Skip to content

Commit 3d747af

Browse files
Introduce StatsBijection to combine stats dicts
1 parent 10c66b4 commit 3d747af

File tree

3 files changed

+83
-4
lines changed

3 files changed

+83
-4
lines changed

pymc/blocking.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@
3131

3232
T = TypeVar("T")
3333
PointType: TypeAlias = Dict[str, np.ndarray]
34-
StatsType: TypeAlias = List[Dict[str, Any]]
34+
StatsDict: TypeAlias = Dict[str, Any]
35+
StatsType: TypeAlias = List[StatsDict]
3536

3637
# `point_map_info` is a tuple of tuples containing `(name, shape, dtype)` for
3738
# each of the raveled variables.

pymc/step_methods/compound.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,15 @@
1818
@author: johnsalvatier
1919
"""
2020

21-
2221
from abc import ABC, abstractmethod
2322
from enum import IntEnum, unique
24-
from typing import Dict, List, Tuple
23+
from typing import Dict, List, Sequence, Tuple, Union
2524

2625
import numpy as np
2726

2827
from pytensor.graph.basic import Variable
2928

30-
from pymc.blocking import PointType, StatsType
29+
from pymc.blocking import PointType, StatsDict, StatsType
3130
from pymc.model import modelcontext
3231

3332
__all__ = ("Competence", "CompoundStep")
@@ -165,3 +164,43 @@ def reset_tuning(self):
165164
@property
166165
def vars(self):
167166
return [var for method in self.methods for var in method.vars]
167+
168+
169+
def flatten_steps(step: Union[BlockedStep, CompoundStep]) -> List[BlockedStep]:
170+
"""Flatten a hierarchy of step methods to a list."""
171+
if isinstance(step, BlockedStep):
172+
return [step]
173+
steps = []
174+
if not isinstance(step, CompoundStep):
175+
raise ValueError(f"Unexpected type of step method: {step}")
176+
for sm in step.methods:
177+
steps += flatten_steps(sm)
178+
return steps
179+
180+
181+
class StatsBijection:
182+
"""Map between a `list` of stats to `dict` of stats."""
183+
184+
def __init__(self, sampler_stats_dtypes: Sequence[Dict[str, type]]) -> None:
185+
# Keep a list of flat vs. original stat names
186+
self._stat_groups: List[List[Tuple[str, str]]] = [
187+
[(f"sampler_{s}__{statname}", statname) for statname, _ in names_dtypes.items()]
188+
for s, names_dtypes in enumerate(sampler_stats_dtypes)
189+
]
190+
191+
def map(self, stats_list: StatsType) -> StatsDict:
192+
"""Combine stats dicts of multiple samplers into one dict."""
193+
stats_dict = {}
194+
for s, sts in enumerate(stats_list):
195+
for statname, sval in sts.items():
196+
sname = f"sampler_{s}__{statname}"
197+
stats_dict[sname] = sval
198+
return stats_dict
199+
200+
def rmap(self, stats_dict: StatsDict) -> StatsType:
201+
"""Split a global stats dict into a list of sampler-wise stats dicts."""
202+
stats_list = []
203+
for namemap in self._stat_groups:
204+
d = {statname: stats_dict[sname] for sname, statname in namemap}
205+
stats_list.append(d)
206+
return stats_list

pymc/tests/step_methods/test_compound.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
Metropolis,
2626
Slice,
2727
)
28+
from pymc.step_methods.compound import StatsBijection, flatten_steps
2829
from pymc.tests.helpers import StepMethodTester, fast_unstable_sampling_mode
2930
from pymc.tests.models import simple_2model_continuous
3031

@@ -91,3 +92,41 @@ def test_compound_step(self):
9192
step2 = NUTS([c2])
9293
step = CompoundStep([step1, step2])
9394
assert {m.rvs_to_values[c1], m.rvs_to_values[c2]} == set(step.vars)
95+
96+
97+
class TestStatsBijection:
98+
def test_flatten_steps(self):
99+
with pm.Model():
100+
a = pm.Normal("a")
101+
b = pm.Normal("b")
102+
c = pm.Normal("c")
103+
s1 = Metropolis([a])
104+
s2 = Metropolis([b])
105+
c1 = CompoundStep([s1, s2])
106+
s3 = NUTS([c])
107+
c2 = CompoundStep([c1, s3])
108+
assert flatten_steps(s1) == [s1]
109+
assert flatten_steps(c2) == [s1, s2, s3]
110+
with pytest.raises(ValueError, match="Unexpected type"):
111+
flatten_steps("not a step")
112+
113+
def test_stats_bijection(self):
114+
step_stats_dtypes = [
115+
{"a": float, "b": int},
116+
{"a": float, "c": int},
117+
]
118+
bij = StatsBijection(step_stats_dtypes)
119+
stats_l = [
120+
dict(a=1.5, b=3),
121+
dict(a=2.5, c=4),
122+
]
123+
stats_d = bij.map(stats_l)
124+
assert isinstance(stats_d, dict)
125+
assert stats_d["sampler_0__a"] == 1.5
126+
assert stats_d["sampler_0__b"] == 3
127+
assert stats_d["sampler_1__a"] == 2.5
128+
assert stats_d["sampler_1__c"] == 4
129+
rev = bij.rmap(stats_d)
130+
assert isinstance(rev, list)
131+
assert len(rev) == len(stats_l)
132+
assert rev == stats_l

0 commit comments

Comments
 (0)