Skip to content

Commit aa616e6

Browse files
committed
Pass dtype directly to zeros_like
1 parent 935ce79 commit aa616e6

File tree

4 files changed

+28
-28
lines changed

4 files changed

+28
-28
lines changed

pytensor/ifelse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def grad(self, ins, grads):
273273
# `condition` does affect the elements of the output so it is connected.
274274
# For the sake of making the gradient convenient we assume that
275275
# condition + epsilon always triggers the same branch as condition
276-
condition_grad = condition.zeros_like().astype(config.floatX)
276+
condition_grad = condition.zeros_like(dtype=config.floatX)
277277

278278
return [
279279
condition_grad,

pytensor/scalar/basic.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1323,8 +1323,8 @@ def L_op(self, inputs, outputs, output_gradients):
13231323
x, y = inputs
13241324
assert outputs[0].type == bool
13251325
return [
1326-
x.zeros_like().astype(config.floatX),
1327-
y.zeros_like().astype(config.floatX),
1326+
x.zeros_like(dtype=config.floatX),
1327+
y.zeros_like(dtype=config.floatX),
13281328
]
13291329

13301330
def c_code_cache_version(self):
@@ -1358,7 +1358,7 @@ def output_types(self, *input_dtypes):
13581358
def L_op(self, inputs, outputs, output_gradients):
13591359
(x,) = inputs
13601360
assert outputs[0].type == bool
1361-
return [x.zeros_like().astype(config.floatX)]
1361+
return [x.zeros_like(dtype=config.floatX)]
13621362

13631363
def c_code_cache_version(self):
13641364
super_version = super().c_code_cache_version()
@@ -1577,7 +1577,7 @@ def get_grad(self, elem):
15771577
)
15781578
raise NotImplementedError(msg)
15791579
elif elem.type in discrete_types:
1580-
return elem.zeros_like().astype(config.floatX)
1580+
return elem.zeros_like(dtype=config.floatX)
15811581
else:
15821582
return elem.zeros_like()
15831583

@@ -1611,13 +1611,13 @@ def L_op(self, inputs, outputs, gout):
16111611
second_part = switch(cond, 0.0, gz)
16121612

16131613
if outputs[0].type in discrete_types:
1614-
first_part = ift.zeros_like(config.floatX)
1615-
second_part = iff.zeros_like(config.floatX)
1614+
first_part = ift.zeros_like(dtype=config.floatX)
1615+
second_part = iff.zeros_like(dtype=config.floatX)
16161616

16171617
# cond does affect the elements of the output so it is connected.
16181618
# For the sake of making the gradient convenient we assume that
16191619
# condition + epsilon always triggers the same branch as condition
1620-
condition_grad = cond.zeros_like().astype(config.floatX)
1620+
condition_grad = cond.zeros_like(dtype=config.floatX)
16211621

16221622
return (condition_grad, first_part, second_part)
16231623

@@ -1644,7 +1644,7 @@ def output_types(self, *input_types):
16441644
return upcast_out(*input_types[0])
16451645

16461646
def grad(self, inputs, output_gradients):
1647-
return [inputs[0].zeros_like().astype(config.floatX)]
1647+
return [inputs[0].zeros_like(dtype=config.floatX)]
16481648

16491649

16501650
class BinaryBitOp(BinaryScalarOp):
@@ -1664,8 +1664,8 @@ def output_types(self, *input_types):
16641664
def grad(self, inputs, output_gradients):
16651665
a, b = inputs
16661666
return [
1667-
a.zeros_like().astype(config.floatX),
1668-
b.zeros_like().astype(config.floatX),
1667+
a.zeros_like(dtype=config.floatX),
1668+
b.zeros_like(dtype=config.floatX),
16691669
]
16701670

16711671

@@ -1776,8 +1776,8 @@ def L_op(self, inputs, outputs, gout):
17761776

17771777
if outputs[0].type in discrete_types:
17781778
return [
1779-
x.zeros_like().astype(config.floatX),
1780-
y.zeros_like().astype(config.floatX),
1779+
x.zeros_like(dtype=config.floatX),
1780+
y.zeros_like(dtype=config.floatX),
17811781
]
17821782
# This form handle the case when both value are the same.
17831783
# In that case, gx will be gz, gy will be 0.
@@ -1818,8 +1818,8 @@ def L_op(self, inputs, outputs, gout):
18181818

18191819
if outputs[0].type in discrete_types:
18201820
return [
1821-
x.zeros_like().astype(config.floatX),
1822-
y.zeros_like().astype(config.floatX),
1821+
x.zeros_like(dtype=config.floatX),
1822+
y.zeros_like(dtype=config.floatX),
18231823
]
18241824
# This form handle the case when both value are the same.
18251825
# In that case, gx will be gz, gy will be 0.
@@ -1861,7 +1861,7 @@ def L_op(self, inputs, outputs, gout):
18611861
retval = []
18621862
for ii, inp in enumerate(inputs):
18631863
if hasattr(inp, "zeros_like"):
1864-
retval.append(inp.zeros_like().astype(config.floatX))
1864+
retval.append(inp.zeros_like(dtype=config.floatX))
18651865
else:
18661866
retval.append(grad_undefined(self, ii, inp))
18671867
else:
@@ -1937,7 +1937,7 @@ def grad(self, inputs, gout):
19371937
)
19381938

