I am using acme framework to run some experiments, and I installed acme based on documentation. However, I have attribute error that raised likely from JAX, HAIKU, and when I looked into github issue, there was no solution given at this time. Can anyone take a look what package dependecy caused this issue?
my venv spec:
here is my venv spec
dm-acme 0.4.0
dm-control 0.0.364896371
dm-env 1.6
dm-haiku 0.0.10
dm-launchpad 0.5.0
dm-reverb 0.7.0
dm-tree 0.1.8
acme 2.10.0
dm-acme 0.4.0
jax 0.4.26
jaxlib 0.4.26+cuda12.cudnn89
python -V Python 3.9.5
error details:
File "/data/acme/examples/baselines/rl_discrete/run_dqn.py", line 18, in from acme.agents.jax import dqn File "/data/acme/acme_venv_new/lib/python3.9/site-packages/acme/agents/jax/dqn/init.py", line 18, in from acme.agents.jax.dqn.actor import behavior_policy File "/data/acme/acme_venv_new/lib/python3.9/site-packages/acme/agents/jax/dqn/actor.py", line 20, in from acme.agents.jax import actor_core as actor_core_lib File "/data/acme/acme_venv_new/lib/python3.9/site-packages/acme/agents/jax/actor_core.py", line 22, in from acme.jax import networks as networks_lib File "/data/acme/acme_venv_new/lib/python3.9/site-packages/acme/jax/networks/init.py", line 18, in from acme.jax.networks.atari import AtariTorso File "/data/acme/acme_venv_new/lib/python3.9/site-packages/acme/jax/networks/atari.py", line 29, in from acme.jax.networks import base File "/data/acme/acme_venv_new/lib/python3.9/site-packages/acme/jax/networks/base.py", line 24, in import haiku as hk File "/data/acme/acme_venv_new/lib/python3.9/site-packages/haiku/init.py", line 20, in from haiku import experimental File "/data/acme/acme_venv_new/lib/python3.9/site-packages/haiku/experimental/init.py", line 34, in from haiku._src.dot import abstract_to_dot File "/data/acme/acme_venv_new/lib/python3.9/site-packages/haiku/_src/dot.py", line 163, in @jax.linear_util.transformation File "/data/acme/acme_venv_new/lib/python3.9/site-packages/jax/_src/deprecations.py", line 54, in getattr raise AttributeError(f"module {module!r} has no attribute {name!r}") AttributeError: module 'jax' has no attribute 'linear_util'
seems it raised from haiku and JAX, how this can be fixed? any quick thoughts?
updated attempt
based on @jakevdp suggestion, I reinstalled jax, jaxlib, but now I am getting this error again:
Traceback (most recent call last):
File "/data/acme/examples/baselines/rl_discrete/run_dqn.py", line 18, in <module>
from acme.agents.jax import dqn
File "/data/acme/acme_venv_new/lib/python3.9/site-packages/acme/agents/jax/dqn/__init__.py", line 18, in <module>
from acme.agents.jax.dqn.actor import behavior_policy
File "/data/acme/acme_venv_new/lib/python3.9/site-packages/acme/agents/jax/dqn/actor.py", line 20, in <module>
from acme.agents.jax import actor_core as actor_core_lib
File "/data/acme/acme_venv_new/lib/python3.9/site-packages/acme/agents/jax/actor_core.py", line 22, in <module>
from acme.jax import networks as networks_lib
File "/data/acme/acme_venv_new/lib/python3.9/site-packages/acme/jax/networks/__init__.py", line 45, in <module>
from acme.jax.networks.multiplexers import CriticMultiplexer
File "/data/acme/acme_venv_new/lib/python3.9/site-packages/acme/jax/networks/multiplexers.py", line 20, in <module>
from acme.jax import utils
File "/data/acme/acme_venv_new/lib/python3.9/site-packages/acme/jax/utils.py", line 190, in <module>
devices: Optional[Sequence[jax.xla.Device]] = None,
File "/data/acme/acme_venv_new/lib/python3.9/site-packages/jax/_src/deprecations.py", line 53, in getattr
raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax' has no attribute 'xla'
here is my pip freeze list on this public gist: acme pip list
I looked into this github issue: jax xla attribute issue
@jakevdp, any updated comment or possible workaround for this jax.xla issue? thanks