I have been trying to generalise this jax program for solving on both CPU and GPU depending on the machine it's running on (essentially need cpu parallelisation to speed up testing versus gpu for production). I can get jax to parallelise on the GPU, but no matter what I do jax will not detect my cpu_count and thus cannot be sharded across cores (for context am running on 8 core, 16 thread laptop processor).
I found out that XLA_FORCE_HOST_PLATFORM_DEVICE_COUNT had to be set before jax was initialised (was previously set in the if statement included in the code), but it is still not working. I also tried setting at the very start of my code (this is a snippet from the only file using jax itself, but some other files use jnp as a jax drop in for numpy).
Can anyone tell me why jax will not pick up on the flag? (Relevant code snippet and jupyter notebook output included below). Thanks.
Relevant code snippet:
from multiprocessing import cpu_count
core_count = cpu_count()
### THIS NEEDS TO BE SET BEFORE JAX IS INITIALISED IN ANY WAY, INCLUDING IMPORTING
# - XLA_FLAGS are read WHEN jax is IMPORTED
# you can see other ways of setting the environment variable that I've tried here
#jax.config.update('xla_force_host_platform_device_count', core_count)
#os.environ["XLA_FORCE_HOST_PLATFORM_DEVICE_COUNT"] = '16'#str(core_count)
#os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=' + str(core_count)
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={cpu_count()}"
import jax
# defaults float data types to 64-bit instead of 32 for greater precision
jax.config.update('jax_enable_x64', True)
jax.config.update('jax_captured_constants_report_frames', -1)
jax.config.update('jax_captured_constants_warn_bytes', 128 * 1024 ** 2)
jax.config.update('jax_traceback_filtering', 'off')
# https://docs.jax.dev/en/latest/gpu_memory_allocation.html
#jax.config.update('xla_python_client_allocator', '\"platform\"')
# can't set via jax.config.update for some reason
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = '\"platform\"'
print("\nDefault jax backend:", jax.default_backend())
available_devices = jax.devices()
print(f"Available devices: {available_devices}")
running_device = xla_bridge.get_backend().platform
print("Running device:", running_device, end='')
if running_device == 'cpu':
print(", with:", core_count, "cores.")
from jax.sharding import PartitionSpec as P, NamedSharding
# Create a Sharding object to distribute a value across devices:
# Assume core_count is the no. of core devices available
mesh = jax.make_mesh((core_count,), ('cols',)) # 1D mesh for columns
# Example matrix shape (9, N), e.g., N = 1e7
#x = jax.random.normal(jax.random.key(0), (9, Np))
# Specify sharding: don't split axis 0 (rows), split axis 1 (columns) across devices
# then apply sharding to produce a sharded array from the matrix input
# and use jax.device_put to distribute it across devices:
s0_sharded = jax.device_put(s0, NamedSharding(mesh, P(None, 'cols'))) # 'None' means don't shard axis 0
print(s0_sharded.sharding) # See the sharding spec
print(s0_sharded.addressable_shards) # Check each device's shard
jax.debug.visualize_array_sharding(s0_sharded)
Output:
Default jax backend: cpu
Available devices: [CpuDevice(id=0)]
Running device: cpu, with: 16 cores.
...
relevant line of my code: --> 258 mesh = jax.make_mesh((core_count,), ('cols',)) # 1D mesh for columns
... jax backend trace
ValueError: Number of devices 1 must be >= the product of mesh_shape (16,)