19391939
if output_type in discrete_types:
1940-
return [ipt.zeros_like().astype(config.floatX) for ipt in inputs]
1940+
return [ipt.zeros_like(dtype=config.floatX) for ipt in inputs]
19411941

19421942
for input in inputs:
19431943
if gz.type in complex_types:
@@ -1980,8 +1980,8 @@ def L_op(self, inputs, outputs, gout):
19801980
raise NotImplementedError()
19811981
if outputs[0].type in discrete_types:
19821982
return [
1983-
x.zeros_like().astype(config.floatX),
1984-
y.zeros_like().astype(config.floatX),
1983+
x.zeros_like(dtype=config.floatX),
1984+
y.zeros_like(dtype=config.floatX),
19851985
]
19861986

19871987
first_part = gz
@@ -2293,8 +2293,8 @@ def L_op(self, inputs, outputs, gout):
22932293

22942294
if outputs[0].type in discrete_types:
22952295
return [
2296-
x.zeros_like().astype(config.floatX),
2297-
y.zeros_like().astype(config.floatX),
2296+
x.zeros_like(dtype=config.floatX),
2297+
y.zeros_like(dtype=config.floatX),
22982298
]
22992299

23002300
first_part = gz * y * x ** (y - 1)
@@ -2385,7 +2385,7 @@ def L_op(self, inputs, outputs, gout):
23852385

23862386
def handle_int(v):
23872387
if outputs[0].type in int_types:
2388-
return v.zeros_like().astype(config.floatX)
2388+
return v.zeros_like(dtype=config.floatX)
23892389
return v
23902390

23912391
return list(map(handle_int, [gx, gmn, gmx]))
@@ -2422,7 +2422,7 @@ def grad(self, inputs, gout):
24222422
# to deal with real-valued inputs by rounding them to the
24232423
# nearest integer. f(x+eps) thus equals f(x) so the gradient
24242424
# is zero, not disconnected or undefined
2425-
return DisconnectedType()(), y.zeros_like()
2425+
return DisconnectedType()(), y.zeros_like(dtype=config.floatX)
24262426

24272427

24282428
second = Second(transfer_type(1), name="second")
@@ -2494,7 +2494,7 @@ def grad(self, inputs, gout):
24942494
if self.o_type in continuous_types:
24952495
return [gz]
24962496
else:
2497-
return [x.zeros_like().astype(config.floatX)]
2497+
return [x.zeros_like(dtype=config.floatX)]
24982498

24992499
def c_code_cache_version(self):
25002500
s = super().c_code_cache_version()
@@ -2715,7 +2715,7 @@ def impl(self, x):
27152715
def grad(self, inputs, gout):
27162716
(x,) = inputs
27172717
(gz,) = gout
2718-
return [x.zeros_like().astype(config.floatX)]
2718+
return [x.zeros_like(dtype=config.floatX)]
27192719

27202720
def c_code(self, node, name, inputs, outputs, sub):
27212721
(x,) = inputs

pytensor/tensor/basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -589,7 +589,7 @@ def grad(self, inp, grads):
589589
# Currently, pytensor.grad insists that the dtype of the returned
590590
# gradient has a float dtype, so we use floatX.
591591
if s.type.dtype in discrete_dtypes:
592-
return [s.zeros_like().astype(config.floatX)]
592+
return [s.zeros_like(dtype=config.floatX)]
593593

594594
raise NotImplementedError("grad not implemented for complex dtypes")
595595

@@ -1876,7 +1876,7 @@ def infer_shape(self, fgraph, node, ishapes):
18761876
def grad(self, inputs, output_gradients):
18771877
# If the output is of an integer dtype, no gradient shall pass
18781878
if self.dtype in discrete_dtypes:
1879-
return [ipt.zeros_like().astype(config.floatX) for ipt in inputs]
1879+
return [ipt.zeros_like(dtype=config.floatX) for ipt in inputs]
18801880

18811881
grads = [output_gradients[0][i] for i in range(len(inputs))]
18821882
return grads

pytensor/tensor/subtensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -946,7 +946,7 @@ def grad(self, inputs, grads):
946946
x = inputs[0]
947947
rest = inputs[1:]
948948
if x.dtype in discrete_dtypes:
949-
first = x.zeros_like().astype(config.floatX)
949+
first = x.zeros_like(dtype=config.floatX)
950950
else:
951951
# For best optimization, we let this as an inc.
952952
# This allow the opt local_IncSubtensor_serialize to apply first.

0 commit comments

Comments
 (0)