Skip to content

Commit dbad516

Browse files
committed
add helper function to initialize masked psi
1 parent 462b1a6 commit dbad516

File tree

1 file changed

+38
-1
lines changed
  • pymc_experimental/distributions/multivariate

1 file changed

+38
-1
lines changed

pymc_experimental/distributions/multivariate/r2d2m2cp.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from typing import Sequence, Union
1717

18+
import numpy as np
1819
import pymc as pm
1920
import pytensor.tensor as pt
2021

@@ -69,6 +70,41 @@ def _R2D2M2CP_beta(
6970
return beta
7071

7172

73+
def _broadcast_as_dims(*values, dims):
74+
model = pm.modelcontext(None)
75+
shape = [len(model.coords[d]) for d in dims]
76+
return tuple(np.broadcast_to(v, shape) for v in values)
77+
78+
79+
def _psi_masked(positive_probs, positive_probs_std, *, dims):
80+
if not (
81+
isinstance(positive_probs, pt.Constant) and isinstance(positive_probs_std, pt.Constant)
82+
):
83+
raise TypeError(
84+
"Only constant values for positive_probs and positive_probs_std are accepted"
85+
)
86+
positive_probs, positive_probs_std = _broadcast_as_dims(
87+
positive_probs.data, positive_probs_std.data, dims=dims
88+
)
89+
mask = ~np.bitwise_or(positive_probs == 1, positive_probs == 0)
90+
if np.bitwise_and(~mask, positive_probs_std != 0).any():
91+
raise ValueError("Can't have both positive_probs == '1 or 0' and positive_probs_std != 0")
92+
if (~mask).any():
93+
r_idx = mask.nonzero()
94+
with pm.Model("psi"):
95+
psi = pm.Beta(
96+
"masked",
97+
mu=positive_probs[r_idx],
98+
sigma=positive_probs_std[r_idx],
99+
shape=len(r_idx[0]),
100+
)
101+
psi = pt.set_subtensor(pt.as_tensor(positive_probs)[r_idx], psi)
102+
psi = pm.Deterministic("psi", psi, dims=dims)
103+
else:
104+
psi = pm.Beta("psi", mu=positive_probs, sigma=positive_probs_std, dims=dims)
105+
return mask, psi
106+
107+
72108
def R2D2M2CP(
73109
name,
74110
output_sigma,
@@ -258,6 +294,7 @@ def R2D2M2CP(
258294
*broadcast_dims, dim = dims
259295
input_sigma = pt.as_tensor(input_sigma)
260296
output_sigma = pt.as_tensor(output_sigma)
297+
positive_probs = pt.as_tensor(positive_probs)
261298
with pm.Model(name) as model:
262299
if variables_importance is not None:
263300
if variance_explained is not None:
@@ -272,7 +309,7 @@ def R2D2M2CP(
272309
raise TypeError("Can't use variance explained with less than two variables")
273310
phi = pt.as_tensor(variance_explained)
274311
else:
275-
phi = 1 / len(model.coords[dim])
312+
phi = pt.as_tensor(1 / len(model.coords[dim]))
276313
if r2_std is not None:
277314
r2 = pm.Beta("r2", mu=r2, sigma=r2_std, dims=broadcast_dims)
278315
if positive_probs_std is not None:

0 commit comments

Comments
 (0)