-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #144 from Sampreet/pr/feature-numerical-backend
JAX Numerical Backend for GPU/TPU Support
- Loading branch information
Showing
64 changed files
with
1,098 additions
and
687 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# Development | ||
|
||
The current development branch "dev/jax" implements | ||
|
||
* [Experimental Support for GPUs/TPUs](#experimental-support-for-gpustpus) | ||
|
||
## Experimental Support for GPUs/TPUs | ||
|
||
Although OQuPy is built on top of the backend-agnostic | ||
[TensorNetwork](https://github.com/google/TensorNetwork) library, | ||
OQuPy uses vanilla NumPy and SciPy throughout its implementation. | ||
|
||
The "dev/jax" branch adds supports for GPUs/TPUs via the | ||
[JAX](https://jax.readthedocs.io/en/latest/) library. | ||
A new `oqupy.backends.numerical_backend.py` module handles the | ||
[breaking changes in JAX NumPy](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html), | ||
while the rest of the modules utilizes `numpy` and `scipy.linalg` instances from there | ||
without explicitly importing JAX-based libraries. | ||
|
||
### Enabling Experimental Features | ||
|
||
To enable experimental features switch to the `dev/jax` branch and use | ||
```python | ||
from oqupy.backends import enable_jax_features | ||
enable_jax_features() | ||
``` | ||
Alternatively, the `OQUPY_BACKEND` environmental variable may be set to `jax` to | ||
initialize the jax backend by default. | ||
|
||
### Contributing Guidelines | ||
|
||
To contribute features compatible with the JAX backend, | ||
please adhere to the following set of guidelines: | ||
|
||
* avoid wildcard imports of NumPy and SciPy. | ||
* use `from oqupy.backends.numerical_backend import np` instead of `import numpy as np` and use the alias `default_np` in cases vanilla NumPy is explicitly required. | ||
* use `from oqupy.backends.numerical_backend import la` instead of `import scipy.linalg as la`, except that for non-symmetric eigen-decomposition, `scipy.linalg.eig` should be used. | ||
* use one of `np.dtype_complex` (`np.dtype_float`) or `oqupy.config.NumPyDtypeComplex` (`oqupy.config.NumPyDtypeFloat`) instead of `np.complex_` (`np.float_`). | ||
* convert lists or tuples to arrays when passing them as arguments inside functions. | ||
* use `array = np.update(array, indices, values)` instead of `array[indices] = values`. | ||
* use `np.get_random_floats(seed, shape)` instead of `np.random.default_rng(seed).random(shape)`. | ||
* declare signatures for `np.vectorize` explicitly. | ||
* avoid directly changing the `shape` attribute of an array (use `.reshape` instead) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
Experimental Support for GPUs/TPUs | ||
================================== | ||
The current development branch "dev/jax" implements experimental support | ||
for GPUs/TPUs. | ||
|
||
Although OQuPy is built on top of the backend-agnostic | ||
`TensorNetwork <https://github.com/google/TensorNetwork>`__ library, | ||
OQuPy uses vanilla NumPy and SciPy throughout its implementation. | ||
|
||
The "dev/jax" branch adds supports for GPUs/TPUs via the | ||
`JAX <https://jax.readthedocs.io/en/latest/>`__ library. A new | ||
``oqupy.backends.numerical_backend.py`` module handles the | ||
`breaking changes in JAX | ||
NumPy <https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html>`__, | ||
while the rest of the modules utilizes ``numpy`` and ``scipy.linalg`` | ||
instances from there without explicitly importing JAX-based libraries. | ||
|
||
Enabling Experimental Features | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
To enable experimental features, switch to the ``dev/jax`` branch and use | ||
|
||
.. code:: python | ||
from oqupy.backends import enable_jax_features | ||
enable_jax_features() | ||
Alternatively, the `OQUPY_BACKEND` environmental variable may be set to `jax` to | ||
initialize the jax backend by default. | ||
|
||
Contributing Guidelines | ||
~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
To contribute features compatible with the JAX backend, | ||
please adhere to the following set of guidelines: | ||
|
||
- avoid wildcard imports of NumPy and SciPy. | ||
- use ``from oqupy.backends.numerical_backend import np`` instead of | ||
``import numpy as np`` and use the alias ``default_np`` in cases | ||
vanilla NumPy is explicitly required. | ||
- use ``from oqupy.backends.numerical_backend import la`` instead of | ||
``import scipy.linalg as la``, except that for non-symmetric | ||
eigen-decomposition, ``scipy.linalg.eig`` should be used. | ||
- use one of ``np.dtype_complex`` (``np.dtype_float``) or | ||
``oqupy.config.NumPyDtypeComplex`` (``oqupy.config.NumPyDtypeFloat``) | ||
instead of ``np.complex_`` (``np.float_``). | ||
- convert lists or tuples to arrays when passing them as arguments | ||
inside functions. | ||
- use ``array = np.update(array, indices, values)`` instead of | ||
``array[indices] = values``. | ||
- use ``np.get_random_floats(seed, shape)`` instead of | ||
``np.random.default_rng(seed).random(shape)``. | ||
- declare signatures for ``np.vectorize`` explicitly. | ||
- avoid directly changing the ``shape`` attribute of an array (use | ||
``.reshape`` instead) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
#!/usr/bin/env python | ||
|
||
import sys | ||
sys.path.insert(0, '.') | ||
# set the 'OQUPY_BACKEND' environment variable | ||
# to 'jax' to initialize JAX backend by default | ||
# or switch to JAX backend using oqupy.backends | ||
import oqupy | ||
from oqupy.backends import enable_jax_features | ||
# import NumPy from numerical_backend | ||
#from oqupy.backends.numerical_backend import np | ||
#enable_jax_features() | ||
|
||
import matplotlib.pyplot as plt | ||
sigma_x = oqupy.operators.sigma("x") | ||
sigma_z = oqupy.operators.sigma("z") | ||
up_density_matrix = oqupy.operators.spin_dm("z+") | ||
Omega = 1.0 | ||
omega_cutoff = 5.0 | ||
alpha = 0.3 | ||
|
||
system = oqupy.System(0.5 * Omega * sigma_x) | ||
correlations = oqupy.PowerLawSD(alpha=alpha, | ||
zeta=1, | ||
cutoff=omega_cutoff, | ||
cutoff_type='exponential') | ||
bath = oqupy.Bath(0.5 * sigma_z, correlations) | ||
tempo_parameters = oqupy.TempoParameters(dt=0.1, tcut=3.0, epsrel=10**(-4)) | ||
|
||
dynamics = oqupy.tempo_compute(system=system, | ||
bath=bath, | ||
initial_state=up_density_matrix, | ||
start_time=0.0, | ||
end_time=2.0, | ||
parameters=tempo_parameters, | ||
unique=True) | ||
t, s_z = dynamics.expectations(0.5*sigma_z, real=True) | ||
print(s_z) | ||
plt.plot(t, s_z, label=r'$\alpha=0.3$') | ||
plt.xlabel(r'$t\,\Omega$') | ||
plt.ylabel(r'$\langle\sigma_z\rangle$') | ||
plt.savefig('simple_dynamics_jax.png') | ||
#plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
"""Module to initialize OQuPy's backends.""" | ||
|
||
from oqupy.backends.numerical_backend import set_numerical_backends | ||
|
||
def enable_jax_features(): | ||
"""Function to enable experimental features.""" | ||
|
||
# set numerical backend to JAX | ||
set_numerical_backends('jax') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
""" | ||
Module containing NumPy-like and SciPy-like numerical backends. | ||
""" | ||
|
||
import os | ||
|
||
import numpy as default_np | ||
import scipy.linalg as default_la | ||
|
||
from tensornetwork.backend_contextmanager import \ | ||
set_default_backend | ||
|
||
import oqupy.config as oc | ||
|
||
# store instances of the initialized backends | ||
# this way, `oqupy.config` remains unchanged | ||
# and `ocupy.config.DEFAULT_BACKEND` is used | ||
# when NumPy and LinAlg are initialized | ||
NUMERICAL_BACKEND_INSTANCES = {} | ||
|
||
def get_numerical_backends( | ||
backend_name: str, | ||
): | ||
"""Function to get numerical backend. | ||
Parameters | ||
---------- | ||
backend_name: str | ||
Name of the backend. Options are `'jax'` and `'numpy'`. | ||
Returns | ||
------- | ||
backends: list | ||
NumPy and LinAlg backends. | ||
""" | ||
|
||
_bn = backend_name.lower() | ||
if _bn in NUMERICAL_BACKEND_INSTANCES: | ||
set_default_backend(_bn) | ||
return NUMERICAL_BACKEND_INSTANCES[_bn] | ||
assert _bn in ['jax', 'numpy'], \ | ||
"currently supported backends are `'jax'` and `'numpy'`" | ||
|
||
if 'jax' in _bn: | ||
try: | ||
# explicitly import and configure jax | ||
import jax | ||
import jax.numpy as jnp | ||
import jax.scipy.linalg as jla | ||
jax.config.update('jax_enable_x64', True) | ||
|
||
# # TODO: GPU memory allocation (default is 0.75) | ||
# os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform' | ||
# os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.5' | ||
|
||
# set TensorNetwork backend | ||
set_default_backend('jax') | ||
|
||
NUMERICAL_BACKEND_INSTANCES['jax'] = [jnp, jla] | ||
return NUMERICAL_BACKEND_INSTANCES['jax'] | ||
except ImportError: | ||
print("JAX not installed, defaulting to NumPy") | ||
|
||
# set TensorNetwork backend | ||
set_default_backend('numpy') | ||
|
||
NUMERICAL_BACKEND_INSTANCES['numpy'] = [default_np, default_la] | ||
return NUMERICAL_BACKEND_INSTANCES['numpy'] | ||
|
||
class NumPy: | ||
""" | ||
The NumPy backend employing | ||
dynamic switching through `oqupy.config`. | ||
""" | ||
def __init__(self, | ||
backend_name=oc.DEFAULT_BACKEND, | ||
): | ||
"""Getter for the backend.""" | ||
self.backend = get_numerical_backends(backend_name)[0] | ||
|
||
@property | ||
def dtype_complex(self) -> default_np.dtype: | ||
"""Getter for the complex datatype.""" | ||
return oc.NumPyDtypeComplex | ||
|
||
@property | ||
def dtype_float(self) -> default_np.dtype: | ||
"""Getter for the float datatype.""" | ||
return oc.NumPyDtypeFloat | ||
|
||
def __getattr__(self, | ||
name: str, | ||
): | ||
"""Return the backend's default attribute.""" | ||
return getattr(self.backend, name) | ||
|
||
def update(self, | ||
array, | ||
indices: tuple, | ||
values, | ||
) -> default_np.ndarray: | ||
"""Option to update select indices of an array with given values.""" | ||
if not isinstance(array, default_np.ndarray): | ||
return array.at[indices].set(values) | ||
array[indices] = values | ||
return array | ||
|
||
def get_random_floats(self, | ||
seed, | ||
shape, | ||
): | ||
"""Method to obtain random floats with a given seed and shape.""" | ||
random_floats = default_np.random.default_rng(seed).random(shape, \ | ||
dtype=default_np.float64) | ||
return self.backend.array(random_floats, dtype=self.dtype_float) | ||
|
||
class LinAlg: | ||
""" | ||
The Linear Algebra backend employing | ||
dynamic switching through `oqupy.config`. | ||
""" | ||
def __init__(self, | ||
backend_name=oc.DEFAULT_BACKEND, | ||
): | ||
"""Getter for the backend.""" | ||
self.backend = get_numerical_backends(backend_name)[1] | ||
|
||
def __getattr__(self, | ||
name: str, | ||
): | ||
"""Return the backend's default attribute.""" | ||
return getattr(self.backend, name) | ||
|
||
# setup libraries using environment variable | ||
# fall back to oqupy.config.DEFAULT_BACKEND | ||
try: | ||
BACKEND_NAME = os.environ[oc.BACKEND_ENV_VAR] | ||
except KeyError: | ||
BACKEND_NAME = oc.DEFAULT_BACKEND | ||
np = NumPy(backend_name=BACKEND_NAME) | ||
la = LinAlg(backend_name=BACKEND_NAME) | ||
|
||
def set_numerical_backends( | ||
backend_name: str | ||
): | ||
"""Function to set numerical backend. | ||
Parameters | ||
---------- | ||
backend_name: str | ||
Name of the backend. Options are `'jax'` and `'numpy'`. | ||
""" | ||
backends = get_numerical_backends(backend_name) | ||
np.backend = backends[0] | ||
la.backend = backends[1] |
Oops, something went wrong.