Skip to content

Commit bdf8382

Browse files
committed
Add type hints to MarginalModel methods
1 parent 1f70c7a commit bdf8382

File tree

1 file changed

+12
-11
lines changed

1 file changed

+12
-11
lines changed

pymc_experimental/model/marginal_model.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import warnings
2-
from typing import Sequence, Tuple, Union
2+
from typing import Sequence
33

44
import numpy as np
55
import pymc
66
import pytensor.tensor as pt
7-
from arviz import dict_to_dataset
7+
from arviz import InferenceData, dict_to_dataset
88
from pymc import SymbolicRandomVariable
99
from pymc.backends.arviz import coords_and_dims_for_inferencedata, dataset_to_point_list
1010
from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform
@@ -14,7 +14,7 @@
1414
from pymc.logprob.transforms import IntervalTransform
1515
from pymc.model import Model
1616
from pymc.pytensorf import compile_pymc, constant_fold, inputvars
17-
from pymc.util import _get_seeds_per_chain, treedict
17+
from pymc.util import RandomState, _get_seeds_per_chain, treedict
1818
from pytensor import Mode, scan
1919
from pytensor.compile import SharedVariable
2020
from pytensor.compile.builders import OpFromGraph
@@ -236,7 +236,7 @@ def clone(self):
236236

237237
def marginalize(
238238
self,
239-
rvs_to_marginalize: Union[TensorVariable, str, Sequence[TensorVariable], Sequence[str]],
239+
rvs_to_marginalize: TensorVariable | Sequence[TensorVariable] | str | Sequence[str],
240240
):
241241
if not isinstance(rvs_to_marginalize, Sequence):
242242
rvs_to_marginalize = (rvs_to_marginalize,)
@@ -293,7 +293,8 @@ def _to_transformed(self):
293293
fn = self.compile_fn(inputs=self.free_RVs, outs=transformed_rvs)
294294
return fn, transformed_names
295295

296-
def unmarginalize(self, rvs_to_unmarginalize):
296+
def unmarginalize(self, rvs_to_unmarginalize: Sequence[TensorVariable]):
297+
"""TODO: Allow passing strings as well"""
297298
for rv in rvs_to_unmarginalize:
298299
self.marginalized_rvs.remove(rv)
299300
if rv.name in self._marginalized_named_vars_to_dims:
@@ -304,11 +305,11 @@ def unmarginalize(self, rvs_to_unmarginalize):
304305

305306
def recover_marginals(
306307
self,
307-
idata,
308-
var_names=None,
309-
return_samples=True,
310-
extend_inferencedata=True,
311-
random_seed=None,
308+
idata: InferenceData,
309+
var_names: Sequence[str] | None = None,
310+
return_samples: bool = True,
311+
extend_inferencedata: bool = True,
312+
random_seed: RandomState = None,
312313
):
313314
"""Computes posterior log-probabilities and samples of marginalized variables
314315
conditioned on parameters of the model given InferenceData with posterior group
@@ -649,7 +650,7 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
649650
return rvs_to_marginalize, marginalized_rvs
650651

651652

652-
def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> Tuple[int, ...]:
653+
def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]:
653654
op = rv.owner.op
654655
if isinstance(op, Bernoulli):
655656
return (0, 1)

0 commit comments

Comments
 (0)