-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for PyMC >= 5 with PyTensor backend #89
Conversation
Copied pymc4 stuff, renamted aesara to pytensor, and fixed missing/renamed modules
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @vandalt!! I've been wanting to get to this for a long time. My only hesitation is that the implementation details of pytensor still seem to be changing dramatically with each version, but we have to give this a try sooner or later. I left some inline comments that would be worth looking into. Can you also:
- Add these tests to
.github/workflows/tests.yml
. - See what you can dig up about the jax failures. I'm happy to scrap that part of the implementation if it's not stable.
In terms of naming, I'd be happy to just remove the PyMC4 implementation since it was so short lived and then we'll just have |
…xoplanet-dev#91) * Import ShapedArray from `jax.core` instead for `jax.abstract_arrays` Related JAX changelog: https://jax.readthedocs.io/en/latest/changelog.html#jax-0-4-16-sept-18-2023 * Fixing another deprecated function call --------- Co-authored-by: Dan Foreman-Mackey <[email protected]>
updates: - [github.com/pre-commit/pre-commit-hooks: v4.4.0 → v4.5.0](pre-commit/pre-commit-hooks@v4.4.0...v4.5.0) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Bumps [pypa/cibuildwheel](https://github.com/pypa/cibuildwheel) from 2.15.0 to 2.16.2. - [Release notes](https://github.com/pypa/cibuildwheel/releases) - [Changelog](https://github.com/pypa/cibuildwheel/blob/main/docs/changelog.md) - [Commits](pypa/cibuildwheel@v2.15.0...v2.16.2) --- updated-dependencies: - dependency-name: pypa/cibuildwheel dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] <[email protected]> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
@vandalt — I'm making a few changes to your branch to get the test suite running. Can you pull from your branch before making anymore changes? Remaining on your end is to remove the compiler flags, and to look into the details of the shape question above. Let me know if there's anything I can do yo help get this over the finish line! |
Thanks, noted!
I started reading about the shape question this weekend and ended up doing a deep dive into how the graphs works in PyMC (which I had not really done before). I'll probably have more time to look into this in ~ 2 weeks. |
Sounds good - thank you!! |
I finally had time to look more into this. Regarding the Jax part, it was failing because of #90: import was skipped because importing For shape and compiler flags, I left comments in the threads above. |
Thanks @vandalt!! |
Hi @dfm!
I wanted to play with PPLs during .Astronomy 12, including the latest PyMC version. I figured adding it to
exoplanet-core
would be a simple way to get familiar with the codebase while doing something useful. A few notes:pymc4
directory topymc
and did the changes there, as PyMC no longer uses the version number in the package name.aesara
withpytensor
.pymc4
andpymc
backends are actually in conflict because PyMC v4 also usespymc
as the package name.pymc5
as the submodule name here for the latest version?pymc_jax_test.py
are failing, but they were already failing inpymc4_jax_test.py
Let me know if there is anything I should add/change.
Thank you!