Skip to content

Commit

Permalink
fixing pymc3 compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
dfm committed Mar 11, 2021
1 parent 13be965 commit 6e6ea45
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 17 deletions.
6 changes: 2 additions & 4 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ on:

jobs:
tests:
name: "py${{ matrix.python-version }}; ${{ matrix.aesara-or-theano }}"
runs-on: "ubuntu-latest"
strategy:
matrix:
Expand All @@ -19,6 +18,7 @@ jobs:
- "theano-pymc==1.0.11"
- "theano-pymc"
- "aesara"
- '"pymc3==3.11.1" "aesara"'
- "https://github.com/pymc-devs/aesara/archive/master.zip"

steps:
Expand All @@ -32,10 +32,8 @@ jobs:
- name: Install dependencies
run: |
python -m pip install -U pip
python -m pip install $AESARA
python -m pip install ${{ matrix.aesara-or-theano }}
python -m pip install ".[test]"
env:
AESARA: ${{ matrix.aesara-or-theano }}
- name: Run the unit tests
run: python -m pytest -v tests
Expand Down
5 changes: 3 additions & 2 deletions src/aesara_theano_fallback/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
# -*- coding: utf-8 -*-

__all__ = [
"USE_AESARA",
"graph",
"aesara",
"sparse",
"tensor",
"change_flags",
"ifelse"
"ifelse",
]

from . import graph, tensor
from .compat import aesara, sparse, change_flags, ifelse
from .compat import USE_AESARA, aesara, sparse, change_flags, ifelse
from .aesara_theano_fallback_version import version as __version__ # noqa

__author__ = "Dan Foreman-Mackey, Rodrigo Luger"
Expand Down
12 changes: 10 additions & 2 deletions src/aesara_theano_fallback/compat.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
# -*- coding: utf-8 -*-

__all__ = ["aesara", "sparse", "tensor", "change_flags", "ifelse"]
__all__ = ["USE_AESARA", "aesara", "sparse", "change_flags", "ifelse"]


USE_AESARA = False
try:
import aesara

except ImportError:
aesara = None
else:
try:
import pymc3.theanof # noqa
except ImportError:
USE_AESARA = True


if aesara is None or not USE_AESARA:
try:
import theano.graph

Expand Down
6 changes: 4 additions & 2 deletions src/aesara_theano_fallback/graph.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# -*- coding: utf-8 -*-

try:
from .compat import USE_AESARA

if USE_AESARA:
from aesara.graph import * # noqa

except ImportError:
else:
try:
from theano.graph import * # noqa

Expand Down
13 changes: 6 additions & 7 deletions src/aesara_theano_fallback/tensor.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
# -*- coding: utf-8 -*-

try:
from aesara.tensor import * # noqa

except ImportError:
from theano.tensor import * # noqa
from theano.tensor import slinalg
from .compat import USE_AESARA

else:
if USE_AESARA:
from aesara.tensor import * # noqa
from aesara.tensor import slinalg # noqa
else:
from theano.tensor import * # noqa
from theano.tensor import slinalg # noqa
11 changes: 11 additions & 0 deletions tests/test_imports.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-

import pytest


def test_core_imports():
from aesara_theano_fallback import tensor
Expand All @@ -23,3 +25,12 @@ def test_graph_op_imports():

op.Op
op.ExternalCOp


def test_pymc3_compat():
import aesara_theano_fallback.tensor as tt

pm = pytest.importorskip("pymc3")
with pm.Model():
x = pm.Normal("x", shape=10)
tt.dot(x, x)

0 comments on commit 6e6ea45

Please sign in to comment.