Skip to content

Commit

Permalink
Add print-rate as optional argument (#11)
Browse files Browse the repository at this point in the history
* Add print-rate as optional argument

* Update dependencies

* Update isort

* Update README, extend tests add checks on print-rate

* README tweaks from PR comments

* Fix typo

* Bump minor version number ahead of new release
  • Loading branch information
zombie-einstein authored Feb 13, 2023
1 parent ca18c7e commit 9034145
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 73 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ repos:
- "--ignore=W503"
- "--per-file-ignores=__init__.py:F401"
- repo: https://github.com/pycqa/isort
rev: 5.11.4
rev: 5.12.0
hooks:
- id: isort
args:
Expand Down
34 changes: 32 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,41 @@ def step(i, val):
last_number = lax.fori_loop(0, n, step, 0)
```

### Print Rate

By default, the progress bar is updated 20 times over the course of the scan/loop
(for performance purposes, see [below](#why-jax-tqdm)). This
update rate can be manually controlled with the `print_rate` keyword argument. For
example:

```python
from jax_tqdm import scan_tqdm
from jax import lax
import jax.numpy as jnp

n = 10_000

@scan_tqdm(n, print_rate=2)
def step(carry, x):
return carry + 1, carry + 1

last_number, all_numbers = lax.scan(step, 0, jnp.arange(n))
```

will update every other step.

## Why JAX-tqdm?

JAX functions are [purely functional](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions), so side effects such as printing progress when running scans and loops are not allowed. However, the [host_callback module](https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html) has primitives for calling Python functions on the host from JAX code. This can be used to update a Python tqdm progress bar regularly during the computation. JAX-tqdm implements this for JAX scans and loops and is used by simply adding a decorator to the body of your update function.
JAX functions are [pure](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions),
so side effects such as printing progress when running scans and loops are not allowed.
However, the [host_callback module](https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html)
has primitives for calling Python functions on the host from JAX code. This can be used
to update a Python tqdm progress bar regularly during the computation. JAX-tqdm
implements this for JAX scans and loops and is used by simply adding a decorator to the
body of your update function.

Note that as the tqdm progress bar is only updated 20 times during the scan or loop, there is no performance penalty.
Note that as the tqdm progress bar is only updated 20 times during the scan or loop,
there is no performance penalty.

The code is explained in more detail in this [blog post](https://www.jeremiecoullon.com/2021/01/29/jax_progress_bar/).

Expand Down
40 changes: 32 additions & 8 deletions jax_tqdm/pbar.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,21 @@
from tqdm.auto import tqdm


def scan_tqdm(n: int, message: typing.Optional[str] = None) -> typing.Callable:
def scan_tqdm(
n: int,
print_rate: typing.Optional[int] = None,
message: typing.Optional[str] = None,
) -> typing.Callable:
"""
tqdm progress bar for a JAX scan
Parameters
----------
n : int
Number of scan steps/iterations.
print_rate: int
Optional integer rate at which the progress bar will be updated,
by default the print rate will 1/20th of the total number of steps.
message : str
Optional string to prepend to tqdm progress bar.
Expand All @@ -22,7 +29,7 @@ def scan_tqdm(n: int, message: typing.Optional[str] = None) -> typing.Callable:
Progress bar wrapping function.
"""

_update_progress_bar, close_tqdm = build_tqdm(n, message)
_update_progress_bar, close_tqdm = build_tqdm(n, print_rate, message)

def _scan_tqdm(func):
"""Decorator that adds a tqdm progress bar to `body_fun` used in `jax.lax.scan`.
Expand All @@ -45,14 +52,21 @@ def wrapper_progress_bar(carry, x):
return _scan_tqdm


def loop_tqdm(n: int, message: typing.Optional[str] = None) -> typing.Callable:
def loop_tqdm(
n: int,
print_rate: typing.Optional[int] = None,
message: typing.Optional[str] = None,
) -> typing.Callable:
"""
tqdm progress bar for a JAX fori_loop
Parameters
----------
n : int
Number of iterations.
print_rate: int
Optional integer rate at which the progress bar will be updated,
by default the print rate will 1/20th of the total number of steps.
message : str
Optional string to prepend to tqdm progress bar.
Expand All @@ -62,7 +76,7 @@ def loop_tqdm(n: int, message: typing.Optional[str] = None) -> typing.Callable:
Progress bar wrapping function.
"""

_update_progress_bar, close_tqdm = build_tqdm(n, message)
_update_progress_bar, close_tqdm = build_tqdm(n, print_rate, message)

def _loop_tqdm(func):
"""
Expand All @@ -81,7 +95,7 @@ def wrapper_progress_bar(i, val):


def build_tqdm(
n: int, message: typing.Optional[str] = None
n: int, print_rate: typing.Optional[int], message: typing.Optional[str] = None
) -> typing.Tuple[typing.Callable, typing.Callable]:
"""
Build the tqdm progress bar on the host
Expand All @@ -91,10 +105,20 @@ def build_tqdm(
message = f"Running for {n:,} iterations"
tqdm_bars = {}

if n > 20:
print_rate = int(n / 20)
if print_rate is None:
if n > 20:
print_rate = int(n / 20)
else:
print_rate = 1
else:
print_rate = 1
if print_rate < 1:
raise ValueError(f"Print rate should be > 0 got {print_rate}")
elif print_rate > n:
raise ValueError(
"Print rate should be less than the "
f"number of steps {n}, got {print_rate}"
)

remainder = n % print_rate

def _define_tqdm(arg, transform):
Expand Down
112 changes: 56 additions & 56 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 9034145

Please sign in to comment.