From 6e6ea453c0edffb6914b0f5fcf237fd4237c3fe6 Mon Sep 17 00:00:00 2001 From: Dan F-M Date: Thu, 11 Mar 2021 08:27:54 -0500 Subject: [PATCH] fixing pymc3 compatibility --- .github/workflows/tests.yml | 6 ++---- src/aesara_theano_fallback/__init__.py | 5 +++-- src/aesara_theano_fallback/compat.py | 12 ++++++++++-- src/aesara_theano_fallback/graph.py | 6 ++++-- src/aesara_theano_fallback/tensor.py | 13 ++++++------- tests/test_imports.py | 11 +++++++++++ 6 files changed, 36 insertions(+), 17 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 2b217cb..889527f 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -9,7 +9,6 @@ on: jobs: tests: - name: "py${{ matrix.python-version }}; ${{ matrix.aesara-or-theano }}" runs-on: "ubuntu-latest" strategy: matrix: @@ -19,6 +18,7 @@ jobs: - "theano-pymc==1.0.11" - "theano-pymc" - "aesara" + - '"pymc3==3.11.1" "aesara"' - "https://github.com/pymc-devs/aesara/archive/master.zip" steps: @@ -32,10 +32,8 @@ jobs: - name: Install dependencies run: | python -m pip install -U pip - python -m pip install $AESARA + python -m pip install ${{ matrix.aesara-or-theano }} python -m pip install ".[test]" - env: - AESARA: ${{ matrix.aesara-or-theano }} - name: Run the unit tests run: python -m pytest -v tests diff --git a/src/aesara_theano_fallback/__init__.py b/src/aesara_theano_fallback/__init__.py index 8ded804..c144b01 100644 --- a/src/aesara_theano_fallback/__init__.py +++ b/src/aesara_theano_fallback/__init__.py @@ -1,16 +1,17 @@ # -*- coding: utf-8 -*- __all__ = [ + "USE_AESARA", "graph", "aesara", "sparse", "tensor", "change_flags", - "ifelse" + "ifelse", ] from . import graph, tensor -from .compat import aesara, sparse, change_flags, ifelse +from .compat import USE_AESARA, aesara, sparse, change_flags, ifelse from .aesara_theano_fallback_version import version as __version__ # noqa __author__ = "Dan Foreman-Mackey, Rodrigo Luger" diff --git a/src/aesara_theano_fallback/compat.py b/src/aesara_theano_fallback/compat.py index a43a87f..fb42d9a 100644 --- a/src/aesara_theano_fallback/compat.py +++ b/src/aesara_theano_fallback/compat.py @@ -1,13 +1,21 @@ # -*- coding: utf-8 -*- -__all__ = ["aesara", "sparse", "tensor", "change_flags", "ifelse"] +__all__ = ["USE_AESARA", "aesara", "sparse", "change_flags", "ifelse"] +USE_AESARA = False try: import aesara - except ImportError: + aesara = None +else: + try: + import pymc3.theanof # noqa + except ImportError: + USE_AESARA = True + +if aesara is None or not USE_AESARA: try: import theano.graph diff --git a/src/aesara_theano_fallback/graph.py b/src/aesara_theano_fallback/graph.py index 921bd09..2a419b1 100644 --- a/src/aesara_theano_fallback/graph.py +++ b/src/aesara_theano_fallback/graph.py @@ -1,9 +1,11 @@ # -*- coding: utf-8 -*- -try: +from .compat import USE_AESARA + +if USE_AESARA: from aesara.graph import * # noqa -except ImportError: +else: try: from theano.graph import * # noqa diff --git a/src/aesara_theano_fallback/tensor.py b/src/aesara_theano_fallback/tensor.py index 19cf6a2..0c3f3bb 100644 --- a/src/aesara_theano_fallback/tensor.py +++ b/src/aesara_theano_fallback/tensor.py @@ -1,11 +1,10 @@ # -*- coding: utf-8 -*- -try: - from aesara.tensor import * # noqa - -except ImportError: - from theano.tensor import * # noqa - from theano.tensor import slinalg +from .compat import USE_AESARA -else: +if USE_AESARA: + from aesara.tensor import * # noqa from aesara.tensor import slinalg # noqa +else: + from theano.tensor import * # noqa + from theano.tensor import slinalg # noqa diff --git a/tests/test_imports.py b/tests/test_imports.py index 55f0461..6bf641c 100644 --- a/tests/test_imports.py +++ b/tests/test_imports.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- +import pytest + def test_core_imports(): from aesara_theano_fallback import tensor @@ -23,3 +25,12 @@ def test_graph_op_imports(): op.Op op.ExternalCOp + + +def test_pymc3_compat(): + import aesara_theano_fallback.tensor as tt + + pm = pytest.importorskip("pymc3") + with pm.Model(): + x = pm.Normal("x", shape=10) + tt.dot(x, x)