I am trying to use JAX pmap but I am getting the error that XLA devices aren't visible - Here's my code -
import jax.numpy as jnp
import os
from jax import pmap
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
out = pmap(lambda x: x ** 2)(jnp.arange(8))
print(out)
Here's the error -
Traceback (most recent call last):
File "new.py", line 6, in <module>
out = pmap(lambda x: x ** 2)(jnp.arange(8))
File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/api.py", line 1779, in cache_miss
execute = pxla.xla_pmap_impl_lazy(fun_, *tracers, **params)
File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 411, in xla_pmap_impl_lazy
compiled_fun, fingerprint = parallel_callable(
File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/linear_util.py", line 345, in memoized_fun
ans = call(fun, *args)
File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 682, in parallel_callable
pmap_executable = pmap_computation.compile()
File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 923, in compile
executable = UnloadedPmapExecutable.from_hlo(
File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 993, in from_hlo
raise ValueError(msg.format(shards.num_global_shards,
jax._src.traceback_util.UnfilteredStackTrace: ValueError: compiling computation that requires 8 logical devices, but only 1 XLA devices are available (num_replicas=8)
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "new.py", line 6, in <module>
out = pmap(lambda x: x ** 2)(jnp.arange(8))
ValueError: compiling computation that requires 8 logical devices, but only 1 XLA devices are available (num_replicas=8)
Based on this and this discussion, I did this os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8', but it doesn't seem to work.
Edit 1:
I tried this but it still doesn't work -
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
import jax
from jax import pmap
import jax.numpy as jnp
out = pmap(lambda x: x ** 2)(jnp.arange(8))
print(out)