Skip to content

Commit

Permalink
Added placeholder/reminder to remove jax dependency when converting t…
Browse files Browse the repository at this point in the history
…race data to InferenceData
  • Loading branch information
aphc14 committed Nov 7, 2024
1 parent fdc3f38 commit 1fd7a11
Showing 1 changed file with 18 additions and 7 deletions.
25 changes: 18 additions & 7 deletions pymc_experimental/inference/pathfinder/pathfinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def convert_flat_trace_to_idata(
samples,
include_transformed=False,
postprocessing_backend="cpu",
inference_backend="pymc",
model=None,
):
model = modelcontext(model)
Expand All @@ -139,10 +140,21 @@ def convert_flat_trace_to_idata(
var_names = model.unobserved_value_vars
vars_to_sample = list(get_default_varnames(var_names, include_transformed=include_transformed))
print("Transforming variables...", file=sys.stdout)
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
result = jax.vmap(jax.vmap(jax_fn))(
*jax.device_put(list(trace.values()), jax.devices(postprocessing_backend)[0])
)

if inference_backend == "pymc":
# TODO: we need to remove JAX dependency as win32 users can now use Pathfinder with inference_backend="pymc".
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
result = jax.vmap(jax.vmap(jax_fn))(
*jax.device_put(list(trace.values()), jax.devices(postprocessing_backend)[0])
)
elif inference_backend == "blackjax":
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
result = jax.vmap(jax.vmap(jax_fn))(
*jax.device_put(list(trace.values()), jax.devices(postprocessing_backend)[0])
)
else:
raise ValueError(f"Invalid inference_backend: {inference_backend}")

trace = {v.name: r for v, r in zip(vars_to_sample, result)}
coords, dims = coords_and_dims_for_inferencedata(model)
idata = az.from_dict(trace, dims=dims, coords=coords)
Expand Down Expand Up @@ -742,7 +754,6 @@ def fit_pathfinder(
random_seed=random_seed,
**pathfinder_kwargs,
)

elif inference_backend == "blackjax":
jitter_seed, pathfinder_seed, sample_seed = _get_seeds_per_chain(random_seed, 3)
# TODO: extend initial points initialisation to blackjax
Expand Down Expand Up @@ -773,15 +784,15 @@ def fit_pathfinder(
state=pathfinder_state,
num_samples=num_draws,
)

else:
raise ValueError(f"Inference backend {inference_backend} not supported")
raise ValueError(f"Invalid inference_backend: {inference_backend}")

print("Running pathfinder...", file=sys.stdout)

idata = convert_flat_trace_to_idata(
pathfinder_samples,
postprocessing_backend=postprocessing_backend,
inference_backend=inference_backend,
model=model,
)
return idata

0 comments on commit 1fd7a11

Please sign in to comment.