Skip to content

Commit 09bddf1

Browse files
Use rewrite_mode defined in test_math.py for testing
1 parent 3b66eba commit 09bddf1

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tests/tensor/rewriting/test_math.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4674,15 +4674,15 @@ def test_local_block_diag_dot_to_dot_block_diag(left_multiply):
46744674
else:
46754675
out = d @ x
46764676

4677-
fn = pytensor.function([a, b, c, d], out)
4677+
fn = pytensor.function([a, b, c, d], out, mode=rewrite_mode)
46784678
assert not any(
46794679
isinstance(node.op, BlockDiagonal) for node in fn.maker.fgraph.toposort()
46804680
)
46814681

46824682
fn_expected = pytensor.function(
46834683
[a, b, c, d],
46844684
out,
4685-
mode=get_default_mode().excluding("local_block_diag_dot_to_dot_block_diag"),
4685+
mode=rewrite_mode.excluding("local_block_diag_dot_to_dot_block_diag"),
46864686
)
46874687

46884688
rng = np.random.default_rng()

0 commit comments

Comments
 (0)