-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement Dynamical CPU/GPU/TPU Backends #142
Comments
Hi @Sampreet, the code changes look really good. I haven't been able to run things on my system yet, as Linux + AMD for JAX is experimental, and there are some barriers to building from source, but hopefully I'll be able to get that done soon. [On a related note it looks like JAX requires Python 3.10+; not an issue but will need to make a note of it in the README or similar as I think we are still hoping to support 3.6 as a minimum - #80] The dynamical switching looks fine, and I actually prefer the neatness of using I guess it's nice to be able to choose JAX for either a EDIT: Running CPU JAX, I've observed that the calls |
RE Potential Issues ShapeI don't see any issue with this. In fact, the current use of shape vs reshape in the code looks sporadic. I think we should switch to always using Vectorization SignaturesYes, signatures can be added and that sounds like a good addition. You are correct: Hamilonians for all Row Degeneracy Checks
Overall, this looks all looks very good and close to being PR-ready. It would be good to do widespread testing, because presumably there may well be other small compatibility issues such as you have caught in the numerical backend with |
RE Unit testing We have a large number of coverage and 'physical' (assessing physical consistency or accuracy) under tests/. What is the best way to extend these to test against the JAX backend? I guess (assuming we want the code to work 100% with both backends) we can have two testing environments, one for each backend (JAX also being Python 3.10), to run all the tests against. |
Hi @piperfw, thanks again for your detailed comments. I shall implement the suggestions for shape, signatures and the degeneracy arrays and let you know if I face any issue. Based on my understanding of the source code, I would like to share the following in relation to your queries: Data typesI do agree with certain checks on datatypes, specially since the precision requirements (single/double floating points and their corresponding complex datatypes) are interrelated in most of the modules. Also, as the operations for a single simulation (by a user) are likely to involve the same degree of precision, I feel it might not be useful to provide the dynamic alterability of precision (to the user) through the corresponding methods of And yes, JAX, by default, supports 32-bit floats (and correspondingly, 64-bit complex) datatypes and requires a call to the Unit testsI am still to update the test files with the classes from Regarding PRIt might be handy to have a separate branch (say Update on speedupIn the latest commit of the pr/feature-numerical-backend branch, I have implemented the generalized method to create deltas (with a NumPy-friendly About
|
Hi Sampreet, thanks for the detailed updates. I will respond to particular comments below. The code changes and generally the implementation look really good. For the bigger picture, there are two things we want to avoid:
These are potentially conflicting: if we require contributions to be JAX compatible, then that is a large barrier, but if we don't then contributed code may break for the JAX backend and someone will need to assess and fix any issues. It seems like we should be able to keep the core OQuPy and existing features JAX-compatible (to the extent of lacking integrate, eigen decomposition etc.). I suggest we add a note to developers (and forgetful maintainers!) e.g. in I'm not sure we should absolutely require developers to write modules for new features that are necessarily JAX compatible. I guess there might be some boilerplate code we can have ready to add to a module e.g. to raise Otherwise, I realised we did in fact switch to Python3.10 making my concern about versions above irrelevant. So that's good. Data typesI'll let @gefux tell us about the intended purpose, and if this is still relevant, as I don't know. Unit testsI'm currently having difficultly getting a working JAX install on my system within python testing environments (3.10 is not my OS python). I'll hopefully get this sorted in due course and provide an update. But regardless, the assertion and vectorized changes sound fine. I don't know of any code that would suffer from the change of My question here Sampreet is how we can make switching to the JAX testing neater (bear in mind we want to use GitHub's CI). I'm not so familiar with pytest, but do you think there is a simple command flag or plugin we can optionally specify from the test invocation command to optionally load the JAX PRThat's a good suggestion. For now I think just creating a PR from your current branch and linking here would be helpful so everyone can follow changes. create_deltaI or someone else will look into #140 when we get the chance, sounds good. integrate / linalgThese deficiencies may be good to note in addition to advise for developers. Maybe we need a page in the docs about the JAX backend, what do you think? Under 'Development', perhaps a new page on JAX after 'Contributing' (https://oqupy.readthedocs.io/en/latest/pages/contributing.html). In place of or addition to changes to I do not think we should use other libraries or projects. In particular, the bath eigendecomposition is only done once for small matrices if I'm not mistaken, so there is no cost in just avoiding the GPU here. |
Dear @Sampreet, I finally found some time to look into your code - I apologize for the delay. General commentAs piperfw writes, we need to be mindful of the maintainance difficulty and the contributing barriers. I think that using Technical comments (in addition to piperfw's)Dynamical switching and automated testing:The general idea of defining the default in Numpy data typesFixing the data types to one specific type is simply a kind of a defensive coding style which I embrased when starting this project because numpy tends to become really slow if it performs computations on mixed types. At least that is what I used to experience a few years ago -- maybe this has changed in the meantime. I also found it potentially useful to be able to switch all computations to a different type from one point.
|
Really good idea with the environmental variable @gefux, that clears up a lot of the mess with the testing I was concerned about. Completely agree about the experimental label for the backend; we can advertise this in an appropriate page in the docs. |
Thanks a lot for the detailed comments. I shall implement the suggested changes sometime this week and create a draft PR, where we can continue the code-related discussions. Below are a few points I would like to add here: Contribution and Maintenance
This sounds like a good idea and more user/contributor/maintainer-friendly for now. The proposed changes are also fully compatible with vanilla NumPy users without requiring any change in the existing examples. Same goes for contributions with vanilla NumPy. Additionally, interested contributors can follow an "optional" set of guidelines (detailed next) to make sure that their changes work seamlessly with JAX, thereby easing maintenance. Moreover, the current
The following guidelines can be added for "optional" contribution to the "experimental" features :
It might be a good idea to add a separate section/page in the documentation to highlight the "experimental" nature of this feature. In time, this can include installation guidelines, gotchas (some points from Continuous Integration and Unit Tests
From what I understand, the automated tests are performed with [testenv:pytest_jax]
description = run pytests with JAX and make coverage report
basepython = python3.10
deps = -rrequirements_ci_jax.txt
commands =
pytest --cov-report term-missing --cov=oqupy ./tests/{posargs}
python -m coverage xml Then, we can use Python's from importlib.util import find_spec
if find_spec('jax') is not None:
# JAX configuration here I have implemented this approach and it works well. What I am worried about is the time taken by the JAX-based tests are very high since we do not have any explicit JAX-based primitives in OQuPy's modules yet (further details in the next section). Dynamical Switching with Environment Variables
I believe what you are suggesting here is an explicit backend module for JAX, which is imported during the initialization of The above description is similar to what has been implemented inside On second thoughts, since
As of now, I don't see any major disadvantage in the environment-variable approach, other than the fact that we might need to add a few additional guidelines for "optional" contributions. |
Hi @Sampreet, thanks for making the PR and apologies for the delay in responding here. Contribution and MaintenanceI do think we should keep to no required jax imports in modules, but something like Continuous Integration and Unit TestsYour proposal for Dynamical Switching with Environment VariablesI understand @gefux's suggestion to mean, the dynamical switching in Being able to use jax primitives etc. does sound really nice; that would really open up possibilities for developers to push and optimise oqupy (/for GPUs), although I'm not sure about the code complexity. I'll make a brief comment/summary on the PR itself. |
Hi @piperfw, thank you for your the feedback and the suggestions for the next steps. In line with your comment on Contribution and Maintenance, I shall update the docs sometime during the weekend. As for CI and Unit Tests, I think it's best if the JAX-based tests are run locally since GitHub CI has a monthly capping for it's free tier. To achieve this using # store instances of the initialized backends
# this way, `oqupy.config` remains unchanged
# and `OQUPY_BACKEND` is used by default
# when NumPy and LinAlg are initialized
NUMERICAL_BACKEND_INSTANCES = {}
def get_numerical_backends(
backend_name:str,
):
if backend_name in NUMERICAL_BACKEND_INSTANCES:
return NUMERICAL_BACKEND_INSTANCES[backend_name]
assert backend_name in ['jax', 'numpy'], \
"currently supported backends are `'jax'` and `'numpy'`"
if 'jax' in backend_name.lower():
try:
# explicitly import and configure jax
# tox catches import-outside-toplevel
import jax
import jax.numpy as jnp
import jax.scipy.linalg as jla
jax.config.update('jax_enable_x64', True)
# set tensor network backend
set_default_backend('jax')
NUMERICAL_BACKEND_INSTANCES['jax'] = [jnp, jla]
return NUMERICAL_BACKEND_INSTANCES['jax']
except ImportError:
print("JAX not installed, falling back to NumPy")
set_default_backend('numpy')
NUMERICAL_BACKEND_INSTANCES['numpy'] = [default_np, default_la]
return NUMERICAL_BACKEND_INSTANCES['numpy']
# NumPy and LinAlg initializes with the `OQUPY_BACKEND` value by default
class NumPy:
def __init__(self):
self.backend = get_numerical_backends(oc.OQUPY_BACKEND)[0]
# other methods remain the same
class LinAlg:
def __init__(self):
self.backend = get_numerical_backends(oc.OQUPY_BACKEND)[1]
# other methods remain the same
# initialize
np = NumPy()
la = LinAlg()
# switch the backend without changing `oqupy.config`
def enable_experimental_features():
backends = get_numerical_backends('jax')
np.backend = backends[0]
la.backend = backends[1] I suppose this is in line with @gefux's comment. Kindly let me know if this works. Also, to answer your query on explicit modules, what I meant was that the alternate form of implementation which I mentioned in my previous comment would use separate modules for each numerical backend (as in quantrl), and implement an abstract base class inherited by the numerical backends. This approach typically follows the file structure:
Here, each backend can be individually optimized to implement the required methods. Although this will enhance scalability and allow further optimization of the package (as you have mentioned), the codebase might change significantly. As such, this can be implemented at a later stage when OQuPy's backends are modified to implement other planned changes (as mentioned in some of the ToDos and in #28). Thanks again! |
Hi @Sampreet! I see what you're doing here and this looks close to what we want. Thanks for explaining the explicit module idea, I agree we can leave that for the time being. To check, is the purpose of I think the main idea of the dynamical switching was to switch based on the environment, which I think can be done with something like # config.py
OQUPY_BACKEND_ENV_VAR = 'OQUPY_BACKEND' # name of environmental variable to check # numerical_backend.py
import os
def get_numerical_backends(...)
class NumPy:
def __init__(self, backend_name=oc.OQUPY_DEFAULT_BACKEND):
self.backend = get_numerical_backends(backend_name)[0]
# Similar for LinAlg
try:
backend_name = os.environ[oc.OQUPY_BACKEND_ENV_VAR]
except KeyError:
backend_name = oc.OQUPY_DEFAULT_BACKEND
np = NumPy(backend_name=backend_name)
la = LinAlg(backend_name=backend_name) Note I think it makes sense to call the config variable In fact, can one lose class NumPy:
def __init__(self, backend=None):
if backend is None:
backend = get_numerical_backends(oc.OQUPY_DEFAULT_BACKEND)
self.backend = backend
# Similar for LinAlg
np_backend, la_backend = get_numerical_backends(backend_name)
np = NumPy(np_backend)
la = LinAlg(la_backend) I guess that might not be advisable because someone could try to initialise with modules that was not canonically jax or numpy. What do you think? To summarise, I think this is close to being what we want but I am questioning the relevance of |
Hi @piperfw, Indeed, the main idea behind the Also, I do agree that allowing I have made some slight changes to the |
After some further discussion on the PR page, completed with 6ccbd2e |
Summary
With reference to the email conversation with @piperfw and @gefux, I am opening a dedicated issue pertaining to a dynamical numerical backend to support GPUs and TPUs. OQuPy utilizes TensorNetwork modules which support frameworks like TensorFlow, PyTorch and JAX. However, TensorNetwork's
Node
requires itstensor
parameter to be passed with the same numerical backend, and OQuPy's modules use NumPy explicitly to create these nodes. As such, using a dynamical backend to switch from NumPy/SciPy to corresponding libraries of (say) JAX will facilitate the usage of TensorNetwork's GPU/TPU-friendly frameworks, which may speed up certain methods. This issue proposes a way similar to what has been recently implemented in QuTiP to support auto-differentiation using its JAX backend.Implementation
The
oqupy.config
module is updated with:The specific choice of
scipy.linalg
instead ofscipy
is becausejax.scipy.integrate
doesn't yet have implementations ofdblquad
,quad
andquad_vec
as required by some of OQuPy's modules.A new
oqupy.backends.numerical_backend
module now takes care of the switching. The coarse contents of the module are:All the modules of OQuPy are modified as:
The corresponding changes can be viewed by comparing the pr/feature-numerical-backend branch.
Testing
I have tested the above implementations using
tox
for vanilla NumPy/SciPy backend and reproduced the plots of arXiv:2406.16650. To use the JAX backend, the following snippet can be added at any point in the scripts:Issues
The JAX backend tests have multiple issues that require suggestions before any modification. A few of these issues are mentioned below:
Immutability of
jax.numpy
ArraysThe
add_singleton
function inoqupy.utils
is often called by modules likeoqupy.backends.pt_tempo_backend
to create/update MPOs and MPSs. This function updates the shape of the array by adding an additional dimension to it at a specified index. Whereas NumPy supports altering theshape
attribute, JAX'sArrayImpl
datatypes are immutable and as suchshape
doesn't have a setter method leading to a recurrent errors. Similar mutations can be found inoqupy.gradient
,oqupy.mps_mpo
andoqupy.system_dynamics
modules. In this case, would reshaping the array instead of updating its shape break the functionality? Otherwise, a separate variable can be dedicated for the shape.Vectorization Signatures
Modules such as
oqupy.bath_correlations
andoqupy.system
use thevectorize
method of NumPy without explicitly mentioning the signatures for the input and output. This raises issues with JAX when the parameters contain non-scalar inputs. Depending on the type of implementation, can the signatures be added for each vectorized function? For example, the Hamiltonians forTimeDependentSystem
andTimeDependentSystemWithField
will have the signatures()->(m,m)
and(),()->(m,m)
respectively (kindly correct me if I have made a mistake here).Row Degeneracy Checks
Unlike the
numpy.unique
method, thejax.numpy.unique
method returns array indices with the same shape as the input array even with theaxis
parameter. Although this can be resolved by using theflatten
method, I am not sure if it will introduce additional errors.Kindly share your views as and when time permits. Thank you once again.
The text was updated successfully, but these errors were encountered: