diff --git a/pymc_extras/model/marginal/distributions.py b/pymc_extras/model/marginal/distributions.py index de6ec5ba..d1807f61 100644 --- a/pymc_extras/model/marginal/distributions.py +++ b/pymc_extras/model/marginal/distributions.py @@ -7,7 +7,6 @@ from pymc.logprob.abstract import MeasurableOp, _logprob from pymc.logprob.basic import conditional_logp, logp from pymc.pytensorf import constant_fold -from pytensor import Variable from pytensor.compile.builders import OpFromGraph from pytensor.compile.mode import Mode from pytensor.graph import Op, vectorize_graph diff --git a/pymc_experimental/sampling/__init__.py b/pymc_extras/sampling/__init__.py similarity index 100% rename from pymc_experimental/sampling/__init__.py rename to pymc_extras/sampling/__init__.py diff --git a/pymc_experimental/sampling/mcmc.py b/pymc_extras/sampling/mcmc.py similarity index 98% rename from pymc_experimental/sampling/mcmc.py rename to pymc_extras/sampling/mcmc.py index c57e7fa2..65b9dd2e 100644 --- a/pymc_experimental/sampling/mcmc.py +++ b/pymc_extras/sampling/mcmc.py @@ -4,7 +4,7 @@ from pymc.sampling.mcmc import sample from pytensor.graph.rewriting.basic import GraphRewriter -from pymc_experimental.sampling.optimizations.optimize import ( +from pymc_extras.sampling.optimizations.optimize import ( TAGS_TYPE, optimize_model_for_mcmc_sampling, ) diff --git a/pymc_experimental/sampling/optimizations/__init__.py b/pymc_extras/sampling/optimizations/__init__.py similarity index 54% rename from pymc_experimental/sampling/optimizations/__init__.py rename to pymc_extras/sampling/optimizations/__init__.py index 78d0c858..46fdf848 100644 --- a/pymc_experimental/sampling/optimizations/__init__.py +++ b/pymc_extras/sampling/optimizations/__init__.py @@ -1,9 +1,8 @@ # ruff: noqa: F401 # Add rewrites to the optimization DBs -import pymc_experimental.sampling.optimizations.conjugacy -import pymc_experimental.sampling.optimizations.summary_stats -from pymc_experimental.sampling.optimizations.optimize import ( +from pymc_extras.sampling.optimizations import conjugacy, summary_stats +from pymc_extras.sampling.optimizations.optimize import ( optimize_model_for_mcmc_sampling, posterior_optimization_db, ) diff --git a/pymc_experimental/sampling/optimizations/conjugacy.py b/pymc_extras/sampling/optimizations/conjugacy.py similarity index 97% rename from pymc_experimental/sampling/optimizations/conjugacy.py rename to pymc_extras/sampling/optimizations/conjugacy.py index f4c3a360..f1f6e246 100644 --- a/pymc_experimental/sampling/optimizations/conjugacy.py +++ b/pymc_extras/sampling/optimizations/conjugacy.py @@ -10,10 +10,10 @@ from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.subtensor import _sum_grad_over_bcasted_dims as sum_bcasted_dims -from pymc_experimental.sampling.optimizations.conjugate_sampler import ( +from pymc_extras.sampling.optimizations.conjugate_sampler import ( ConjugateRV, ) -from pymc_experimental.sampling.optimizations.optimize import posterior_optimization_db +from pymc_extras.sampling.optimizations.optimize import posterior_optimization_db def register_conjugacy_rewrites_variants(rewrite_fn, tracks=(ModelFreeRV,)): diff --git a/pymc_experimental/sampling/optimizations/conjugate_sampler.py b/pymc_extras/sampling/optimizations/conjugate_sampler.py similarity index 98% rename from pymc_experimental/sampling/optimizations/conjugate_sampler.py rename to pymc_extras/sampling/optimizations/conjugate_sampler.py index 9ddaa462..1502cfe1 100644 --- a/pymc_experimental/sampling/optimizations/conjugate_sampler.py +++ b/pymc_extras/sampling/optimizations/conjugate_sampler.py @@ -13,7 +13,7 @@ from pytensor.link.jax.linker import JAXLinker from pytensor.tensor.random.type import RandomGeneratorType -from pymc_experimental.utils.ofg import inline_ofg_outputs +from pymc_extras.utils.ofg import inline_ofg_outputs class ConjugateRV(OpFromGraph, MeasurableOp): diff --git a/pymc_experimental/sampling/optimizations/optimize.py b/pymc_extras/sampling/optimizations/optimize.py similarity index 100% rename from pymc_experimental/sampling/optimizations/optimize.py rename to pymc_extras/sampling/optimizations/optimize.py diff --git a/pymc_experimental/sampling/optimizations/summary_stats.py b/pymc_extras/sampling/optimizations/summary_stats.py similarity index 96% rename from pymc_experimental/sampling/optimizations/summary_stats.py rename to pymc_extras/sampling/optimizations/summary_stats.py index 58a9c806..167e2f1d 100644 --- a/pymc_experimental/sampling/optimizations/summary_stats.py +++ b/pymc_extras/sampling/optimizations/summary_stats.py @@ -5,7 +5,7 @@ from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import node_rewriter -from pymc_experimental.sampling.optimizations.optimize import posterior_optimization_db +from pymc_extras.sampling.optimizations.optimize import posterior_optimization_db @node_rewriter(tracks=[ModelObservedRV]) diff --git a/pymc_extras/utils/ofg.py b/pymc_extras/utils/ofg.py index 6de8ed4a..c85c85cb 100644 --- a/pymc_extras/utils/ofg.py +++ b/pymc_extras/utils/ofg.py @@ -5,7 +5,7 @@ from pytensor.graph.replace import clone_replace -def inline_ofg_outputs(op: OpFromGraph, inputs: Sequence[Variable]) -> tuple[Variable]: +def inline_ofg_outputs(op: OpFromGraph, inputs: Sequence[Variable]) -> list[Variable]: """Inline the inner graph (outputs) of an OpFromGraph Op. Whereas `OpFromGraph` "wraps" a graph inside a single Op, this function "unwraps" diff --git a/tests/sampling/mcmc/test_mcmc.py b/tests/sampling/mcmc/test_mcmc.py index f179e22d..beda5bbe 100644 --- a/tests/sampling/mcmc/test_mcmc.py +++ b/tests/sampling/mcmc/test_mcmc.py @@ -8,7 +8,7 @@ from pymc.sampling.mcmc import sample from pymc.step_methods import Slice -from pymc_experimental import opt_sample +from pymc_extras import opt_sample def test_custom_step_raises(): diff --git a/tests/sampling/optimizations/test_conjugacy.py b/tests/sampling/optimizations/test_conjugacy.py index b45aabfc..658f56d4 100644 --- a/tests/sampling/optimizations/test_conjugacy.py +++ b/tests/sampling/optimizations/test_conjugacy.py @@ -6,8 +6,8 @@ from pymc.model.transform.conditioning import remove_value_transforms from pymc.sampling import draw -from pymc_experimental.sampling.optimizations.conjugate_sampler import ConjugateRV -from pymc_experimental.sampling.optimizations.optimize import optimize_model_for_mcmc_sampling +from pymc_extras.sampling.optimizations import optimize_model_for_mcmc_sampling +from pymc_extras.sampling.optimizations.conjugate_sampler import ConjugateRV @pytest.mark.parametrize("eager", [False, True]) diff --git a/tests/sampling/optimizations/test_summary_stats.py b/tests/sampling/optimizations/test_summary_stats.py index bb390e0e..bf439bbe 100644 --- a/tests/sampling/optimizations/test_summary_stats.py +++ b/tests/sampling/optimizations/test_summary_stats.py @@ -3,7 +3,7 @@ from pymc.distributions import HalfNormal, Normal from pymc.model.core import Model -from pymc_experimental.sampling.optimizations.optimize import optimize_model_for_mcmc_sampling +from pymc_extras.sampling.optimizations import optimize_model_for_mcmc_sampling def test_summary_stats_normal():