I was suprised to see that depending on the size of an input matrix, which is vmapped over inside of a function, the output of the function changes slightly. That is, not only does the size of the output change (which is what I would expect from vmapping) but also the numerics changed slightly. (Note that this only occurs in float32 and only on the GPU)
I wrote a minimally reproducible example to illustrate the behaviour:
import jax
import jax.numpy as jnp
import equinox as eqx
def equinox_vmap(x, mlp):
out = eqx.filter_vmap(mlp.__call__)(x)
return out
key = jax.random.PRNGKey(0)
key, network_key = jax.random.split(key, 2)
mlp = eqx.nn.MLP(2, 2, 10, 2, key=network_key)
key, key_x = jax.random.split(key, 2)
x = jax.random.normal(key_x, (10000, 2))
error_eqx = equinox_vmap(x[:10], mlp) - equinox_vmap(x, mlp)[:10]
print("eqx error:", error_eqx)
When running this example I get the output:
eqx error: [[-1.4442205e-04 1.0999292e-04]
[-5.9515238e-05 -9.1716647e-06]
[ 1.4841557e-05 5.6132674e-05]
[ 0.0000000e+00 0.0000000e+00]
[-9.1642141e-06 -2.5466084e-05]
[ 3.8832426e-05 -3.3110380e-05]
[ 3.3825636e-05 -2.4946406e-05]
[ 4.0918589e-05 -3.2216311e-05]
[ 1.3601780e-04 8.7693334e-06]
[ 0.0000000e+00 0.0000000e+00]]
I understand that the numerics of float32 are not fully accurate and some error is to be expected. However, I was suprised that the result changes depending on how much of the input array is put into the function. I was expecting that the first row of the x array, i.e., x[0,:] would still be filled with the same values and therefore the first row in the output would be the same.
Further notes:
- I enabled the use of
float64(jax.config.update("jax_enable_x64", False)) which completely removed this from occuring. I understand that this is a numerical problem, but I am a little bit confused how the vmapping interacts with the example. - When I run the same example on the CPU (using
jax.config.update("jax_platform_name", "cpu")) this problem also disappears which I also find difficult to understand.
Questions:
- Is this to be expected?
- Where does this "inconsistency" come from?
- Why does it not occur on the CPU and only on the GPU?
Setup:
- GPU: NVIDIA RTX 6000 Ada Generation 48 GB
- Python 3.11.11 with
equinox 0.13.0
jax 0.7.0
jax-cuda12-pjrt 0.7.0
jax-cuda12-plugin 0.7.0
jaxlib 0.7.0
jaxtyping 0.3.2
ml_dtypes 0.5.3
numpy 2.3.2
nvidia-cublas-cu12 12.9.1.4
nvidia-cuda-cupti-cu12 12.9.79
nvidia-cuda-nvcc-cu12 12.9.86
nvidia-cuda-nvrtc-cu12 12.9.86
nvidia-cuda-runtime-cu12 12.9.79
nvidia-cudnn-cu12 9.11.0.98
nvidia-cufft-cu12 11.4.1.4
nvidia-cusolver-cu12 11.7.5.82
nvidia-cusparse-cu12 12.5.10.65
nvidia-nccl-cu12 2.27.6
nvidia-nvjitlink-cu12 12.9.86
nvidia-nvshmem-cu12 3.3.9
opt_einsum 3.4.0
pip 24.0
scipy 1.16.1
setuptools 65.5.0
typing_extensions 4.14.1
wadler_lindig 0.1.7
Any explanations are greatly appreachiated.