Skip to content

Commit fcc1177

Browse files
committed
Allow creating MarginalModel from existing Model
1 parent bdf8382 commit fcc1177

File tree

4 files changed

+134
-23
lines changed

4 files changed

+134
-23
lines changed

docs/api_reference.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ methods in the current release of PyMC experimental.
1010

1111
as_model
1212
MarginalModel
13+
marginalize
1314
model_builder.ModelBuilder
1415

1516
Inference

pymc_experimental/model/marginal_model.py

Lines changed: 69 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import warnings
2-
from typing import Sequence
2+
from typing import Sequence, Union
33

44
import numpy as np
55
import pymc
@@ -26,10 +26,12 @@
2626
from pytensor.tensor.shape import Shape
2727
from pytensor.tensor.special import log_softmax
2828

29-
__all__ = ["MarginalModel"]
29+
__all__ = ["MarginalModel", "marginalize"]
3030

3131
from pymc_experimental.distributions import DiscreteMarkovChain
3232

33+
ModelRVs = TensorVariable | Sequence[TensorVariable] | str | Sequence[str]
34+
3335

3436
class MarginalModel(Model):
3537
"""Subclass of PyMC Model that implements functionality for automatic
@@ -208,35 +210,50 @@ def logp(self, vars=None, **kwargs):
208210
vars = [m[var.name] for var in vars]
209211
return m._logp(vars=vars, **kwargs)
210212

211-
def clone(self):
212-
m = MarginalModel(coords=self.coords)
213-
model_vars = self.basic_RVs + self.potentials + self.deterministics + self.marginalized_rvs
214-
data_vars = [var for name, var in self.named_vars.items() if var not in model_vars]
213+
@staticmethod
214+
def _clone(model: Union[Model, "MarginalModel"]):
215+
new_model = MarginalModel(coords=model.coords)
216+
if isinstance(model, MarginalModel):
217+
marginalized_rvs = model.marginalized_rvs
218+
marginalized_named_vars_to_dims = model._marginalized_named_vars_to_dims
219+
else:
220+
marginalized_rvs = []
221+
marginalized_named_vars_to_dims = {}
222+
223+
model_vars = model.basic_RVs + model.potentials + model.deterministics + marginalized_rvs
224+
data_vars = [var for name, var in model.named_vars.items() if var not in model_vars]
215225
vars = model_vars + data_vars
216226
cloned_vars = clone_replace(vars)
217227
vars_to_clone = {var: cloned_var for var, cloned_var in zip(vars, cloned_vars)}
218-
m.vars_to_clone = vars_to_clone
219-
220-
m.named_vars = treedict({name: vars_to_clone[var] for name, var in self.named_vars.items()})
221-
m.named_vars_to_dims = self.named_vars_to_dims
222-
m.values_to_rvs = {i: vars_to_clone[rv] for i, rv in self.values_to_rvs.items()}
223-
m.rvs_to_values = {vars_to_clone[rv]: i for rv, i in self.rvs_to_values.items()}
224-
m.rvs_to_transforms = {vars_to_clone[rv]: i for rv, i in self.rvs_to_transforms.items()}
225-
m.rvs_to_initial_values = {
226-
vars_to_clone[rv]: i for rv, i in self.rvs_to_initial_values.items()
228+
new_model.vars_to_clone = vars_to_clone
229+
230+
new_model.named_vars = treedict(
231+
{name: vars_to_clone[var] for name, var in model.named_vars.items()}
232+
)
233+
new_model.named_vars_to_dims = model.named_vars_to_dims
234+
new_model.values_to_rvs = {i: vars_to_clone[rv] for i, rv in model.values_to_rvs.items()}
235+
new_model.rvs_to_values = {vars_to_clone[rv]: i for rv, i in model.rvs_to_values.items()}
236+
new_model.rvs_to_transforms = {
237+
vars_to_clone[rv]: i for rv, i in model.rvs_to_transforms.items()
238+
}
239+
new_model.rvs_to_initial_values = {
240+
vars_to_clone[rv]: i for rv, i in model.rvs_to_initial_values.items()
227241
}
228-
m.free_RVs = [vars_to_clone[rv] for rv in self.free_RVs]
229-
m.observed_RVs = [vars_to_clone[rv] for rv in self.observed_RVs]
230-
m.potentials = [vars_to_clone[pot] for pot in self.potentials]
231-
m.deterministics = [vars_to_clone[det] for det in self.deterministics]
242+
new_model.free_RVs = [vars_to_clone[rv] for rv in model.free_RVs]
243+
new_model.observed_RVs = [vars_to_clone[rv] for rv in model.observed_RVs]
244+
new_model.potentials = [vars_to_clone[pot] for pot in model.potentials]
245+
new_model.deterministics = [vars_to_clone[det] for det in model.deterministics]
232246

233-
m.marginalized_rvs = [vars_to_clone[rv] for rv in self.marginalized_rvs]
234-
m._marginalized_named_vars_to_dims = self._marginalized_named_vars_to_dims
235-
return m
247+
new_model.marginalized_rvs = [vars_to_clone[rv] for rv in marginalized_rvs]
248+
new_model._marginalized_named_vars_to_dims = marginalized_named_vars_to_dims
249+
return new_model
250+
251+
def clone(self):
252+
return self._clone(self)
236253

237254
def marginalize(
238255
self,
239-
rvs_to_marginalize: TensorVariable | Sequence[TensorVariable] | str | Sequence[str],
256+
rvs_to_marginalize: ModelRVs,
240257
):
241258
if not isinstance(rvs_to_marginalize, Sequence):
242259
rvs_to_marginalize = (rvs_to_marginalize,)
@@ -491,6 +508,35 @@ def transform_input(inputs):
491508
return rv_dataset
492509

493510

511+
def marginalize(model: Model, rvs_to_marginalize: ModelRVs) -> MarginalModel:
512+
"""Marginalize a subset of variables in a PyMC model.
513+
514+
This creates a class of `MarginalModel` from an existing `Model`, with the specified
515+
variables marginalized.
516+
517+
See documentation for `MarginalModel` for more information.
518+
519+
Parameters
520+
----------
521+
model : Model
522+
PyMC model to marginalize. Original variables well be cloned.
523+
rvs_to_marginalize : Sequence[TensorVariable]
524+
Variables to marginalize in the returned model.
525+
526+
Returns
527+
-------
528+
marginal_model: MarginalModel
529+
Marginal model with the specified variables marginalized.
530+
"""
531+
if not isinstance(rvs_to_marginalize, tuple | list):
532+
rvs_to_marginalize = (rvs_to_marginalize,)
533+
rvs_to_marginalize = [rv if isinstance(rv, str) else rv.name for rv in rvs_to_marginalize]
534+
535+
marginal_model = MarginalModel._clone(model)
536+
marginal_model.marginalize(rvs_to_marginalize)
537+
return marginal_model
538+
539+
494540
class MarginalRV(SymbolicRandomVariable):
495541
"""Base class for Marginalized RVs"""
496542

pymc_experimental/tests/model/test_marginal_model.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,18 @@
1010
from pymc import ImputationWarning, inputvars
1111
from pymc.distributions import transforms
1212
from pymc.logprob.abstract import _logprob
13+
from pymc.model.fgraph import fgraph_from_model
1314
from pymc.util import UNSET
1415
from scipy.special import log_softmax, logsumexp
1516
from scipy.stats import halfnorm, norm
17+
from utils import equal_computations_up_to_root
1618

1719
from pymc_experimental.distributions import DiscreteMarkovChain
1820
from pymc_experimental.model.marginal_model import (
1921
FiniteDiscreteMarginalRV,
2022
MarginalModel,
2123
is_conditional_dependent,
24+
marginalize,
2225
)
2326

2427

@@ -776,3 +779,33 @@ def test_mutable_indexing_jax_backend():
776779
pm.LogNormal("y", mu=cat_effect[cat_effect_idx], sigma=1 + is_outlier, observed=data)
777780
model.marginalize(["is_outlier"])
778781
get_jaxified_logp(model)
782+
783+
784+
def test_marginal_model_func():
785+
def create_model(model_class):
786+
with model_class(coords={"trial": range(10)}) as m:
787+
idx = pm.Bernoulli("idx", p=0.5, dims="trial")
788+
mu = pt.where(idx, 1, -1)
789+
sigma = pm.HalfNormal("sigma")
790+
y = pm.Normal("y", mu=mu, sigma=sigma, dims="trial", observed=[1] * 10)
791+
return m
792+
793+
marginal_m = marginalize(create_model(pm.Model), ["idx"])
794+
assert isinstance(marginal_m, MarginalModel)
795+
796+
reference_m = create_model(MarginalModel)
797+
reference_m.marginalize(["idx"])
798+
799+
# Check forward graph representation is the same
800+
marginal_fgraph, _ = fgraph_from_model(marginal_m)
801+
reference_fgraph, _ = fgraph_from_model(reference_m)
802+
assert equal_computations_up_to_root(marginal_fgraph.outputs, reference_fgraph.outputs)
803+
804+
# Check logp graph is the same
805+
# This fails because OpFromGraphs comparison is broken
806+
# assert equal_computations_up_to_root([marginal_m.logp()], [reference_m.logp()])
807+
ip = marginal_m.initial_point()
808+
np.testing.assert_allclose(
809+
marginal_m.compile_logp()(ip),
810+
reference_m.compile_logp()(ip),
811+
)

pymc_experimental/tests/utils.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from typing import Sequence
2+
3+
from pytensor.compile import SharedVariable
4+
from pytensor.graph import Constant, graph_inputs
5+
from pytensor.graph.basic import Variable, equal_computations
6+
from pytensor.tensor.random.type import RandomType
7+
8+
9+
def equal_computations_up_to_root(
10+
xs: Sequence[Variable], ys: Sequence[Variable], ignore_rng_values=True
11+
) -> bool:
12+
# Check if graphs are equivalent even if root variables have distinct identities
13+
14+
x_graph_inputs = [var for var in graph_inputs(xs) if not isinstance(var, Constant)]
15+
y_graph_inputs = [var for var in graph_inputs(ys) if not isinstance(var, Constant)]
16+
if len(x_graph_inputs) != len(y_graph_inputs):
17+
return False
18+
for x, y in zip(x_graph_inputs, y_graph_inputs):
19+
if x.type != y.type:
20+
return False
21+
if x.name != y.name:
22+
return False
23+
if isinstance(x, SharedVariable):
24+
if not isinstance(y, SharedVariable):
25+
return False
26+
if isinstance(x.type, RandomType) and ignore_rng_values:
27+
continue
28+
if not x.type.values_eq(x.get_value(), y.get_value()):
29+
return False
30+
31+
return equal_computations(xs, ys, in_xs=x_graph_inputs, in_ys=y_graph_inputs)

0 commit comments

Comments
 (0)