You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I therefore tried to install the version 4.24 by changing requirements.txt from jax>=0.4.10 to jax>=0.4.24 and the Dockerfile line 36 to :
RUN if [ "$USE_CUDA" = true ] ; \
then pip install "jax[cuda11]>=0.4.24" -f "https://storage.googleapis.com/jax-releases/jax_cuda_releases.html" ; \
fi
however I get the error, not being able to use my gpu anymore :
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
Do you have any idea how to solve that ?
Full traceback:
Traceback (most recent call last):
File "/opt/project/xposure/stoix/systems/q_learning/ff_ddqn.py", line 6, in <module>
import flashbax as fbx
File "/xposure/lib/python3.10/site-packages/flashbax/__init__.py", line 16, in <module>
from flashbax.buffers import (
File "/xposure/lib/python3.10/site-packages/flashbax/buffers/__init__.py", line 16, in <module>
from flashbax.buffers.prioritised_flat_buffer import make_prioritised_flat_buffer
File "/xposure/lib/python3.10/site-packages/flashbax/buffers/prioritised_flat_buffer.py", line 25, in <module>
from flashbax.buffers.prioritised_trajectory_buffer import (
File "/xposure/lib/python3.10/site-packages/flashbax/buffers/prioritised_trajectory_buffer.py", line 39, in <module>
from flashbax.buffers import sum_tree, trajectory_buffer
File "/xposure/lib/python3.10/site-packages/flashbax/buffers/sum_tree.py", line 33, in <module>
from flax.struct import dataclass
File "/xposure/lib/python3.10/site-packages/flax/__init__.py", line 24, in <module>
from flax import core
File "/xposure/lib/python3.10/site-packages/flax/core/__init__.py", line 15, in <module>
from .axes_scan import broadcast as broadcast
File "/xposure/lib/python3.10/site-packages/flax/core/axes_scan.py", line 23, in <module>
from jax.extend import linear_util as lu
ImportError: cannot import name 'linear_util' from 'jax.extend' (/xposure/lib/python3.10/site-packages/jax/extend/__init__.py)
To Reproduce
Steps to reproduce the behavior:
make build
run ff_ddqn.py with the docker
Possible Solution
Change version of flax and jax/jaxlib in the requirements.txt and the Dockerfile
Context (Environment)
Linux 24.04 with docker.
This is the pip freeze if I run the Docker with the current setting of the repo:
Hmm, let me look into this. I unfortunately dont have access to a GPU machine currently so itll be hard for me to test this however regardless this reminds me to raise the jax version in the requirements file. Just make sure that the image you are pulling and the jax version has the same cuda and cudnn version and that they are aligned.
Describe the bug
Hello!
When making the Dockerfile, I get the error
Cannot import name 'linear_util' from 'jax'
when running examples. This seems to be due to the incompatibility of flax with jax. https://stackoverflow.com/questions/78210393/cannot-import-name-linear-util-from-jax (I do get access to my GPU 2070MaxQ with those settings)I therefore tried to install the version 4.24 by changing requirements.txt from
jax>=0.4.10
tojax>=0.4.24
and the Dockerfile line 36 to :however I get the error, not being able to use my gpu anymore :
Do you have any idea how to solve that ?
Full traceback:
To Reproduce
Steps to reproduce the behavior:
Possible Solution
Change version of flax and jax/jaxlib in the requirements.txt and the Dockerfile
Context (Environment)
Linux 24.04 with docker.
This is the pip freeze if I run the Docker with the current setting of the repo:
The text was updated successfully, but these errors were encountered: