-
-
Notifications
You must be signed in to change notification settings - Fork 61
Support HMM via marginalization of DiscreteMarkovChain #257
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,29 +10,26 @@ | |
from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform | ||
from pymc.distributions.transforms import Chain | ||
from pymc.logprob.abstract import _logprob | ||
from pymc.logprob.basic import conditional_logp | ||
from pymc.logprob.basic import conditional_logp, logp | ||
from pymc.logprob.transforms import IntervalTransform | ||
from pymc.model import Model | ||
from pymc.pytensorf import compile_pymc, constant_fold, inputvars | ||
from pymc.util import _get_seeds_per_chain, dataset_to_point_list, treedict | ||
from pytensor import Mode | ||
from pytensor import Mode, scan | ||
from pytensor.compile import SharedVariable | ||
from pytensor.compile.builders import OpFromGraph | ||
from pytensor.graph import ( | ||
Constant, | ||
FunctionGraph, | ||
ancestors, | ||
clone_replace, | ||
vectorize_graph, | ||
) | ||
from pytensor.graph import Constant, FunctionGraph, ancestors, clone_replace | ||
from pytensor.graph.replace import vectorize_graph | ||
from pytensor.scan import map as scan_map | ||
from pytensor.tensor import TensorVariable | ||
from pytensor.tensor import TensorType, TensorVariable | ||
from pytensor.tensor.elemwise import Elemwise | ||
from pytensor.tensor.shape import Shape | ||
from pytensor.tensor.special import log_softmax | ||
|
||
__all__ = ["MarginalModel"] | ||
|
||
from pymc_experimental.distributions import DiscreteMarkovChain | ||
|
||
|
||
class MarginalModel(Model): | ||
"""Subclass of PyMC Model that implements functionality for automatic | ||
|
@@ -247,16 +244,25 @@ def marginalize( | |
self[var] if isinstance(var, str) else var for var in rvs_to_marginalize | ||
] | ||
|
||
supported_dists = (Bernoulli, Categorical, DiscreteUniform) | ||
for rv_to_marginalize in rvs_to_marginalize: | ||
if rv_to_marginalize not in self.free_RVs: | ||
raise ValueError( | ||
f"Marginalized RV {rv_to_marginalize} is not a free RV in the model" | ||
) | ||
if not isinstance(rv_to_marginalize.owner.op, supported_dists): | ||
|
||
rv_op = rv_to_marginalize.owner.op | ||
if isinstance(rv_op, DiscreteMarkovChain): | ||
if rv_op.n_lags > 1: | ||
raise NotImplementedError( | ||
"Marginalization for DiscreteMarkovChain with n_lags > 1 is not supported" | ||
) | ||
if rv_to_marginalize.owner.inputs[0].type.ndim > 2: | ||
raise NotImplementedError( | ||
"Marginalization for DiscreteMarkovChain with non-matrix transition probability is not supported" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can a markov chain have a non-matrix transition probability? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should be valid for batch dims |
||
) | ||
elif not isinstance(rv_op, (Bernoulli, Categorical, DiscreteUniform)): | ||
raise NotImplementedError( | ||
f"RV with distribution {rv_to_marginalize.owner.op} cannot be marginalized. " | ||
f"Supported distribution include {supported_dists}" | ||
f"Marginalization of RV with distribution {rv_to_marginalize.owner.op} is not supported" | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought the old error message was more helpful There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It was but it not gonna scale There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe link to the docs where it lists all the supported distributions then? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The notes state that functionality is restricted, only finite discrete RVs are supported which is kind of true. Although we don't yet support Truncated/Censored of infinite discrete RVs which thus become finite: #95 We also don't support Multinomial which in theory is finite... So I think we the disclaimer functionality is restricted and this error message indicating the type of the RV that could not be marginalized it's fair game? |
||
|
||
if rv_to_marginalize.name in self.named_vars_to_dims: | ||
|
@@ -381,41 +387,36 @@ def transform_input(inputs): | |
|
||
rv_dict = {} | ||
rv_dims = {} | ||
for seed, rv in zip(seeds, vars_to_recover): | ||
for seed, marginalized_rv in zip(seeds, vars_to_recover): | ||
supported_dists = (Bernoulli, Categorical, DiscreteUniform) | ||
if not isinstance(rv.owner.op, supported_dists): | ||
if not isinstance(marginalized_rv.owner.op, supported_dists): | ||
raise NotImplementedError( | ||
f"RV with distribution {rv.owner.op} cannot be recovered. " | ||
f"RV with distribution {marginalized_rv.owner.op} cannot be recovered. " | ||
f"Supported distribution include {supported_dists}" | ||
) | ||
|
||
m = self.clone() | ||
rv = m.vars_to_clone[rv] | ||
m.unmarginalize([rv]) | ||
dependent_vars = find_conditional_dependent_rvs(rv, m.basic_RVs) | ||
joint_logps = m.logp(vars=dependent_vars + [rv], sum=False) | ||
marginalized_rv = m.vars_to_clone[marginalized_rv] | ||
m.unmarginalize([marginalized_rv]) | ||
dependent_vars = find_conditional_dependent_rvs(marginalized_rv, m.basic_RVs) | ||
joint_logps = m.logp(vars=[marginalized_rv] + dependent_vars, sum=False) | ||
|
||
marginalized_value = m.rvs_to_values[rv] | ||
marginalized_value = m.rvs_to_values[marginalized_rv] | ||
other_values = [v for v in m.value_vars if v is not marginalized_value] | ||
|
||
# Handle batch dims for marginalized value and its dependent RVs | ||
joint_logp = joint_logps[-1] | ||
for dv in joint_logps[:-1]: | ||
dbcast = dv.type.broadcastable | ||
mbcast = marginalized_value.type.broadcastable | ||
mbcast = (True,) * (len(dbcast) - len(mbcast)) + mbcast | ||
values_axis_bcast = [ | ||
i for i, (m, v) in enumerate(zip(mbcast, dbcast)) if m and not v | ||
] | ||
joint_logp += dv.sum(values_axis_bcast) | ||
marginalized_logp, *dependent_logps = joint_logps | ||
joint_logp = marginalized_logp + _add_reduce_batch_dependent_logps( | ||
marginalized_rv.type, dependent_logps | ||
) | ||
|
||
rv_shape = constant_fold(tuple(rv.shape)) | ||
rv_domain = get_domain_of_finite_discrete_rv(rv) | ||
rv_shape = constant_fold(tuple(marginalized_rv.shape)) | ||
rv_domain = get_domain_of_finite_discrete_rv(marginalized_rv) | ||
rv_domain_tensor = pt.moveaxis( | ||
pt.full( | ||
(*rv_shape, len(rv_domain)), | ||
rv_domain, | ||
dtype=rv.dtype, | ||
dtype=marginalized_rv.dtype, | ||
), | ||
-1, | ||
0, | ||
|
@@ -431,7 +432,7 @@ def transform_input(inputs): | |
joint_logps_norm = log_softmax(joint_logps, axis=-1) | ||
if return_samples: | ||
sample_rv_outs = pymc.Categorical.dist(logit_p=joint_logps) | ||
if isinstance(rv.owner.op, DiscreteUniform): | ||
if isinstance(marginalized_rv.owner.op, DiscreteUniform): | ||
sample_rv_outs += rv_domain[0] | ||
|
||
rv_loglike_fn = compile_pymc( | ||
|
@@ -456,18 +457,20 @@ def transform_input(inputs): | |
logps, samples = zip(*logvs) | ||
logps = np.array(logps) | ||
samples = np.array(samples) | ||
rv_dict[rv.name] = samples.reshape( | ||
rv_dict[marginalized_rv.name] = samples.reshape( | ||
tuple(len(coord) for coord in stacked_dims.values()) + samples.shape[1:], | ||
) | ||
else: | ||
logps = np.array(logvs) | ||
|
||
rv_dict["lp_" + rv.name] = logps.reshape( | ||
rv_dict["lp_" + marginalized_rv.name] = logps.reshape( | ||
tuple(len(coord) for coord in stacked_dims.values()) + logps.shape[1:], | ||
) | ||
if rv.name in m.named_vars_to_dims: | ||
rv_dims[rv.name] = list(m.named_vars_to_dims[rv.name]) | ||
rv_dims["lp_" + rv.name] = rv_dims[rv.name] + ["lp_" + rv.name + "_dim"] | ||
if marginalized_rv.name in m.named_vars_to_dims: | ||
rv_dims[marginalized_rv.name] = list(m.named_vars_to_dims[marginalized_rv.name]) | ||
rv_dims["lp_" + marginalized_rv.name] = rv_dims[marginalized_rv.name] + [ | ||
"lp_" + marginalized_rv.name + "_dim" | ||
] | ||
ricardoV94 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
coords, dims = coords_and_dims_for_inferencedata(self) | ||
dims.update(rv_dims) | ||
|
@@ -495,6 +498,10 @@ class FiniteDiscreteMarginalRV(MarginalRV): | |
"""Base class for Finite Discrete Marginalized RVs""" | ||
|
||
|
||
class DiscreteMarginalMarkovChainRV(MarginalRV): | ||
"""Base class for Discrete Marginal Markov Chain RVs""" | ||
|
||
|
||
def static_shape_ancestors(vars): | ||
"""Identify ancestors Shape Ops of static shapes (therefore constant in a valid graph).""" | ||
return [ | ||
|
@@ -582,10 +589,13 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs | |
raise ValueError(f"No RVs depend on marginalized RV {rv_to_marginalize}") | ||
|
||
ndim_supp = {rv.owner.op.ndim_supp for rv in dependent_rvs} | ||
if max(ndim_supp) > 0: | ||
if len(ndim_supp) != 1: | ||
raise NotImplementedError( | ||
"Marginalization of withe dependent Multivariate RVs not implemented" | ||
"Marginalization with dependent variables of different support dimensionality not implemented" | ||
) | ||
[ndim_supp] = ndim_supp | ||
if ndim_supp > 0: | ||
raise NotImplementedError("Marginalization with dependent Multivariate RVs not implemented") | ||
|
||
marginalized_rv_input_rvs = find_conditional_input_rvs([rv_to_marginalize], all_rvs) | ||
dependent_rvs_input_rvs = [ | ||
|
@@ -620,11 +630,17 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs | |
replace_inputs.update({input_rv: input_rv.type() for input_rv in input_rvs}) | ||
cloned_outputs = clone_replace(outputs, replace=replace_inputs) | ||
|
||
marginalization_op = FiniteDiscreteMarginalRV( | ||
if isinstance(rv_to_marginalize.owner.op, DiscreteMarkovChain): | ||
marginalize_constructor = DiscreteMarginalMarkovChainRV | ||
else: | ||
marginalize_constructor = FiniteDiscreteMarginalRV | ||
|
||
marginalization_op = marginalize_constructor( | ||
inputs=list(replace_inputs.values()), | ||
outputs=cloned_outputs, | ||
ndim_supp=0, | ||
ndim_supp=ndim_supp, | ||
) | ||
|
||
marginalized_rvs = marginalization_op(*replace_inputs.keys()) | ||
fgraph.replace_all(tuple(zip(rvs_to_marginalize, marginalized_rvs))) | ||
return rvs_to_marginalize, marginalized_rvs | ||
|
@@ -640,10 +656,29 @@ def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> Tuple[int, ...]: | |
elif isinstance(op, DiscreteUniform): | ||
lower, upper = constant_fold(rv.owner.inputs[3:]) | ||
return tuple(range(lower, upper + 1)) | ||
elif isinstance(op, DiscreteMarkovChain): | ||
P = rv.owner.inputs[0] | ||
return tuple(range(pt.get_vector_length(P[-1]))) | ||
|
||
raise NotImplementedError(f"Cannot compute domain for op {op}") | ||
|
||
|
||
def _add_reduce_batch_dependent_logps( | ||
marginalized_type: TensorType, dependent_logps: Sequence[TensorVariable] | ||
): | ||
"""Add the logps of dependent RVs while reducing extra batch dims relative to `marginalized_type`.""" | ||
|
||
mbcast = marginalized_type.broadcastable | ||
reduced_logps = [] | ||
for dependent_logp in dependent_logps: | ||
dbcast = dependent_logp.type.broadcastable | ||
dim_diff = len(dbcast) - len(mbcast) | ||
mbcast_aligned = (True,) * dim_diff + mbcast | ||
vbcast_axis = [i for i, (m, v) in enumerate(zip(mbcast_aligned, dbcast)) if m and not v] | ||
reduced_logps.append(dependent_logp.sum(vbcast_axis)) | ||
return pt.add(*reduced_logps) | ||
|
||
|
||
@_logprob.register(FiniteDiscreteMarginalRV) | ||
def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs): | ||
# Clone the inner RV graph of the Marginalized RV | ||
|
@@ -659,17 +694,12 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs): | |
logps_dict = conditional_logp(rv_values=inner_rvs_to_values, **kwargs) | ||
|
||
# Reduce logp dimensions corresponding to broadcasted variables | ||
joint_logp = logps_dict[inner_rvs_to_values[marginalized_rv]] | ||
for inner_rv, inner_value in inner_rvs_to_values.items(): | ||
if inner_rv is marginalized_rv: | ||
continue | ||
vbcast = inner_value.type.broadcastable | ||
mbcast = marginalized_rv.type.broadcastable | ||
mbcast = (True,) * (len(vbcast) - len(mbcast)) + mbcast | ||
values_axis_bcast = [i for i, (m, v) in enumerate(zip(mbcast, vbcast)) if m != v] | ||
joint_logp += logps_dict[inner_value].sum(values_axis_bcast, keepdims=True) | ||
|
||
# Wrap the joint_logp graph in an OpFromGrah, so that we can evaluate it at different | ||
marginalized_logp = logps_dict.pop(inner_rvs_to_values[marginalized_rv]) | ||
joint_logp = marginalized_logp + _add_reduce_batch_dependent_logps( | ||
marginalized_rv.type, logps_dict.values() | ||
) | ||
|
||
# Wrap the joint_logp graph in an OpFromGraph, so that we can evaluate it at different | ||
# values of the marginalized RV | ||
# Some inputs are not root inputs (such as transformed projections of value variables) | ||
# Or cannot be used as inputs to an OpFromGraph (shared variables and constants) | ||
|
@@ -697,6 +727,7 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs): | |
) | ||
|
||
# Arbitrary cutoff to switch to Scan implementation to keep graph size under control | ||
# TODO: Try vectorize here | ||
if len(marginalized_rv_domain) <= 10: | ||
joint_logps = [ | ||
joint_logp_op(marginalized_rv_domain_tensor[i], *values, *inputs) | ||
|
@@ -718,3 +749,69 @@ def logp_fn(marginalized_rv_const, *non_sequences): | |
|
||
# We have to add dummy logps for the remaining value variables, otherwise PyMC will raise | ||
return joint_logps, *(pt.constant(0),) * (len(values) - 1) | ||
|
||
|
||
@_logprob.register(DiscreteMarginalMarkovChainRV) | ||
def marginal_hmm_logp(op, values, *inputs, **kwargs): | ||
|
||
marginalized_rvs_node = op.make_node(*inputs) | ||
inner_rvs = clone_replace( | ||
op.inner_outputs, | ||
replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)}, | ||
) | ||
|
||
chain_rv, *dependent_rvs = inner_rvs | ||
P, n_steps_, init_dist_, rng = chain_rv.owner.inputs | ||
domain = pt.arange(P.shape[-1], dtype="int32") | ||
|
||
# Construct logp in two steps | ||
# Step 1: Compute the probability of the data ("emissions") under every possible state (vec_logp_emission) | ||
|
||
# First we need to vectorize the conditional logp graph of the data, in case there are batch dimensions floating | ||
# around. To do this, we need to break the dependency between chain and the init_dist_ random variable. Otherwise, | ||
# PyMC will detect a random variable in the logp graph (init_dist_), that isn't relevant at this step. | ||
chain_value = chain_rv.clone() | ||
dependent_rvs = clone_replace(dependent_rvs, {chain_rv: chain_value}) | ||
logp_emissions_dict = conditional_logp(dict(zip(dependent_rvs, values))) | ||
|
||
# Reduce and add the batch dims beyond the chain dimension | ||
reduced_logp_emissions = _add_reduce_batch_dependent_logps( | ||
chain_rv.type, logp_emissions_dict.values() | ||
) | ||
|
||
# Add a batch dimension for the domain of the chain | ||
chain_shape = constant_fold(tuple(chain_rv.shape)) | ||
batch_chain_value = pt.moveaxis(pt.full((*chain_shape, domain.size), domain), -1, 0) | ||
batch_logp_emissions = vectorize_graph(reduced_logp_emissions, {chain_value: batch_chain_value}) | ||
|
||
# Step 2: Compute the transition probabilities | ||
# This is the "forward algorithm", alpha_t = p(y | s_t) * sum_{s_{t-1}}(p(s_t | s_{t-1}) * alpha_{t-1}) | ||
# We do it entirely in logs, though. | ||
|
||
# To compute the prior probabilities of each state, we evaluate the logp of the domain (all possible states) under | ||
# the initial distribution. This is robust to everything the user can throw at it. | ||
batch_logp_init_dist = pt.vectorize(lambda x: logp(init_dist_, x), "()->()")( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No way to avoid this lambda here with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The way to avoid it is to be a little function, but seems like a fine use for lambda? |
||
batch_chain_value[..., 0] | ||
) | ||
log_alpha_init = batch_logp_init_dist + batch_logp_emissions[..., 0] | ||
|
||
def step_alpha(logp_emission, log_alpha, log_P): | ||
step_log_prob = pt.logsumexp(log_alpha[:, None] + log_P, axis=0) | ||
return logp_emission + step_log_prob | ||
|
||
P_bcast_dims = (len(chain_shape) - 1) - (P.type.ndim - 2) | ||
log_P = pt.shape_padright(pt.log(P), P_bcast_dims) | ||
log_alpha_seq, _ = scan( | ||
step_alpha, | ||
non_sequences=[log_P], | ||
outputs_info=[log_alpha_init], | ||
# Scan needs the time dimension first, and we already consumed the 1st logp computing the initial value | ||
sequences=pt.moveaxis(batch_logp_emissions[..., 1:], -1, 0), | ||
) | ||
# Final logp is just the sum of the last scan state | ||
joint_logp = pt.logsumexp(log_alpha_seq[-1], axis=0) | ||
|
||
# If there are multiple emission streams, we have to add dummy logps for the remaining value variables. The first | ||
# return is the joint probability of everything together, but PyMC still expects one logp for each one. | ||
dummy_logps = (pt.constant(0),) * (len(values) - 1) | ||
return joint_logp, *dummy_logps |
Uh oh!
There was an error while loading. Please reload this page.