Skip to content

Commit f9d1c55

Browse files
ferrineMaxim Kochurov
andauthored
add default to argument (#93)
Co-authored-by: Maxim Kochurov <[email protected]>
1 parent bb2359b commit f9d1c55

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

pymc_experimental/tests/test_prior_from_trace.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,3 +163,9 @@ def test_prior_from_idata(idata, user_param_cfg, coords, param_cfg):
163163
test_prior = pm.sample_prior_predictive(1)
164164
names = [p["name"] for p in param_cfg.values()]
165165
assert set(model.named_vars) == {"trace_prior_", *names}
166+
167+
168+
def test_empty(idata, coords):
169+
with pm.Model(coords=coords):
170+
priors = pmx.utils.prior.prior_from_idata(idata)
171+
assert not priors

pymc_experimental/utils/prior.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def prior_from_idata(
132132
idata: arviz.InferenceData,
133133
name="trace_prior_",
134134
*,
135-
var_names: Sequence[str],
135+
var_names: Sequence[str] = (),
136136
**kwargs: Union[ParamCfg, RVTransform, str, Tuple]
137137
) -> Dict[str, pt.TensorVariable]:
138138
"""
@@ -192,5 +192,7 @@ def prior_from_idata(
192192
... trace1 = pm.sample_prior_predictive(100)
193193
"""
194194
param_cfg = _parse_args(var_names=var_names, **kwargs)
195+
if not param_cfg:
196+
return {}
195197
flat_info = _flatten(idata, **param_cfg)
196198
return _mvn_prior_from_flat_info(name, flat_info)

0 commit comments

Comments
 (0)