Skip to content

Commit 17fbeb3

Browse files
Handle case with multiple clients
1 parent 09bddf1 commit 17fbeb3

File tree

2 files changed

+51
-47
lines changed

2 files changed

+51
-47
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 42 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -191,53 +191,54 @@ def check_for_block_diag(x):
191191
)
192192

193193
# Check that the BlockDiagonal is an input to a Dot node:
194-
clients = list(get_clients_at_depth(fgraph, node, depth=1))
195-
if not clients or len(clients) > 1 or not isinstance(clients[0].op, Dot):
196-
return
194+
for client in get_clients_at_depth(fgraph, node, depth=1):
195+
if not isinstance(client.op, Dot):
196+
return
197197

198-
[dot_node] = clients
199-
op = dot_node.op
200-
x, y = dot_node.inputs
198+
op = client.op
199+
x, y = client.inputs
201200

202-
if not (check_for_block_diag(x) or check_for_block_diag(y)):
203-
return None
201+
if not (check_for_block_diag(x) or check_for_block_diag(y)):
202+
return None
204203

205-
# Case 1: Only one input is BlockDiagonal. In this case, multiply all components of the block-diagonal with the
206-
# non-block diagonal, and return a new block diagonal
207-
if check_for_block_diag(x) and not check_for_block_diag(y):
208-
components = x.owner.inputs
209-
y_splits = split(
210-
y,
211-
splits_size=[component.shape[-1] for component in components],
212-
n_splits=len(components),
213-
)
214-
new_components = [
215-
op(component, y_split) for component, y_split in zip(components, y_splits)
216-
]
217-
new_output = join(0, *new_components)
218-
219-
elif not check_for_block_diag(x) and check_for_block_diag(y):
220-
components = y.owner.inputs
221-
x_splits = split(
222-
x,
223-
splits_size=[component.shape[0] for component in components],
224-
n_splits=len(components),
225-
axis=1,
226-
)
204+
# Case 1: Only one input is BlockDiagonal. In this case, multiply all components of the block-diagonal with the
205+
# non-block diagonal, and return a new block diagonal
206+
if check_for_block_diag(x) and not check_for_block_diag(y):
207+
components = x.owner.inputs
208+
y_splits = split(
209+
y,
210+
splits_size=[component.shape[-1] for component in components],
211+
n_splits=len(components),
212+
)
213+
new_components = [
214+
op(component, y_split)
215+
for component, y_split in zip(components, y_splits)
216+
]
217+
new_output = join(0, *new_components)
218+
219+
elif not check_for_block_diag(x) and check_for_block_diag(y):
220+
components = y.owner.inputs
221+
x_splits = split(
222+
x,
223+
splits_size=[component.shape[0] for component in components],
224+
n_splits=len(components),
225+
axis=1,
226+
)
227227

228-
new_components = [
229-
op(x_split, component) for component, x_split in zip(components, x_splits)
230-
]
231-
new_output = join(1, *new_components)
228+
new_components = [
229+
op(x_split, component)
230+
for component, x_split in zip(components, x_splits)
231+
]
232+
new_output = join(1, *new_components)
232233

233-
# Case 2: Both inputs are BlockDiagonal. Do nothing
234-
else:
235-
# TODO: If shapes are statically known and all components have equal shapes, we could rewrite
236-
# this case to block_diag(*[dot(comp_1, comp_2) for comp_1, comp_2 in zip(x.owner.inputs, y.owner.inputs)])
237-
return None
234+
# Case 2: Both inputs are BlockDiagonal. Do nothing
235+
else:
236+
# TODO: If shapes are statically known and all components have equal shapes, we could rewrite
237+
# this case to block_diag(*[dot(comp_1, comp_2) for comp_1, comp_2 in zip(x.owner.inputs, y.owner.inputs)])
238+
return None
238239

239-
copy_stack_trace(node.outputs[0], new_output)
240-
return {dot_node.outputs[0]: new_output}
240+
copy_stack_trace(node.outputs[0], new_output)
241+
return {client.outputs[0]: new_output}
241242

242243

243244
@register_canonicalize

tests/tensor/rewriting/test_math.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4666,21 +4666,23 @@ def test_local_block_diag_dot_to_dot_block_diag(left_multiply):
46664666
b = tensor("b", shape=(2, 4))
46674667
c = tensor("c", shape=(4, 4))
46684668
d = tensor("d", shape=(10, 10))
4669+
e = tensor("e", shape=(10, 10))
46694670

46704671
x = pt.linalg.block_diag(a, b, c)
46714672

4673+
# Test multiple clients are all rewritten
46724674
if left_multiply:
4673-
out = x @ d
4675+
out = [x @ d, x @ e]
46744676
else:
4675-
out = d @ x
4677+
out = [d @ x, e @ x]
46764678

4677-
fn = pytensor.function([a, b, c, d], out, mode=rewrite_mode)
4679+
fn = pytensor.function([a, b, c, d, e], out, mode=rewrite_mode)
46784680
assert not any(
46794681
isinstance(node.op, BlockDiagonal) for node in fn.maker.fgraph.toposort()
46804682
)
46814683

46824684
fn_expected = pytensor.function(
4683-
[a, b, c, d],
4685+
[a, b, c, d, e],
46844686
out,
46854687
mode=rewrite_mode.excluding("local_block_diag_dot_to_dot_block_diag"),
46864688
)
@@ -4690,10 +4692,11 @@ def test_local_block_diag_dot_to_dot_block_diag(left_multiply):
46904692
b_val = rng.normal(size=b.type.shape).astype(b.type.dtype)
46914693
c_val = rng.normal(size=c.type.shape).astype(c.type.dtype)
46924694
d_val = rng.normal(size=d.type.shape).astype(d.type.dtype)
4695+
e_val = rng.normal(size=e.type.shape).astype(e.type.dtype)
46934696

46944697
np.testing.assert_allclose(
4695-
fn(a_val, b_val, c_val, d_val),
4696-
fn_expected(a_val, b_val, c_val, d_val),
4698+
fn(a_val, b_val, c_val, d_val, e_val),
4699+
fn_expected(a_val, b_val, c_val, d_val, e_val),
46974700
atol=1e-6 if config.floatX == "float32" else 1e-12,
46984701
rtol=1e-6 if config.floatX == "float32" else 1e-12,
46994702
)

0 commit comments

Comments
 (0)