Skip to content

Commit 2dd249c

Browse files
Enforce positive domain of sqr(sqrt(x))
1 parent 8ebf636 commit 2dd249c

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -424,12 +424,10 @@ def local_sqrt_sqr(fgraph, node):
424424

425425
# Case for sqr(sqrt(x)) -> x
426426
if isinstance(prev_op, ps.Sqr) and isinstance(node_op, ps.Sqrt):
427-
new_out = x.owner.inputs[0]
427+
x = x.owner.inputs[0]
428428
old_out = node.outputs[0]
429+
new_out = switch(ge(x, 0), x, np.asarray(np.nan, old_out.dtype))
429430

430-
# Handle potential integer to float cast by sqrt
431-
if x.dtype != old_out.dtype:
432-
new_out = cast(new_out, old_out.dtype)
433431
return [new_out]
434432

435433

0 commit comments

Comments
 (0)