Skip to content

Commit

Permalink
Fix jax.config import
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Feb 29, 2024
1 parent 4cdb4bb commit 92c162d
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions phiml/backend/jax/_jax_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
from .._dtype import DType, to_numpy_dtype, from_numpy_dtype
from .._backend import Backend, ComputeDevice, combined_dim, ML_LOGGER, TensorType, map_structure

from jax.config import config
config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)


class JaxBackend(Backend):
Expand Down

0 comments on commit 92c162d

Please sign in to comment.