Skip to content
gchanan edited this page Jun 1, 2017 · 93 revisions

Purpose

I wrote this document as my notes for implementing broadcasting in PyTorch, so it is likely unclear to another reader in a number of places. This document is intended to describe how broadcasting works in PyTorch, how I decided individual PyTorch functions should implement broadcasting, and a comparison to NumPy's broadcasting semantics.

PyTorch broadcasting semantics

Many PyTorch operations support Broadcasting Semantics.

In short, if a PyTorch operation supports broadcast, then its tensor arguments can be automatically expanded to be of equal sizes (without making copies of the data).

General semantics

Two tensors are "broadcastable" if the following rules hold:

  • Each tensor has at least one dimension.
  • When iterating over the dimension sizes, starting at the trailing dimension, the dimension sizes must either be equal, one of them is 1, or one of them does not exist.

For Example:

>>> x=torch.FloatTensor(5,7,3)
>>> y=torch.FloatTensor(5,7,3)
# same shapes are always broadcastable (i.e. the above rules always hold)

>>> x=torch.FloatTensor()
>>> y=torch.FloatTensor(2,2)
# x and y are not broadcastable, because x does not have at least 1 dimension

>>> x=torch.FloatTensor(5,1,4,1)
>>> y=torch.FloatTensor(3,1,1)
# x and y are broadcastable.
# 1st trailing dimension: both have size 1
# 2nd trailing dimension: y has size 1
# 3rd trailing dimension: x has size 1
# 4th trailing dimension: y dimension doesn't exist

# but:
>>> x=torch.FloatTensor(5,2,4,1)
>>> y=torch.FloatTensor(3,1,1)
# x and y are not broadcastable, because in the 3rd trailing dimension 2 != 3

If two tensors x, y are "broadcastable", the resulting tensor size is calculated as follows:

  • If the number of dimensions of x and y are not equal, prepend 1 to the dimensions of the tensor with fewer dimensions to make them equal length.
  • Then, for each dimension size, the resulting dimension size is the max of the sizes of x and y along that dimension.

For Example:

# can line up trailing dimensions to make reading easier
>>> x=torch.FloatTensor(5,1,4,1)
>>> y=torch.FloatTensor(  3,1,1)
>>> (x+y).size()
torch.Size([5, 3, 4, 1])

# but not necessary:
>>> x=torch.FloatTensor(1)
>>> y=torch.FloatTensor(3,1,7)
>>> (x+y).size()
torch.Size([3, 1, 7])

>>> x=torch.FloatTensor(5,2,4,1)
>>> y=torch.FloatTensor(3,1,1)
>>> (x+y).size()
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1

In-place semantics

One complication is that in-place operations do not allow the in-place tensor to change shape as a result of the broadcast.

For Example:

>>> x=torch.FloatTensor(5,3,4,1)
>>> y=torch.FloatTensor(3,1,1)
>>> (x.add_(y)).size()
torch.Size([5, 3, 4, 1])

# but:
>>> x=torch.FloatTensor(1,3,1)
>>> y=torch.FloatTensor(3,1,7)
>>> (x.add_(y)).size()
RuntimeError: The expanded size of the tensor (1) must match the existing size (7) at non-singleton dimension 2.

Backwards compatibility

Prior versions of PyTorch allowed certain pointwise functions to execute on tensors with different shapes, as long as the number of elements in each tensor was equal. The pointwise operation would then be carried out by viewing each tensor as 1-dimensional. PyTorch now supports broadcasting and the "1-dimensional" pointwise behavior is considered deprecated and will generate a Python warning in cases where tensors are not broadcastable, but have the same number of elements.

Note that the introduction of broadcasting can cause backwards incompatible changes in the case where two tensors do not have the same shape, but are broadcastable and have the same number of elements. For Example:

>>> torch.add(torch.ones(4,1), torch.randn(4))

