15
15
16
16
from typing import Sequence , Union
17
17
18
+ import numpy as np
18
19
import pymc as pm
19
20
import pytensor .tensor as pt
20
21
@@ -69,6 +70,41 @@ def _R2D2M2CP_beta(
69
70
return beta
70
71
71
72
73
+ def _broadcast_as_dims (* values , dims ):
74
+ model = pm .modelcontext (None )
75
+ shape = [len (model .coords [d ]) for d in dims ]
76
+ return tuple (np .broadcast_to (v , shape ) for v in values )
77
+
78
+
79
+ def _psi_masked (positive_probs , positive_probs_std , * , dims ):
80
+ if not (
81
+ isinstance (positive_probs , pt .Constant ) and isinstance (positive_probs_std , pt .Constant )
82
+ ):
83
+ raise TypeError (
84
+ "Only constant values for positive_probs and positive_probs_std are accepted"
85
+ )
86
+ positive_probs , positive_probs_std = _broadcast_as_dims (
87
+ positive_probs .data , positive_probs_std .data , dims = dims
88
+ )
89
+ mask = ~ np .bitwise_or (positive_probs == 1 , positive_probs == 0 )
90
+ if np .bitwise_and (~ mask , positive_probs_std != 0 ).any ():
91
+ raise ValueError ("Can't have both positive_probs == '1 or 0' and positive_probs_std != 0" )
92
+ if (~ mask ).any ():
93
+ r_idx = mask .nonzero ()
94
+ with pm .Model ("psi" ):
95
+ psi = pm .Beta (
96
+ "masked" ,
97
+ mu = positive_probs [r_idx ],
98
+ sigma = positive_probs_std [r_idx ],
99
+ shape = len (r_idx [0 ]),
100
+ )
101
+ psi = pt .set_subtensor (pt .as_tensor (positive_probs )[r_idx ], psi )
102
+ psi = pm .Deterministic ("psi" , psi , dims = dims )
103
+ else :
104
+ psi = pm .Beta ("psi" , mu = positive_probs , sigma = positive_probs_std , dims = dims )
105
+ return mask , psi
106
+
107
+
72
108
def R2D2M2CP (
73
109
name ,
74
110
output_sigma ,
@@ -258,6 +294,7 @@ def R2D2M2CP(
258
294
* broadcast_dims , dim = dims
259
295
input_sigma = pt .as_tensor (input_sigma )
260
296
output_sigma = pt .as_tensor (output_sigma )
297
+ positive_probs = pt .as_tensor (positive_probs )
261
298
with pm .Model (name ) as model :
262
299
if variables_importance is not None :
263
300
if variance_explained is not None :
@@ -272,7 +309,7 @@ def R2D2M2CP(
272
309
raise TypeError ("Can't use variance explained with less than two variables" )
273
310
phi = pt .as_tensor (variance_explained )
274
311
else :
275
- phi = 1 / len (model .coords [dim ])
312
+ phi = pt . as_tensor ( 1 / len (model .coords [dim ]) )
276
313
if r2_std is not None :
277
314
r2 = pm .Beta ("r2" , mu = r2 , sigma = r2_std , dims = broadcast_dims )
278
315
if positive_probs_std is not None :
0 commit comments