Skip to content

Commit

Permalink
Update JAX dependency constraint, and bump minor version number (#18)
Browse files Browse the repository at this point in the history
* Update JAX dependency constraint, and bump minor version number

* Update link in README to callback docs
  • Loading branch information
zombie-einstein authored Apr 26, 2024
1 parent 7c7d609 commit 43d5fda
Show file tree
Hide file tree
Showing 3 changed files with 313 additions and 247 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ last_number, all_numbers = lax.scan(step, 0, jnp.arange(n))

JAX functions are [pure](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions),
so side effects such as printing progress when running scans and loops are not allowed.
However, the [host_callback module](https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html)
However, the
[debug module](https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html#exploring-debug-callback)
has primitives for calling Python functions on the host from JAX code. This can be used
to update a Python tqdm progress bar regularly during the computation. JAX-tqdm
implements this for JAX scans and loops and is used by simply adding a decorator to the
Expand Down
Loading

0 comments on commit 43d5fda

Please sign in to comment.