1

I am using acme framework to run some experiments, and I installed acme based on documentation. However, I have attribute error that raised likely from JAX, HAIKU, and when I looked into github issue, there was no solution given at this time. Can anyone take a look what package dependecy caused this issue?

my venv spec:

here is my venv spec

dm-acme                      0.4.0
dm-control                   0.0.364896371
dm-env                       1.6
dm-haiku                     0.0.10
dm-launchpad                 0.5.0
dm-reverb                    0.7.0
dm-tree                      0.1.8
acme                         2.10.0
dm-acme                      0.4.0
jax                          0.4.26
jaxlib                       0.4.26+cuda12.cudnn89
python -V                    Python 3.9.5

error details:

File "/data/acme/examples/baselines/rl_discrete/run_dqn.py", line 18, in from acme.agents.jax import dqn File "/data/acme/acme_venv_new/lib/python3.9/site-packages/acme/agents/jax/dqn/init.py", line 18, in from acme.agents.jax.dqn.actor import behavior_policy File "/data/acme/acme_venv_new/lib/python3.9/site-packages/acme/agents/jax/dqn/actor.py", line 20, in from acme.agents.jax import actor_core as actor_core_lib File "/data/acme/acme_venv_new/lib/python3.9/site-packages/acme/agents/jax/actor_core.py", line 22, in from acme.jax import networks as networks_lib File "/data/acme/acme_venv_new/lib/python3.9/site-packages/acme/jax/networks/init.py", line 18, in from acme.jax.networks.atari import AtariTorso File "/data/acme/acme_venv_new/lib/python3.9/site-packages/acme/jax/networks/atari.py", line 29, in from acme.jax.networks import base File "/data/acme/acme_venv_new/lib/python3.9/site-packages/acme/jax/networks/base.py", line 24, in import haiku as hk File "/data/acme/acme_venv_new/lib/python3.9/site-packages/haiku/init.py", line 20, in from haiku import experimental File "/data/acme/acme_venv_new/lib/python3.9/site-packages/haiku/experimental/init.py", line 34, in from haiku._src.dot import abstract_to_dot File "/data/acme/acme_venv_new/lib/python3.9/site-packages/haiku/_src/dot.py", line 163, in @jax.linear_util.transformation File "/data/acme/acme_venv_new/lib/python3.9/site-packages/jax/_src/deprecations.py", line 54, in getattr raise AttributeError(f"module {module!r} has no attribute {name!r}") AttributeError: module 'jax' has no attribute 'linear_util'

seems it raised from haiku and JAX, how this can be fixed? any quick thoughts?

updated attempt

based on @jakevdp suggestion, I reinstalled jax, jaxlib, but now I am getting this error again:

Traceback (most recent call last):
  File "/data/acme/examples/baselines/rl_discrete/run_dqn.py", line 18, in <module>
    from acme.agents.jax import dqn
  File "/data/acme/acme_venv_new/lib/python3.9/site-packages/acme/agents/jax/dqn/__init__.py", line 18, in <module>
    from acme.agents.jax.dqn.actor import behavior_policy
  File "/data/acme/acme_venv_new/lib/python3.9/site-packages/acme/agents/jax/dqn/actor.py", line 20, in <module>
    from acme.agents.jax import actor_core as actor_core_lib
  File "/data/acme/acme_venv_new/lib/python3.9/site-packages/acme/agents/jax/actor_core.py", line 22, in <module>
    from acme.jax import networks as networks_lib
  File "/data/acme/acme_venv_new/lib/python3.9/site-packages/acme/jax/networks/__init__.py", line 45, in <module>
    from acme.jax.networks.multiplexers import CriticMultiplexer
  File "/data/acme/acme_venv_new/lib/python3.9/site-packages/acme/jax/networks/multiplexers.py", line 20, in <module>
    from acme.jax import utils
  File "/data/acme/acme_venv_new/lib/python3.9/site-packages/acme/jax/utils.py", line 190, in <module>
    devices: Optional[Sequence[jax.xla.Device]] = None,
  File "/data/acme/acme_venv_new/lib/python3.9/site-packages/jax/_src/deprecations.py", line 53, in getattr
    raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax' has no attribute 'xla'

here is my pip freeze list on this public gist: acme pip list

I looked into this github issue: jax xla attribute issue

@jakevdp, any updated comment or possible workaround for this jax.xla issue? thanks

1 Answer 1

2

jax.linear_util was deprecated in JAX v0.4.16 and removed in JAX v0.4.24.

It sounds like you have too new a JAX version for the framework code you are using. I'd try installing an older version; e.g.

pip install --upgrade "jax[cuda12_pip]<0.4.24" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

See JAX installation for more installation options.

If you're hoping to update the framework code for compatibility with more recent JAX versions, you might find replacements for previous functionality in jax.extend.linear_util.

Sign up to request clarification or add additional context in comments.

4 Comments

thanks for your help. I downgraded the jax and installed the version you suggested. but now I have old error, "AttributeError: module 'jax' has no attribute 'xla'". I updated my post with updated pip install of jax, jaxlib and new error. do you mind to point me out how to fix this new error? thanks for your support
jax.xla was deprecated in jax v0.4.11 and removed in jax v0.4.14. It seems you still have too new a version of JAX for the framework code you're using. I would look at the documentation of this framework to see what version of JAX is recommended, and use that one.
dm-acme was last released in February 2022, which suggests that the appropriate JAX version is somewhere around 0.3.0. Note that JAX 0.3.0 does not support CUDA 12, so you may have difficulty running acme on your hardware.
I manually removed xla to jax.Device and it worked for now. thanks

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.