Skip to content

Commit dff0aca

Browse files
sudarsan2k5brandonwillard
authored andcommitted
Add type hint None to specify_shape
1 parent a7ef6db commit dff0aca

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

aesara/tensor/shape.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
from aesara.tensor.var import TensorConstant, TensorVariable
2323

2424

25+
ShapeValueType = Union[None, np.integer, int, Variable]
26+
27+
2528
def register_shape_c_code(type, code, version=()):
2629
"""
2730
Tell Shape Op how to generate C code for an Aesara Type.
@@ -541,9 +544,7 @@ def c_code_cache_version(self):
541544

542545
def specify_shape(
543546
x: Union[np.ndarray, Number, Variable],
544-
shape: Union[
545-
int, List[Union[int, Variable]], Tuple[Union[int, Variable]], Variable
546-
],
547+
shape: Union[ShapeValueType, List[ShapeValueType], Tuple[ShapeValueType]],
547548
):
548549
"""Specify a fixed shape for a `Variable`.
549550

0 commit comments

Comments
 (0)