Skip to content

Commit

Permalink
Fixing jax config import
Browse files Browse the repository at this point in the history
  • Loading branch information
dfm committed Mar 6, 2024
1 parent 2ebe865 commit 40808e8
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion docs/tutorials/first.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@
"metadata": {},
"outputs": [],
"source": [
"from jax.config import config\n",
"from jax import config\n",
"\n",
"config.update(\"jax_enable_x64\", True)\n",
"\n",
Expand Down
4 changes: 2 additions & 2 deletions python/celerite2/jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

logger = logging.getLogger(__name__)

from jax.config import config # noqa isort:skip
from jax import config # noqa isort:skip

if not config.read("jax_enable_x64"):
logger.warning(
Expand All @@ -13,7 +13,7 @@
"already run some jax code.\n"
"You can squash this warning by setting the environment variable "
"'JAX_ENABLE_X64=True' or by running:\n"
">>> from jax.config import config\n"
">>> from jax import config\n"
">>> config.update('jax_enable_x64', True)"
)
config.update("jax_enable_x64", True)
Expand Down
2 changes: 1 addition & 1 deletion python/test/jax/test_jax_terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

try:
import jax
from jax.config import config
from jax import config

from celerite2 import terms as pyterms
from celerite2.jax import terms
Expand Down

0 comments on commit 40808e8

Please sign in to comment.