-
Notifications
You must be signed in to change notification settings - Fork 1
Broadcasting Notes
I wrote this document as my notes for implementing broadcasting in PyTorch, so it is likely unclear to another reader in a number of places. This document is intended to describe how broadcasting works in PyTorch, how I decided individual PyTorch functions should implement broadcasting, and a comparison to NumPy's broadcasting semantics.
Many PyTorch operations support Broadcasting Semantics.
In short, if a PyTorch operation supports broadcast, then its tensor arguments can be automatically expanded to be of equal sizes (without making copies of the data).
Two tensors are "broadcastable" if the following rules hold:
- Each tensor has at least one dimension.
- When iterating over the dimension sizes, starting at the trailing dimension, the dimension sizes must either be equal, one of them is 1, or one of them does not exist.
For Example:
>>> x=torch.FloatTensor(5,7,3)
>>> y=torch.FloatTensor(5,7,3)
# same shapes are always broadcastable (i.e. the above rules always hold)
>>> x=torch.FloatTensor()
>>> y=torch.FloatTensor(2,2)
# x and y are not broadcastable, because x does not have at least 1 dimension
>>> x=torch.FloatTensor(5,1,4,1)
>>> y=torch.FloatTensor(3,1,1)
# x and y are broadcastable.
# 1st trailing dimension: both have size 1
# 2nd trailing dimension: y has size 1
# 3rd trailing dimension: x has size 1
# 4th trailing dimension: y dimension doesn't exist
# but:
>>> x=torch.FloatTensor(5,2,4,1)
>>> y=torch.FloatTensor(3,1,1)
# x and y are not broadcastable, because in the 3rd trailing dimension 2 != 3
If two tensors x, y are "broadcastable", the resulting tensor size is calculated as follows:
- If the number of dimensions of x and y are not equal, prepend 1 to the dimensions of the tensor with fewer dimensions to make them equal length.
- Then, for each dimension size, the resulting dimension size is the max of the sizes of x and y along that dimension.
For Example:
# can line up trailing dimensions to make reading easier
>>> x=torch.FloatTensor(5,1,4,1)
>>> y=torch.FloatTensor( 3,1,1)
>>> (x+y).size()
torch.Size([5, 3, 4, 1])
# but not necessary:
>>> x=torch.FloatTensor(1)
>>> y=torch.FloatTensor(3,1,7)
>>> (x+y).size()
torch.Size([3, 1, 7])
>>> x=torch.FloatTensor(5,2,4,1)
>>> y=torch.FloatTensor(3,1,1)
>>> (x+y).size()
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1
One complication is that in-place operations do not allow the in-place tensor to change shape as a result of the broadcast.
For Example:
>>> x=torch.FloatTensor(5,3,4,1)
>>> y=torch.FloatTensor(3,1,1)
>>> (x.add_(y)).size()
torch.Size([5, 3, 4, 1])
# but:
>>> x=torch.FloatTensor(1,3,1)
>>> y=torch.FloatTensor(3,1,7)
>>> (x.add_(y)).size()
RuntimeError: The expanded size of the tensor (1) must match the existing size (7) at non-singleton dimension 2.
Prior versions of PyTorch allowed certain pointwise functions to execute on tensors with different shapes, as long as the number of elements in each tensor was equal. The pointwise operation would then be carried out by viewing each tensor as 1-dimensional. PyTorch now supports broadcasting and the "1-dimensional" pointwise behavior is considered deprecated and will generate a Python warning in cases where tensors are not broadcastable, but have the same number of elements.
Note that the introduction of broadcasting can cause backwards incompatible changes in the case where two tensors do not have the same shape, but are broadcastable and have the same number of elements. For Example:
>>> torch.add(torch.ones(4,1), torch.randn(4))
would previously produce a Tensor with size: torch.Size([4,1]), but now produces a Tensor with size: torch.Size([4,4]).
In order to help identify cases in your code where backwards incompatibilities introduced by broadcasting may exist,
you may set torch.utils.backcompat.broadcast.warning.enabled
to True
, which will generate a python warning
in such cases.
For Example:
>>> torch.utils.backcompat.broadcast.warning.enabled=True
>>> torch.add(torch.ones(4,1), torch.ones(4))
__main__:1: UserWarning: self and other do not have the same shape, but are broadcastable, and have the same number of elements.
Changing behavior in a backwards incompatible manner to broadcasting rather than viewing as 1-dimensional.
I've looked into what numpy describes as broadcasting and in my opinion it is really two different PyTorch features:
- broadcasting (as defined in above link)
- batching
The rest of this document will describe for each class of function how broadcasting and/or batching should be implemented.
torch function |
---|
add |
atan2 |
div |
fmod |
lerp |
mul |
pow |
remainder |
dist |
sub |
In-place functions follow the same rules with the additional restriction that the resulting tensor cannot change size. Note that this is the same behavior as numpy, although I couldn't find numpy documentation to this effect. Numpy in general does not allow output parameters (parameters passed in as out,output) to change size (torch in general automatically resizes) and functions where the size is not known apriori (e.g. numpy.nonzero do not support output parameters).
torch function |
---|
eq |
ge |
gt |
le |
lt |
max |
min |
ne |
The broadcasting behavior of these functions is the same as the pointwise math functions.
torch function |
---|
addcdiv |
addcmul |
When in-place, the non-result tensors are broadcast to the size of the resulting tensor (i.e. same behavior as the 2-operand case).
For out-of-place, the behavior should be equivalent to breaking up the operation. I.e. addcmul(C,A,B) is equivalent to add(C,mul(A,B)). With broadcasting behavior, that would mean we first broadcast A and B together, then broadcast the result with C. It's easier to implement this as broadcasting all 3 operands together to start; here's a proof sketch showing they are equivalent:
consider addcmul, under expansion we want: a + (b * c) = (a + b * c) [all expanded together]
Let e(i, j) be the expansion of i with j, e(i, j, k) be the expansion of i with j,k
Then a + (b * c) = e(a, e(b,c) * e(c,b)) + e(e(b,c) * e(c,b), a)
= e(a, e(b,c)) + e(e(b,c) * e(c,b), a) (only size matters for second param)
= e(a,b,c) + e(e(b,c) * e(c,b), a) (by associativity of max in expand)
= e(a,b,c) + e(e(b,c),a) * e(e(c,b), a) (see L1)
= e(a,b,c) + e(b,c,a) * e(c,b,a) (associativity, as above)
which is a + b * c all expanded together
L1: Show e(i * j, a) = e(i,a) * e(j,a) where i,j have same size.
Consider any point (s0, s1, ..., sn-1):
e(i * j, a)(s0, s1, ..., sn-1) =
(i*j)(f(s0), f(s1)...,f(sn-1)) where f is the expansion of that dimension with a
= i(f(s0), f(s1)...,f(sn-1)) * j(f(s0), f(s1)...,f(sn-1)) by definition of pointwise operator
= e(i,a) * e(j,a)
Also note that the combination of changing to broadcasting semantics and changing the keepdims default to False on reductions can cause some calculations to "just work" when either change introduced independently would cause them to fail. For example:
running_mean = torch.randn(4)
input = torch.randn(4,4)
input_mean = input.mean(1)
diff = torch.sum( (input_mean - running_mean).abs() )
With broadcasting and keepdim=False, the sum is over 4 elements, as desired. With broadcasting and keepdim=True, the sum is over 4*4 elements.
The functions that have deprecated pointwise fallback is:
- pow
- add
- sub
- mul
- div
- fmod
- remainder
- addcmul
- addcdiv
Rules for numpy matmul are here.
PyTorch doesn't currently implement the case where the dimensionality is greater than 2:
If either argument is N-D, N > 2, it is treated as a stack of matrices residing in the last two indexes and broadcast accordingly.
For example:
>>> a_np = np.random.randn(2,5,7)
>>> b_np = np.random.randn(5,2,7,3)
>>> c_np = a_np @ b_np
>>> c_np.shape
(5, 2, 5, 3)
Note that while this is referred to as broadcasting, in PyTorch terms this is essentially broadcasting followed by batching+reshaping, i.e. PyTorch only supports the batching case where ndims(A) == ndims(B) == 3 without broadcasting.
For example:
>>> a=torch.from_numpy(a_np)
>>> b=torch.from_numpy(b_np)
>>> c=torch.bmm(a.expand(5,2,5,7).contiguous().view(5*2,5,7),b.view(5*2,7,3)).view(5,2,5,3)
>>> (c-torch.from_numpy(c_np)).abs().max()
8.881784197001252e-16
Implementing the numpy behavior won't cause any backwards incompatibility because PyTorch currently errors out if the dimensionality of either input is greater than 2.
torch function |
---|
mm |
mv |
bmm |
If we take the view that matmul
dispatches to these function to perform numpy matmul semantics, then these functions themselves don't need extra broadcasting or batching semantics, i.e. they are lower-level functions.
torch function | numpy equivalent | supports batching | supports broadcasting |
---|---|---|---|
dot | numpy.dot | no | no |
ger | numpy.outer | no | no |
Nothing to do here.
These are:
torch function |
---|
addmm |
addmv |
addr |
These "fused" functions should behave as if the operation were broken up, e.g. addmm(C,A,B) is equivalent to add(C, mm(A,B)). Given that mm, mv, ger do not broadcast (see above), we should only broadcast the add.
torch function |
---|
baddbmm |
addbmm |
Similar logic to the non-batched functions. e.g.:
baddbmm(C,A,B) = add(C, bmm(A,B)). Since bmm does not broadcast, only the add should broadcast.
Putting these all together:
Function | A.shape | B.shape | unfused equivalent | unfused size | broadcast of C |
---|---|---|---|---|---|
addmm(C,A,B) | (n,m) | (m,p) | add(C, mm(A,B)) | add(C, (n,m)) | (n,p) |
addmv(C,A,B) | (n,m) | (m) | add(C, mv(A,B)) | add(C, (n)) | (n) |
addr(C,A,B) | (n) | (m) | add(C, ger(A,B)) | add(C, (n,m)) | (n,m) |
baddbmm(C,A,B) | (b,n,m) | (b,m,p) | add(C, bmm(A,B)) | add(C, (b,n,p) | (b,n,p) |
addbmm(C,A,B) | (b,n,m) | (b,m,p) | add(C, sum(bmm(A,B),0)) | add(C, sum((b,n,p),0) = add(C, (n,p)) |
(n,p) |
Note that because PyTorch is currently strict about tensor sizes for BLAS operations, there are no backwards compatibility concerns with implementing this type of broadcasting.
There isn't a 1-to-1 mapping of numpy lapack functions to torch lapack functions, but the numpy.linalg package contains close analogs.
Many numpy.linalg functions with a single ndarray operand claim they support broadcasting (see e.g. numpy.linalg.svd), which is strange because broadcasting is only described in terms of two operands. What numpy actually seems to mean is that it supports batching+reshaping (a-la matmul), i.e.:
>>> a=np.array([[0,1], [1,1]])
>>> b=np.array([[0,1], [1,1]])
>>> np.linalg.inv(a)
array([[-1., 1.],
[ 1., 0.]])
>>> np.linalg.inv(b)
array([[-1., 1.],
[ 1., 0.]])
>>> np.linalg.inv([a,b])
array([[[-1., 1.],
[ 1., 0.]],
[[-1., 1.],
[ 1., 0.]]])
torch function | numpy/scipy equivalent | supports batching | supports broadcasting |
---|---|---|---|
inverse | numpy.linalg.inv | yes | no |
eig | numpy.linalg.eig | yes | no |
symeig | numpy.linalg.eigh | yes | no |
qr | numpy.linalg.qr | no | no |
svd | numpy.linalg.svd | yes | no |
btrifact | scipy.linalg.lu_factor | no | no |
btrisolve | scipy.linalg.lu_solve | no | no |
qeqrf | scipy.linalg.lapack.dqeqrt | no | no |
orgqr | scipy.linalg.lapack.dorgqr | no | no |
ormqr | scipy.linalg.lapack.dormqr | no | no |
potrf | scipy.linalg.lapack.potrf | no | no |
potri | scipy.linalg.lapack.dpotri | no | no |
potrs | scipy.linalg.lapack.dpotrs | no | no |
pstrf | none, numpy.linalg.cholesky closest? | yes | no |
Since none of these really support broadcasting (only batching), this should be viewed as a separate issue.
torch function | numpy/scipy equivalent | supports batching | supports broadcasting |
---|---|---|---|
gels | numpy.linalg.lstsq | yes | no |
gesv | numpy.linalg.solve | yes | yes*, see below |
numpy.linalg.solve is actually two functions:
* solve: (m,m), (m,n) -> (m,n)
* solve1: (m,m), (m) -> (m)
both of which support batching and broadcasting. solve1
is selected iff ndims(A) -1 == ndims(b). This leads to some weird results as the dimensions of the tensors change, e.g:
# not solve-1
>>> np.linalg.solve(np.random.randn(2,4,5,9,6,6), np.random.randn(6,15)).shape
(2, 4, 5, 9, 6, 15)
# not solve-1
>>> np.linalg.solve(np.random.randn(2,4,5,9,6,6), np.random.randn(9,6,15)).shape
(2, 4, 5, 9, 6, 15)
# not solve-1
>>> np.linalg.solve(np.random.randn(2,4,5,9,6,6), np.random.randn(5,9,6,15)).shape
(2, 4, 5, 9, 6, 15)
# solve-1, old pattern doesn't work:
>>> np.linalg.solve(np.random.randn(2,4,5,9,6,6), np.random.randn(4,5,9,6,15)).shape
Traceback (most recent call last):
...
ValueError: solve1: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (m,m),(m)->(m) (size 15 is different from 6)
# need to match up dimensions according to solve1
>>> np.linalg.solve(np.random.randn(2,4,5,9,6,6), np.random.randn(2,4,5,9,6)).shape
(2, 4, 5, 9, 6)
This seems unnecessarily complex, i.e. different behavior should be different functions.
torch.gesv
currently supports A having shape (m,m) and B having shape (m) or (m,k). B having shape (m) is treated the same as B having shape (m,1) (note: torch.gels
has the same behavior).
Note that we need to make a similar decision as numpy. Consider the case where B has shape (5,5) and A has shape (5,5,5). If B is interpreted as a matrix, then the result has shape (5,5,5). If B is interpreted as a vector, equivalent to (5,5,1) (when viewed as a matrix), then the result has shape (5,5,1).
To avoid these kinds of complications, it seems nicer to just expose two different functions, gesv
and gesv1
, where gesv
will interpret B as a matrix, except in the case where B has 1-dimension, in which case it will behave as now, and gesv1
will interpret B as a vector.
Note that this is backwards compatible because the previous shape of gesv[0] and new shape of gesv[0] for current valid inputs are equivalent:
A.shape | B.shape | Previous shape gesv[0] | New shape gesv[0] | Shape gesv1[0] |
---|---|---|---|---|
(m,m) | (m) | (m,1) | (m,1) | (m) |
(m,m) | (m,1) | (m,1) | (m,1) | (m,1) if m=1, otherwise Error (vector is size 1, not m) |
(m,m) | (m,k), k != 1 |
(m,k) | (m,k) | (m,k) if m=k, otherwise Error (vector is size k, not m) |
Note that this isn't ideal, because in the case where B is 1-dimensional, the shape of gesv[0] is not the same as the shape of gesv1[0], but that is necessary for backwards compatibility. In other cases (e.g. with pointwise functions), we have preferred numpy semantics over backwards compatibility, but in this case, given that numpy doesn't have (in my opinion) reasonable semantics, we should prefer backwards compatibility.
The alternative, preferring consistency over backwards compatibility, is to change the output shape in the 1-dimensional B case to (m) [from (m,1)] (and for gels
) as well.
torch function | numpy/scipy equivalent | supports broadcasting |
---|---|---|
cat | numpy.concatenate | no |
gather | numpy.choose | yes, although numpy.choose only supports the equivalent of torch.gather with dim=0. |
scatter | No equivalent | Yes, for consistency with gather |
index_select | numpy.take | no |
masked_select | Indexing#boolean-or-mask-index-arrays | no, see advanced indexing |
If we have tensor shape (x0,x1,...,xi,...,xn), dim=i, we want to expand index to (x0,x1,...xi-1,eindex(i),xi+1, .. xn) where eindex(i) is the i-th dimensionality of the (possibly expanded) index tensor. There are three cases to consider:
Case | Solution |
---|---|
i > index.dim() | expand as above, with eindex(i) == 1 and squeeze i: >>> x=np.random.randn(7,3,4) >>> y=np.zeros((1,3,4)).astype('int64') >>> z=np.zeros(4).astype('int64') >>> np.array_equal(np.choose(y,x).squeeze(0), np.choose(z,x)) True
|
i <= index.dim() < n | Note this doesn't come up for numpy. Specific example: tensor.size() == (3,5,7), index.size() == (5,7) -- do we expand like (5,1,7) or (1,5,7)? Following the rule of prepend 1s, should be (1,5,7) -> (3,5,7) |
index.dim() == n | expand as above, with eindex(i) == index.size()[i] |
If out is an (x0, x1, ..., xj, ..., xn) tensor and dim == j, then index must be an (x0, x1, ..., ij, ..., xn) tensor and source must be an (x0, x1, ..., sj, ..., xn) tensor and ij must be <= sj (THC currently enforces that these are equal). So, following the gather rules we have:
- out is guaranteed to be defined because this is an in-place call
- If either index or source have at least i dimensions, and the i-th dimension is di (let's assume they are equal like in THC), we broadcast them to (x0, x1, ..., di, ..., xn). Note that if d != 1 this already violates the scatter rule that "rows" (axes) have unique values.
- If neither index nor source have at least i dimensions, the i-th dimension is not defined under broadcast rules and from above, we already violate the scatter rule that "rows" have unique values, so it shouldn't actually matter what we do.
Given the complexity and ambiguity here, I'm punting on implementing this for now.
torch function | numpy/scipy equivalent | supports broadcasting |
---|---|---|
copy_ | numpy.copyto | yes |
masked_copy | numpy.copyto | yes |
masked_fill | nothing direct, numpy.full is closest | yes, for consistency with masked_copy |
map_ | numpy.vectorize | yes |
map2_ | numpy.vectorize | yes |
index_add_ | No equivalent | No, e.g.:x=torch.arange(1,7).view(2,3) torch.zeros(3,3).index_add_(0, torch.LongTensor([1,0]), x) 4 5 6 1 2 3 0 0 0 [torch.FloatTensor of size 3x3] |
index_copy_ | No equivalent | No, e.g.: x=torch.arange(1,7).view(2,3) torch.zeros(3,3).index_copy_(0, torch.LongTensor([1,0]), x) 4 5 6 1 2 3 0 0 0 [torch.FloatTensor of size 3x3] |
Type | Summary | Backwards Compatibile? | Status |
---|---|---|---|
pointwise | same behavior as numpy | no | PR#1563 |
matmul | same behavior as numpy | yes | PR#1563 |
BLAS |
add broadcasts for fused functions, e.g. addmm
|
yes | PR#1563 |
LAPACK |
gesv splits into gesv and gesv1
|
yes | no plans currently |
Indexing, Slicing, Joining, Mutating Ops |
gather , scatter should broadcast according to the (complex) rules above |
? | no plans currently |
Uncategorized tensor functions | copy, etc. | no | PR#1563 |
``