Skip to content

Commit

Permalink
rename pymc_experimental -> pymc_extras
Browse files Browse the repository at this point in the history
  • Loading branch information
jessegrabowski committed Jan 2, 2025
1 parent 8de5346 commit 0359373
Show file tree
Hide file tree
Showing 12 changed files with 12 additions and 14 deletions.
1 change: 0 additions & 1 deletion pymc_extras/model/marginal/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from pymc.logprob.abstract import MeasurableOp, _logprob
from pymc.logprob.basic import conditional_logp, logp
from pymc.pytensorf import constant_fold
from pytensor import Variable
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.mode import Mode
from pytensor.graph import Op, vectorize_graph
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pymc.sampling.mcmc import sample
from pytensor.graph.rewriting.basic import GraphRewriter

from pymc_experimental.sampling.optimizations.optimize import (
from pymc_extras.sampling.optimizations.optimize import (
TAGS_TYPE,
optimize_model_for_mcmc_sampling,
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# ruff: noqa: F401
# Add rewrites to the optimization DBs
import pymc_experimental.sampling.optimizations.conjugacy
import pymc_experimental.sampling.optimizations.summary_stats

from pymc_experimental.sampling.optimizations.optimize import (
from pymc_extras.sampling.optimizations import conjugacy, summary_stats
from pymc_extras.sampling.optimizations.optimize import (
optimize_model_for_mcmc_sampling,
posterior_optimization_db,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.subtensor import _sum_grad_over_bcasted_dims as sum_bcasted_dims

from pymc_experimental.sampling.optimizations.conjugate_sampler import (
from pymc_extras.sampling.optimizations.conjugate_sampler import (
ConjugateRV,
)
from pymc_experimental.sampling.optimizations.optimize import posterior_optimization_db
from pymc_extras.sampling.optimizations.optimize import posterior_optimization_db


def register_conjugacy_rewrites_variants(rewrite_fn, tracks=(ModelFreeRV,)):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from pytensor.link.jax.linker import JAXLinker
from pytensor.tensor.random.type import RandomGeneratorType

from pymc_experimental.utils.ofg import inline_ofg_outputs
from pymc_extras.utils.ofg import inline_ofg_outputs


class ConjugateRV(OpFromGraph, MeasurableOp):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import node_rewriter

from pymc_experimental.sampling.optimizations.optimize import posterior_optimization_db
from pymc_extras.sampling.optimizations.optimize import posterior_optimization_db


@node_rewriter(tracks=[ModelObservedRV])
Expand Down
2 changes: 1 addition & 1 deletion pymc_extras/utils/ofg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pytensor.graph.replace import clone_replace


def inline_ofg_outputs(op: OpFromGraph, inputs: Sequence[Variable]) -> tuple[Variable]:
def inline_ofg_outputs(op: OpFromGraph, inputs: Sequence[Variable]) -> list[Variable]:
"""Inline the inner graph (outputs) of an OpFromGraph Op.
Whereas `OpFromGraph` "wraps" a graph inside a single Op, this function "unwraps"
Expand Down
2 changes: 1 addition & 1 deletion tests/sampling/mcmc/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pymc.sampling.mcmc import sample
from pymc.step_methods import Slice

from pymc_experimental import opt_sample
from pymc_extras import opt_sample


def test_custom_step_raises():
Expand Down
4 changes: 2 additions & 2 deletions tests/sampling/optimizations/test_conjugacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from pymc.model.transform.conditioning import remove_value_transforms
from pymc.sampling import draw

from pymc_experimental.sampling.optimizations.conjugate_sampler import ConjugateRV
from pymc_experimental.sampling.optimizations.optimize import optimize_model_for_mcmc_sampling
from pymc_extras.sampling.optimizations import optimize_model_for_mcmc_sampling
from pymc_extras.sampling.optimizations.conjugate_sampler import ConjugateRV


@pytest.mark.parametrize("eager", [False, True])
Expand Down
2 changes: 1 addition & 1 deletion tests/sampling/optimizations/test_summary_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pymc.distributions import HalfNormal, Normal
from pymc.model.core import Model

from pymc_experimental.sampling.optimizations.optimize import optimize_model_for_mcmc_sampling
from pymc_extras.sampling.optimizations import optimize_model_for_mcmc_sampling


def test_summary_stats_normal():
Expand Down

0 comments on commit 0359373

Please sign in to comment.