1
1
import warnings
2
- from typing import Sequence , Tuple , Union
2
+ from typing import Sequence
3
3
4
4
import numpy as np
5
5
import pymc
6
6
import pytensor .tensor as pt
7
- from arviz import dict_to_dataset
7
+ from arviz import InferenceData , dict_to_dataset
8
8
from pymc import SymbolicRandomVariable
9
9
from pymc .backends .arviz import coords_and_dims_for_inferencedata , dataset_to_point_list
10
10
from pymc .distributions .discrete import Bernoulli , Categorical , DiscreteUniform
14
14
from pymc .logprob .transforms import IntervalTransform
15
15
from pymc .model import Model
16
16
from pymc .pytensorf import compile_pymc , constant_fold , inputvars
17
- from pymc .util import _get_seeds_per_chain , treedict
17
+ from pymc .util import RandomState , _get_seeds_per_chain , treedict
18
18
from pytensor import Mode , scan
19
19
from pytensor .compile import SharedVariable
20
20
from pytensor .compile .builders import OpFromGraph
@@ -236,7 +236,7 @@ def clone(self):
236
236
237
237
def marginalize (
238
238
self ,
239
- rvs_to_marginalize : Union [ TensorVariable , str , Sequence [TensorVariable ], Sequence [str ] ],
239
+ rvs_to_marginalize : TensorVariable | Sequence [TensorVariable ] | str | Sequence [str ],
240
240
):
241
241
if not isinstance (rvs_to_marginalize , Sequence ):
242
242
rvs_to_marginalize = (rvs_to_marginalize ,)
@@ -293,7 +293,8 @@ def _to_transformed(self):
293
293
fn = self .compile_fn (inputs = self .free_RVs , outs = transformed_rvs )
294
294
return fn , transformed_names
295
295
296
- def unmarginalize (self , rvs_to_unmarginalize ):
296
+ def unmarginalize (self , rvs_to_unmarginalize : Sequence [TensorVariable ]):
297
+ """TODO: Allow passing strings as well"""
297
298
for rv in rvs_to_unmarginalize :
298
299
self .marginalized_rvs .remove (rv )
299
300
if rv .name in self ._marginalized_named_vars_to_dims :
@@ -304,11 +305,11 @@ def unmarginalize(self, rvs_to_unmarginalize):
304
305
305
306
def recover_marginals (
306
307
self ,
307
- idata ,
308
- var_names = None ,
309
- return_samples = True ,
310
- extend_inferencedata = True ,
311
- random_seed = None ,
308
+ idata : InferenceData ,
309
+ var_names : Sequence [ str ] | None = None ,
310
+ return_samples : bool = True ,
311
+ extend_inferencedata : bool = True ,
312
+ random_seed : RandomState = None ,
312
313
):
313
314
"""Computes posterior log-probabilities and samples of marginalized variables
314
315
conditioned on parameters of the model given InferenceData with posterior group
@@ -649,7 +650,7 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
649
650
return rvs_to_marginalize , marginalized_rvs
650
651
651
652
652
- def get_domain_of_finite_discrete_rv (rv : TensorVariable ) -> Tuple [int , ...]:
653
+ def get_domain_of_finite_discrete_rv (rv : TensorVariable ) -> tuple [int , ...]:
653
654
op = rv .owner .op
654
655
if isinstance (op , Bernoulli ):
655
656
return (0 , 1 )
0 commit comments