Skip to content

Commit 285263d

Browse files
committed
Allow ordering of multi-output OpFromGraph variables in toposort_replace
1 parent 71d25e7 commit 285263d

File tree

2 files changed

+68
-3
lines changed

2 files changed

+68
-3
lines changed

pymc/pytensorf.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1122,10 +1122,40 @@ def toposort_replace(
11221122
fgraph: FunctionGraph, replacements: Sequence[tuple[Variable, Variable]], reverse: bool = False
11231123
) -> None:
11241124
"""Replace multiple variables in place in topological order."""
1125-
toposort = fgraph.toposort()
1125+
fgraph_toposort = {node: i for i, node in enumerate(fgraph.toposort())}
1126+
_inner_fgraph_toposorts = {} # Cache inner toposorts
1127+
1128+
def _nested_toposort_index(var, fgraph_toposort) -> tuple[int]:
1129+
"""Compute position of variable in fgraph toposort.
1130+
1131+
When a variable is an OpFromGraph output, extend output with the toposort index of the inner graph(s).
1132+
1133+
This allows ordering variables that come from the same OpFromGraph.
1134+
"""
1135+
if not var.owner:
1136+
return (-1,)
1137+
1138+
index = fgraph_toposort[var.owner]
1139+
1140+
# Recurse into OpFromGraphs
1141+
# TODO: Could also recurse into Scans
1142+
if isinstance(var.owner.op, OpFromGraph):
1143+
inner_fgraph = var.owner.op.fgraph
1144+
1145+
if inner_fgraph not in _inner_fgraph_toposorts:
1146+
_inner_fgraph_toposorts[inner_fgraph] = {
1147+
node: i for i, node in enumerate(inner_fgraph.toposort())
1148+
}
1149+
1150+
inner_fgraph_toposort = _inner_fgraph_toposorts[inner_fgraph]
1151+
inner_var = inner_fgraph.outputs[var.owner.outputs.index(var)]
1152+
return (index, *_nested_toposort_index(inner_var, inner_fgraph_toposort))
1153+
else:
1154+
return (index,)
1155+
11261156
sorted_replacements = sorted(
11271157
replacements,
1128-
key=lambda pair: toposort.index(pair[0].owner) if pair[0].owner else -1,
1158+
key=lambda pair: _nested_toposort_index(pair[0], fgraph_toposort),
11291159
reverse=reverse,
11301160
)
11311161
fgraph.replace_all(sorted_replacements, import_missing=True)

tests/test_initial_point.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717
import pytensor.tensor as pt
1818
import pytest
1919

20+
from pytensor.compile.builders import OpFromGraph
2021
from pytensor.tensor.random.op import RandomVariable
2122

2223
import pymc as pm
2324

24-
from pymc.distributions.distribution import support_point
25+
from pymc.distributions.distribution import _support_point, support_point
2526
from pymc.initial_point import make_initial_point_fn, make_initial_point_fns_per_chain
2627

2728

@@ -192,6 +193,40 @@ def test_string_overrides_work(self):
192193
assert np.isclose(iv["B_log__"], 0)
193194
assert iv["C_log__"] == 0
194195

196+
@pytest.mark.parametrize("reverse_rvs", [False, True])
197+
def test_dependent_initval_from_OFG(self, reverse_rvs):
198+
class MyTestOp(OpFromGraph):
199+
pass
200+
201+
@_support_point.register(MyTestOp)
202+
def my_test_op_support_point(op, out):
203+
out1, out2 = out.owner.outputs
204+
if out is out1:
205+
return out1
206+
else:
207+
return out1 * 4
208+
209+
out1 = pt.zeros(())
210+
out2 = out1 * 2
211+
rv_op = MyTestOp([], [out1, out2])
212+
213+
with pm.Model() as model:
214+
A, B = rv_op()
215+
if reverse_rvs:
216+
model.register_rv(B, "B")
217+
model.register_rv(A, "A")
218+
else:
219+
model.register_rv(A, "A")
220+
model.register_rv(B, "B")
221+
222+
assert model.initial_point() == {"A": 0, "B": 0}
223+
224+
model.set_initval(A, 1)
225+
assert model.initial_point() == {"A": 1, "B": 4}
226+
227+
model.set_initval(B, 3)
228+
assert model.initial_point() == {"A": 1, "B": 3}
229+
195230

196231
class TestSupportPoint:
197232
def test_basic(self):

0 commit comments

Comments
 (0)