Skip to content

Commit fd8d2da

Browse files
Merge branch 'main' into time
2 parents 9586329 + 471657a commit fd8d2da

File tree

12 files changed

+119
-70
lines changed

12 files changed

+119
-70
lines changed

aesara/compile/debugmode.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -807,7 +807,7 @@ def _get_preallocated_maps(
807807
for r in considered_outputs:
808808
if isinstance(r.type, TensorType):
809809
# Build a C-contiguous buffer
810-
new_buf = r.type.value_zeros(r_vals[r].shape)
810+
new_buf = np.empty(r_vals[r].shape, dtype=r.type.dtype)
811811
assert new_buf.flags["C_CONTIGUOUS"]
812812
new_buf[...] = np.asarray(def_val).astype(r.type.dtype)
813813

@@ -875,7 +875,8 @@ def _get_preallocated_maps(
875875
buf_shape.append(s)
876876
else:
877877
buf_shape.append(s * 2)
878-
new_buf = r.type.value_zeros(buf_shape)
878+
879+
new_buf = np.empty(buf_shape, dtype=r.type.dtype)
879880
new_buf[...] = np.asarray(def_val).astype(r.type.dtype)
880881
init_strided[r] = new_buf
881882

@@ -950,7 +951,7 @@ def _get_preallocated_maps(
950951
max((s + sd), 0)
951952
for s, sd in zip(r_vals[r].shape, r_shape_diff)
952953
]
953-
new_buf = r.type.value_zeros(out_shape)
954+
new_buf = np.empty(out_shape, dtype=r.type.dtype)
954955
new_buf[...] = np.asarray(def_val).astype(r.type.dtype)
955956
wrong_size[r] = new_buf
956957

aesara/scan/op.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,9 @@ def check_broadcast(v1, v2):
161161
which may wrongly be interpreted as broadcastable.
162162
163163
"""
164-
if not hasattr(v1, "broadcastable") and not hasattr(v2, "broadcastable"):
164+
if not isinstance(v1.type, TensorType) and not isinstance(v2.type, TensorType):
165165
return
166+
166167
msg = (
167168
"The broadcast pattern of the output of scan (%s) is "
168169
"inconsistent with the one provided in `output_info` "
@@ -173,13 +174,13 @@ def check_broadcast(v1, v2):
173174
"them consistent, e.g. using aesara.tensor."
174175
"{unbroadcast, specify_broadcastable}."
175176
)
176-
size = min(len(v1.broadcastable), len(v2.broadcastable))
177+
size = min(v1.type.ndim, v2.type.ndim)
177178
for n, (b1, b2) in enumerate(
178-
zip(v1.broadcastable[-size:], v2.broadcastable[-size:])
179+
zip(v1.type.broadcastable[-size:], v2.type.broadcastable[-size:])
179180
):
180181
if b1 != b2:
181-
a1 = n + size - len(v1.broadcastable) + 1
182-
a2 = n + size - len(v2.broadcastable) + 1
182+
a1 = n + size - v1.type.ndim + 1
183+
a2 = n + size - v2.type.ndim + 1
183184
raise TypeError(msg % (v1.type, v2.type, a1, b1, b2, a2))
184185

185186

@@ -200,7 +201,7 @@ def copy_var_format(var, as_var):
200201
rval = as_var.type.filter_variable(rval)
201202
else:
202203
tmp = as_var.type.clone(
203-
shape=(tuple(var.broadcastable[:1]) + tuple(as_var.broadcastable))
204+
shape=(tuple(var.type.shape[:1]) + tuple(as_var.type.shape))
204205
)
205206
rval = tmp.filter_variable(rval)
206207
return rval
@@ -628,6 +629,7 @@ def validate_inner_graph(self):
628629
type_input = self.inner_inputs[inner_iidx].type
629630
type_output = self.inner_outputs[inner_oidx].type
630631
if (
632+
# TODO: Use the `Type` interface for this
631633
type_input.dtype != type_output.dtype
632634
or type_input.broadcastable != type_output.broadcastable
633635
):
@@ -805,7 +807,9 @@ def tensorConstructor(shape, dtype):
805807
# output sequence
806808
o = outputs[idx]
807809
self.output_types.append(
808-
typeConstructor((False,) + o.type.broadcastable, o.type.dtype)
810+
# TODO: What can we actually say about the shape of this
811+
# added dimension?
812+
typeConstructor((None,) + o.type.shape, o.type.dtype)
809813
)
810814

811815
idx += len(info.mit_mot_out_slices[jdx])
@@ -816,7 +820,9 @@ def tensorConstructor(shape, dtype):
816820

817821
for o in outputs[idx:end]:
818822
self.output_types.append(
819-
typeConstructor((False,) + o.type.broadcastable, o.type.dtype)
823+
# TODO: What can we actually say about the shape of this
824+
# added dimension?
825+
typeConstructor((None,) + o.type.shape, o.type.dtype)
820826
)
821827

822828
# shared outputs + possibly the ending condition
@@ -1380,11 +1386,13 @@ def prepare_fgraph(self, fgraph):
13801386
# the output value, possibly inplace, at the end of the
13811387
# function execution. Also, since an update is defined,
13821388
# a default value must also be (this is verified by
1383-
# DebugMode). Use an array of size 0 with the correct
1384-
# ndim and dtype (use a shape of 1 on broadcastable
1385-
# dimensions, and 0 on the others).
1386-
default_shape = [1 if _b else 0 for _b in inp.broadcastable]
1387-
default_val = inp.type.value_zeros(default_shape)
1389+
# DebugMode).
1390+
# TODO FIXME: Why do we need a "default value" here?
1391+
# This sounds like a serious design issue.
1392+
default_shape = tuple(
1393+
s if s is not None else 0 for s in inp.type.shape
1394+
)
1395+
default_val = np.empty(default_shape, dtype=inp.type.dtype)
13881396
wrapped_inp = In(
13891397
variable=inp,
13901398
value=default_val,
@@ -2318,8 +2326,8 @@ def infer_shape(self, fgraph, node, input_shapes):
23182326
# equivalent (if False). Here, we only need the variable.
23192327
v_shp_i = validator.check(shp_i)
23202328
if v_shp_i is None:
2321-
if hasattr(r, "broadcastable") and r.broadcastable[i]:
2322-
shp.append(1)
2329+
if r.type.shape[i] is not None:
2330+
shp.append(r.type.shape[i])
23232331
else:
23242332
shp.append(Shape_i(i)(r))
23252333
else:

aesara/sparse/type.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -224,14 +224,6 @@ def get_size(self, shape_info):
224224
+ (shape_info[2] + shape_info[3]) * np.dtype("int32").itemsize
225225
)
226226

227-
def value_zeros(self, shape):
228-
matrix_constructor = self.format_cls.get(self.format)
229-
230-
if matrix_constructor is None:
231-
raise ValueError(f"Sparse matrix type {self.format} not found in SciPy")
232-
233-
return matrix_constructor(shape, dtype=self.dtype)
234-
235227
def __eq__(self, other):
236228
res = super().__eq__(other)
237229

aesara/tensor/basic.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from collections.abc import Sequence
1111
from functools import partial
1212
from numbers import Number
13-
from typing import Optional
13+
from typing import TYPE_CHECKING, Optional
1414
from typing import Sequence as TypeSequence
1515
from typing import Tuple, Union
1616
from typing import cast as type_cast
@@ -68,6 +68,10 @@
6868
from aesara.tensor.var import TensorConstant, TensorVariable, get_unique_value
6969

7070

71+
if TYPE_CHECKING:
72+
from aesara.tensor import TensorLike
73+
74+
7175
def __oplist_tag(thing, tag):
7276
tags = getattr(thing, "__oplist_tags", [])
7377
tags.append(tag)
@@ -1334,11 +1338,25 @@ def identity_like(x, dtype: Optional[Union[str, np.generic, np.dtype]] = None):
13341338
return eye(_x.shape[0], _x.shape[1], k=0, dtype=dtype)
13351339

13361340

1337-
def infer_broadcastable(shape):
1338-
"""Infer the broadcastable dimensions for `shape`.
1341+
def infer_static_shape(
1342+
shape: Union[Variable, TypeSequence[Union[Variable, int]]]
1343+
) -> Tuple[TypeSequence["TensorLike"], TypeSequence[Optional[int]]]:
1344+
"""Infer the static shapes implied by the potentially symbolic elements in `shape`.
1345+
1346+
`shape` will be validated and constant folded. As a result, this function
1347+
can be expensive and shouldn't be used unless absolutely necessary.
1348+
1349+
It mostly exists as a hold-over from pre-static shape times, when it was
1350+
required in order to produce correct broadcastable arrays and prevent
1351+
some graphs from being unusable. Now, it is no longer strictly required,
1352+
so don't use it unless you want the same shape graphs to be rewritten
1353+
multiple times during graph construction.
1354+
1355+
Returns
1356+
-------
1357+
A validated sequence of symbolic shape values, and a sequence of
1358+
``None``/``int`` values that can be used as `TensorType.shape` values.
13391359
1340-
`shape` will be validated and constant folded in order to determine
1341-
which dimensions are broadcastable (i.e. equal to ``1``).
13421360
"""
13431361
from aesara.tensor.rewriting.basic import topo_constant_folding
13441362
from aesara.tensor.rewriting.shape import ShapeFeature
@@ -1362,9 +1380,10 @@ def check_type(s):
13621380
clone=True,
13631381
)
13641382
folded_shape = rewrite_graph(shape_fg, custom_rewrite=topo_constant_folding).outputs
1365-
1366-
bcast = tuple(getattr(s, "data", s) == 1 for s in folded_shape)
1367-
return sh, bcast
1383+
static_shape = tuple(
1384+
s.data.item() if isinstance(s, Constant) else None for s in folded_shape
1385+
)
1386+
return sh, static_shape
13681387

13691388

13701389
class Alloc(COp):
@@ -1394,15 +1413,15 @@ class Alloc(COp):
13941413

13951414
def make_node(self, value, *shape):
13961415
v = as_tensor_variable(value)
1397-
sh, bcast = infer_broadcastable(shape)
1416+
sh, static_shape = infer_static_shape(shape)
13981417
if v.ndim > len(sh):
13991418
raise TypeError(
14001419
"The Alloc value to use has more dimensions"
14011420
" than the specified dimensions",
14021421
v.ndim,
14031422
len(sh),
14041423
)
1405-
otype = TensorType(dtype=v.dtype, shape=bcast)
1424+
otype = TensorType(dtype=v.dtype, shape=static_shape)
14061425
return Apply(self, [v] + sh, [otype()])
14071426

14081427
def perform(self, node, inputs, out_):
@@ -3823,8 +3842,8 @@ def typecode(self):
38233842
return np.dtype(self.dtype).num
38243843

38253844
def make_node(self, *_shape):
3826-
_shape, bcast = infer_broadcastable(_shape)
3827-
otype = TensorType(dtype=self.dtype, shape=bcast)
3845+
_shape, static_shape = infer_static_shape(_shape)
3846+
otype = TensorType(dtype=self.dtype, shape=static_shape)
38283847
output = otype()
38293848

38303849
output.tag.values_eq_approx = values_eq_approx_always_true

aesara/tensor/extra_ops.py

Lines changed: 46 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from aesara.raise_op import Assert
2424
from aesara.scalar import int32 as int_t
2525
from aesara.scalar import upcast
26+
from aesara.scalar.basic import Composite
2627
from aesara.tensor import basic as at
2728
from aesara.tensor import get_vector_length
2829
from aesara.tensor.exceptions import NotScalarConstantError
@@ -1552,16 +1553,32 @@ def broadcast_shape_iter(
15521553
# be broadcastable or equal to the one non-broadcastable
15531554
# constant `const_nt_shape_var`.
15541555
assert_dim = Assert("Could not broadcast dimensions")
1556+
1557+
scalar_nonconst_nb_shapes = [
1558+
at.scalar_from_tensor(s)
1559+
if isinstance(s.type, TensorType)
1560+
else s
1561+
for s in nonconst_nb_shapes
1562+
]
1563+
1564+
dummy_nonconst_nb_shapes = [
1565+
aes.get_scalar_type(dtype=v.dtype)()
1566+
for v in scalar_nonconst_nb_shapes
1567+
]
15551568
assert_cond = reduce(
15561569
aes.and_,
15571570
(
15581571
aes.or_(
15591572
aes.eq(nbv, one_at), aes.eq(nbv, const_nt_shape_var)
15601573
)
1561-
for nbv in nonconst_nb_shapes
1574+
for nbv in dummy_nonconst_nb_shapes
15621575
),
15631576
)
1564-
bcast_dim = assert_dim(const_nt_shape_var, assert_cond)
1577+
assert_cond_op = Composite(dummy_nonconst_nb_shapes, [assert_cond])
1578+
1579+
bcast_dim = assert_dim(
1580+
const_nt_shape_var, assert_cond_op(*scalar_nonconst_nb_shapes)
1581+
)
15651582
else:
15661583
bcast_dim = const_nt_shape_var
15671584
else:
@@ -1579,21 +1596,37 @@ def broadcast_shape_iter(
15791596
result_dims.append(maybe_non_bcast_shapes[0])
15801597
continue
15811598

1599+
scalar_maybe_non_bcast_shapes = [
1600+
at.scalar_from_tensor(s) if isinstance(s.type, TensorType) else s
1601+
for s in maybe_non_bcast_shapes
1602+
]
1603+
dummy_maybe_non_bcast_shapes = [
1604+
aes.get_scalar_type(dtype=v.dtype)()
1605+
for v in scalar_maybe_non_bcast_shapes
1606+
]
15821607
non_bcast_vec = [
15831608
aes.switch(aes.eq(nbv, 1), -one_at, nbv)
1584-
for nbv in maybe_non_bcast_shapes
1609+
for nbv in dummy_maybe_non_bcast_shapes
15851610
]
15861611
dim_max = aes.abs(reduce(aes.scalar_maximum, non_bcast_vec))
1612+
dim_max_op = Composite(dummy_maybe_non_bcast_shapes, [dim_max])
1613+
1614+
dummy_dim_max = dim_max_op(*dummy_maybe_non_bcast_shapes)
15871615

15881616
assert_dim = Assert("Could not broadcast dimensions")
15891617
assert_cond = reduce(
15901618
aes.and_,
15911619
(
1592-
aes.or_(aes.eq(nbv, -one_at), aes.eq(nbv, dim_max))
1620+
aes.or_(aes.eq(nbv, -one_at), aes.eq(nbv, dummy_dim_max))
15931621
for nbv in non_bcast_vec
15941622
),
15951623
)
1596-
bcast_dim = assert_dim(dim_max, assert_cond)
1624+
assert_cond_op = Composite(dummy_maybe_non_bcast_shapes, [assert_cond])
1625+
1626+
bcast_dim = assert_dim(
1627+
dim_max_op(*scalar_maybe_non_bcast_shapes),
1628+
assert_cond_op(*scalar_maybe_non_bcast_shapes),
1629+
)
15971630

15981631
result_dims.append(bcast_dim)
15991632

@@ -1613,9 +1646,9 @@ def __call__(self, a, shape, **kwargs):
16131646
def make_node(self, a, *shape):
16141647
a = at.as_tensor_variable(a)
16151648

1616-
shape, bcast = at.infer_broadcastable(shape)
1649+
shape, static_shape = at.infer_static_shape(shape)
16171650

1618-
out = TensorType(dtype=a.type.dtype, shape=bcast)()
1651+
out = TensorType(dtype=a.type.dtype, shape=static_shape)()
16191652

16201653
# Attempt to prevent in-place operations on this view-based output
16211654
out.tag.indestructible = True
@@ -1637,11 +1670,14 @@ def grad(self, inputs, outputs_gradients):
16371670
d_wrt_a = broadcast_to(dout, shape).sum(axis=new_dims)
16381671

16391672
# Determine the dimensions that were broadcast
1640-
_, shape_bcast = at.infer_broadcastable(shape)
1673+
_, static_shape = at.infer_static_shape(shape)
1674+
1675+
# TODO: This needs to be performed at run-time when static shape
1676+
# information isn't available.
16411677
bcast_sums = [
16421678
i
1643-
for i, (a_b, s_b) in enumerate(zip(a.broadcastable, shape_bcast[-a.ndim :]))
1644-
if a_b and not s_b
1679+
for i, (a_s, s_s) in enumerate(zip(a.type.shape, static_shape[-a.ndim :]))
1680+
if a_s == 1 and s_s != 1
16451681
]
16461682

16471683
if bcast_sums:

aesara/tensor/random/op.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
constant,
1515
get_scalar_constant_value,
1616
get_vector_length,
17-
infer_broadcastable,
17+
infer_static_shape,
1818
)
1919
from aesara.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType
2020
from aesara.tensor.random.utils import normalize_size_param, params_broadcast_shapes
@@ -322,7 +322,7 @@ def make_node(self, rng, size, dtype, *dist_params):
322322
)
323323

324324
shape = self._infer_shape(size, dist_params)
325-
_, bcast = infer_broadcastable(shape)
325+
_, static_shape = infer_static_shape(shape)
326326
dtype = self.dtype or dtype
327327

328328
if dtype == "floatX":
@@ -336,7 +336,7 @@ def make_node(self, rng, size, dtype, *dist_params):
336336
dtype_idx = constant(dtype, dtype="int64")
337337
dtype = all_dtypes[dtype_idx.data]
338338

339-
outtype = TensorType(dtype=dtype, shape=bcast)
339+
outtype = TensorType(dtype=dtype, shape=static_shape)
340340
out_var = outtype()
341341
inputs = (rng, size, dtype_idx) + dist_params
342342
outputs = (rng.type(), out_var)

aesara/tensor/type.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -331,13 +331,6 @@ def convert_variable(self, var):
331331
# `specify_shape` will combine the more precise shapes of the two types
332332
return aesara.tensor.specify_shape(var, self.shape)
333333

334-
def value_zeros(self, shape):
335-
"""Create an numpy ndarray full of 0 values.
336-
337-
TODO: Remove this trivial method.
338-
"""
339-
return np.zeros(shape, dtype=self.dtype)
340-
341334
@staticmethod
342335
def values_eq(a, b, force_same_dtype=True):
343336
# TODO: check to see if the shapes must match; for now, we err on safe

0 commit comments

Comments
 (0)