would previously produce a Tensor with size: torch.Size([4,1]), but now produces a Tensor with size: torch.Size([4,4]). In order to help identify cases in your code where backwards incompatibilities introduced by broadcasting may exist, you may set torch.utils.backcompat.broadcast.warning.enabled to True, which will generate a python warning in such cases.

For Example:

>>> torch.utils.backcompat.broadcast.warning.enabled=True
>>> torch.add(torch.ones(4,1), torch.ones(4))
__main__:1: UserWarning: self and other do not have the same shape, but are broadcastable, and have the same number of elements.
Changing behavior in a backwards incompatible manner to broadcasting rather than viewing as 1-dimensional.

A comparison to numpy broadcasting

I've looked into what numpy describes as broadcasting and in my opinion it is really two different PyTorch features:

  • broadcasting (as defined in above link)
  • batching

The rest of this document will describe for each class of function how broadcasting and/or batching should be implemented.

Pointwise functions

Pointwise two tensor operand math functions:

torch function
add
atan2
div
fmod
lerp
mul
pow
remainder
dist
sub

In-place functions follow the same rules with the additional restriction that the resulting tensor cannot change size. Note that this is the same behavior as numpy, although I couldn't find numpy documentation to this effect. Numpy in general does not allow output parameters (parameters passed in as out,output) to change size (torch in general automatically resizes) and functions where the size is not known apriori (e.g. numpy.nonzero do not support output parameters).

Pointwise two tensor operand comparison functions:

torch function
eq
ge
gt
le
lt
max
min
ne

The broadcasting behavior of these functions is the same as the pointwise math functions.

Pointwise three tensor operand math functions:

torch function
addcdiv
addcmul

When in-place, the non-result tensors are broadcast to the size of the resulting tensor (i.e. same behavior as the 2-operand case).

For out-of-place, the behavior should be equivalent to breaking up the operation. I.e. addcmul(C,A,B) is equivalent to add(C,mul(A,B)). With broadcasting behavior, that would mean we first broadcast A and B together, then broadcast the result with C. It's easier to implement this as broadcasting all 3 operands together to start; here's a proof sketch showing they are equivalent:

consider addcmul, under expansion we want: a + (b * c) = (a + b * c) [all expanded together]
Let e(i, j) be the expansion of i with j, e(i, j, k) be the expansion of i with j,k

Then a + (b * c) = e(a, e(b,c) * e(c,b)) + e(e(b,c)    * e(c,b), a)
                 = e(a, e(b,c))          + e(e(b,c)    * e(c,b), a)    (only size matters for second param)
                 = e(a,b,c)              + e(e(b,c)    * e(c,b), a)    (by associativity of max in expand)
                 = e(a,b,c)              + e(e(b,c),a) * e(e(c,b), a)  (see L1)
                 = e(a,b,c)              + e(b,c,a)    * e(c,b,a)      (associativity, as above)
which is a + b * c all expanded together

L1: Show e(i * j, a) = e(i,a) * e(j,a) where i,j have same size.
Consider any point (s0, s1, ..., sn-1):
e(i * j, a)(s0, s1, ..., sn-1) =
(i*j)(f(s0), f(s1)...,f(sn-1)) where f is the expansion of that dimension with a
= i(f(s0), f(s1)...,f(sn-1)) * j(f(s0), f(s1)...,f(sn-1)) by definition of pointwise operator
= e(i,a) * e(j,a)

A note on backwards incompatibility for pointwise functions and keepdims for reduction functions:

Also note that the combination of changing to broadcasting semantics and changing the keepdims default to False on reductions can cause some calculations to "just work" when either change introduced independently would cause them to fail. For example:

running_mean = torch.randn(4)
input = torch.randn(4,4)
input_mean = input.mean(1)
diff = torch.sum( (input_mean - running_mean).abs() )

With broadcasting and keepdim=False, the sum is over 4 elements, as desired. With broadcasting and keepdim=True, the sum is over 4*4 elements.

The functions that have deprecated pointwise fallback is:

  • pow
  • add
  • sub
  • mul
  • div
  • fmod
  • remainder
  • addcmul
  • addcdiv

matmul

Rules for numpy matmul are here.

PyTorch doesn't currently implement the case where the dimensionality is greater than 2:

If either argument is N-D, N > 2, it is treated as a stack of matrices residing in the last two indexes and broadcast accordingly.

For example:

>>> a_np = np.random.randn(2,5,7)
>>> b_np = np.random.randn(5,2,7,3)
>>> c_np = a_np @ b_np
>>> c_np.shape
(5, 2, 5, 3)

Note that while this is referred to as broadcasting, in PyTorch terms this is essentially broadcasting followed by batching+reshaping, i.e. PyTorch only supports the batching case where ndims(A) == ndims(B) == 3 without broadcasting.

For example:

>>> a=torch.from_numpy(a_np)
>>> b=torch.from_numpy(b_np)
>>> c=torch.bmm(a.expand(5,2,5,7).contiguous().view(5*2,5,7),b.view(5*2,7,3)).view(5,2,5,3)
>>> (c-torch.from_numpy(c_np)).abs().max()
8.881784197001252e-16

Implementing the numpy behavior won't cause any backwards incompatibility because PyTorch currently errors out if the dimensionality of either input is greater than 2.

BLAS functions

matmul functions

torch function
mm
mv
bmm

If we take the view that matmul dispatches to these function to perform numpy matmul semantics, then these functions themselves don't need extra broadcasting or batching semantics, i.e. they are lower-level functions.

1-dimensional functions

torch function numpy equivalent supports batching supports broadcasting
dot numpy.dot no no
ger numpy.outer no no

Nothing to do here.

Fused BLAS functions

Non-batched fused functions

These are:

torch function
addmm
addmv
addr

These "fused" functions should behave as if the operation were broken up, e.g. addmm(C,A,B) is equivalent to add(C, mm(A,B)). Given that mm, mv, ger do not broadcast (see above), we should only broadcast the add.

Batched fused functions

torch function
baddbmm
addbmm

Similar logic to the non-batched functions. e.g.:

baddbmm(C,A,B) = add(C, bmm(A,B)). Since bmm does not broadcast, only the add should broadcast.

Putting these all together:

Function A.shape B.shape unfused equivalent unfused size broadcast of C
addmm(C,A,B) (n,m) (m,p) add(C, mm(A,B)) add(C, (n,m)) (n,p)
addmv(C,A,B) (n,m) (m) add(C, mv(A,B)) add(C, (n)) (n)
addr(C,A,B) (n) (m) add(C, ger(A,B)) add(C, (n,m)) (n,m)
baddbmm(C,A,B) (b,n,m) (b,m,p) add(C, bmm(A,B)) add(C, (b,n,p) (b,n,p)
addbmm(C,A,B) (b,n,m) (b,m,p) add(C, sum(bmm(A,B),0)) add(C, sum((b,n,p),0) =
add(C, (n,p))
(n,p)

Note that because PyTorch is currently strict about tensor sizes for BLAS operations, there are no backwards compatibility concerns with implementing this type of broadcasting.

LAPACK functions

There isn't a 1-to-1 mapping of numpy lapack functions to torch lapack functions, but the numpy.linalg package contains close analogs.

1 tensor operand LAPACK functions

Many numpy.linalg functions with a single ndarray operand claim they support broadcasting (see e.g. numpy.linalg.svd), which is strange because broadcasting is only described in terms of two operands. What numpy actually seems to mean is that it supports batching+reshaping (a-la matmul), i.e.:

>>> a=np.array([[0,1], [1,1]])
>>> b=np.array([[0,1], [1,1]])
>>> np.linalg.inv(a)
array([[-1.,  1.],
       [ 1.,  0.]])
>>> np.linalg.inv(b)
array([[-1.,  1.],
       [ 1.,  0.]])
>>> np.linalg.inv([a,b])
array([[[-1.,  1.],
        [ 1.,  0.]],

       [[-1.,  1.],
        [ 1.,  0.]]])
torch function numpy/scipy equivalent supports batching supports broadcasting
inverse numpy.linalg.inv yes no
eig numpy.linalg.eig yes no
symeig numpy.linalg.eigh yes no
qr numpy.linalg.qr no no
svd numpy.linalg.svd yes no
btrifact scipy.linalg.lu_factor no no
btrisolve scipy.linalg.lu_solve no no
qeqrf scipy.linalg.lapack.dqeqrt no no
orgqr scipy.linalg.lapack.dorgqr no no
ormqr scipy.linalg.lapack.dormqr no no
potrf scipy.linalg.lapack.potrf no no
potri scipy.linalg.lapack.dpotri no no
potrs scipy.linalg.lapack.dpotrs no no
pstrf none, numpy.linalg.cholesky closest? yes no

Since none of these really support broadcasting (only batching), this should be viewed as a separate issue.

2 tensor operand LAPACK functions

torch function numpy/scipy equivalent supports batching supports broadcasting
gels numpy.linalg.lstsq yes no
gesv numpy.linalg.solve yes yes*, see below

numpy.linalg.solve is actually two functions:

  * solve:  (m,m), (m,n) -> (m,n)
  * solve1: (m,m), (m)   -> (m)

both of which support batching and broadcasting. solve1 is selected iff ndims(A) -1 == ndims(b). This leads to some weird results as the dimensions of the tensors change, e.g:

# not solve-1
>>> np.linalg.solve(np.random.randn(2,4,5,9,6,6), np.random.randn(6,15)).shape
(2, 4, 5, 9, 6, 15)

# not solve-1
>>> np.linalg.solve(np.random.randn(2,4,5,9,6,6), np.random.randn(9,6,15)).shape
(2, 4, 5, 9, 6, 15)

# not solve-1
>>> np.linalg.solve(np.random.randn(2,4,5,9,6,6), np.random.randn(5,9,6,15)).shape
(2, 4, 5, 9, 6, 15)

# solve-1, old pattern doesn't work:
>>> np.linalg.solve(np.random.randn(2,4,5,9,6,6), np.random.randn(4,5,9,6,15)).shape
Traceback (most recent call last):
...
ValueError: solve1: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (m,m),(m)->(m) (size 15 is different from 6)

# need to match up dimensions according to solve1
>>> np.linalg.solve(np.random.randn(2,4,5,9,6,6), np.random.randn(2,4,5,9,6)).shape
(2, 4, 5, 9, 6)

This seems unnecessarily complex, i.e. different behavior should be different functions.

torch.gesv currently supports A having shape (m,m) and B having shape (m) or (m,k). B having shape (m) is treated the same as B having shape (m,1) (note: torch.gels has the same behavior).

Note that we need to make a similar decision as numpy. Consider the case where B has shape (5,5) and A has shape (5,5,5). If B is interpreted as a matrix, then the result has shape (5,5,5). If B is interpreted as a vector, equivalent to (5,5,1) (when viewed as a matrix), then the result has shape (5,5,1).

To avoid these kinds of complications, it seems nicer to just expose two different functions, gesv and gesv1, where gesv will interpret B as a matrix, except in the case where B has 1-dimension, in which case it will behave as now, and gesv1 will interpret B as a vector.

Note that this is backwards compatible because the previous shape of gesv[0] and new shape of gesv[0] for current valid inputs are equivalent:

A.shape B.shape Previous shape gesv[0] New shape gesv[0] Shape gesv1[0]
(m,m) (m) (m,1) (m,1) (m)
(m,m) (m,1) (m,1) (m,1) (m,1) if m=1,
otherwise Error (vector is size 1, not m)
(m,m) (m,k),
k != 1
(m,k) (m,k) (m,k) if m=k,
otherwise Error (vector is size k, not m)

Note that this isn't ideal, because in the case where B is 1-dimensional, the shape of gesv[0] is not the same as the shape of gesv1[0], but that is necessary for backwards compatibility. In other cases (e.g. with pointwise functions), we have preferred numpy semantics over backwards compatibility, but in this case, given that numpy doesn't have (in my opinion) reasonable semantics, we should prefer backwards compatibility.

The alternative, preferring consistency over backwards compatibility, is to change the output shape in the 1-dimensional B case to (m) [from (m,1)] (and for gels) as well.

Indexing, Slicing, Joining, Mutating Ops

torch function numpy/scipy equivalent supports broadcasting
cat numpy.concatenate no
gather numpy.choose yes, although numpy.choose only supports the equivalent of torch.gather with dim=0.
scatter No equivalent Yes, for consistency with gather
index_select numpy.take no
masked_select Indexing#boolean-or-mask-index-arrays no, see advanced indexing

Gather explanation:

If we have tensor shape (x0,x1,...,xi,...,xn), dim=i, we want to expand index to (x0,x1,...xi-1,eindex(i),xi+1, .. xn) where eindex(i) is the i-th dimensionality of the (possibly expanded) index tensor. There are three cases to consider:

Case Solution
i > index.dim() expand as above, with eindex(i) == 1 and squeeze i:
>>> x=np.random.randn(7,3,4)
>>> y=np.zeros((1,3,4)).astype('int64')
>>> z=np.zeros(4).astype('int64')
>>> np.array_equal(np.choose(y,x).squeeze(0), np.choose(z,x))
True
i <= index.dim() < n Note this doesn't come up for numpy. Specific example: tensor.size() == (3,5,7), index.size() == (5,7) -- do we expand like (5,1,7) or (1,5,7)? Following the rule of prepend 1s, should be (1,5,7) -> (3,5,7)
index.dim() == n expand as above, with eindex(i) == index.size()[i]

Scatter explanation:

If out is an (x0, x1, ..., xj, ..., xn) tensor and dim == j, then index must be an (x0, x1, ..., ij, ..., xn) tensor and source must be an (x0, x1, ..., sj, ..., xn) tensor and ij must be <= sj (THC currently enforces that these are equal). So, following the gather rules we have:

  • out is guaranteed to be defined because this is an in-place call
  • If either index or source have at least i dimensions, and the i-th dimension is di (let's assume they are equal like in THC), we broadcast them to (x0, x1, ..., di, ..., xn). Note that if d != 1 this already violates the scatter rule that "rows" (axes) have unique values.
  • If neither index nor source have at least i dimensions, the i-th dimension is not defined under broadcast rules and from above, we already violate the scatter rule that "rows" have unique values, so it shouldn't actually matter what we do.

Given the complexity and ambiguity here, I'm punting on implementing this for now.

Uncategorized tensor functions

torch function numpy/scipy equivalent supports broadcasting
copy_ numpy.copyto yes
masked_copy numpy.copyto yes
masked_fill nothing direct, numpy.full is closest yes, for consistency with masked_copy
map_ numpy.vectorize yes
map2_ numpy.vectorize yes
index_add_ No equivalent No, e.g.:
x=torch.arange(1,7).view(2,3)
torch.zeros(3,3).index_add_(0, torch.LongTensor([1,0]), x)
4 5 6
1 2 3
0 0 0
[torch.FloatTensor of size 3x3]
index_copy_ No equivalent No, e.g.:
x=torch.arange(1,7).view(2,3)
torch.zeros(3,3).index_copy_(0, torch.LongTensor([1,0]), x)
4 5 6
1 2 3
0 0 0
[torch.FloatTensor of size 3x3]

Current Status/Summary

Type Summary Backwards Compatibile? Status
pointwise same behavior as numpy no PR#1563
matmul same behavior as numpy yes PR#1563
BLAS add broadcasts for fused functions, e.g. addmm yes PR#1563
LAPACK gesv splits into gesv and gesv1 yes no plans currently
Indexing, Slicing, Joining, Mutating Ops gather, scatter should broadcast according to the (complex) rules above ? no plans currently
Uncategorized tensor functions copy, etc. no PR#1563

``

Clone this wiki locally