Skip to content

Commit ea3d82b

Browse files
author
James Reed
authored
Merge pull request #883 from Chillee/vmap
Added vmap example
2 parents 6e6e0d4 + 745e2a8 commit ea3d82b

File tree

1 file changed

+197
-0
lines changed

1 file changed

+197
-0
lines changed

fx/vmap.py

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
# This example is provided only for explanatory and educational purposes.
2+
import torch
3+
import torch.nn as nn
4+
import torch.fx as fx
5+
from torch.fx import Proxy
6+
from typing import Tuple, Any, Optional
7+
8+
import torch.fx
9+
from torch.fx.passes.shape_prop import ShapeProp
10+
11+
# Vmap
12+
# ---------------
13+
# `vmap` (short for vectorizing map) is a function transformation that takes in
14+
# a model that operates on single examples and returns a model that operates on
15+
# multiple examples. For example, if our model `M` originally takes in tensors
16+
# of shape (H, W) and returns a scalar, then `vmap(M)` should take in tensors
17+
# of shape `(B, H, W)` and return a vector. This procedure is also often called
18+
# "batching" a model.
19+
#
20+
# How is this feat accomplished? One observation is that to "batch" a model, it
21+
# suffices to batch each individual operation. In other words, given an
22+
# operation that works on the current shape, how do we make it work with an
23+
# additional batch dimension? This leads us to batching rules.
24+
#
25+
# Batching Rules
26+
# ---------------
27+
# A batching rule for a function `f` takes in the function `f` (that operates
28+
# on unbatched values), a batched argument `x`, and performs the necessary
29+
# transformations to apply `f` to `x`.
30+
#
31+
# One simple example is `torch.movedim(x, from_dim, to_dim)`, which moves a
32+
# dimension from `from_dim` to `to_dim`. For example, if `x.shape = (1,2,3,4)`,
33+
# then torch.movedim(x, 0, 2) would result in a shape of `(2,3,1,4)`.
34+
#
35+
# However, let's say that we introduce a batch dimension - `x.shape =
36+
# (B,1,2,3,4)`. Now, we can't simply execute the same `torch.movedim(x,0,2)`,
37+
# as there is an extra batch dimension in the front. Instead, we must execute
38+
# `torch.movedim(x,1,3)`. This procedure (and some other stuff to make sure the
39+
# batch dimension is always at the front) is what's done in
40+
# `movedim_batching_rule`.
41+
#
42+
# There is one final thing to note about these batching rules - they're almost
43+
# entirely written in normal PyTorch, with the exception of `bdim` attribute
44+
# that's needed for tracking the batch dimension. That is because in order to
45+
# use these batching rules, we will be tracing them by passing in `Proxy`
46+
# objects that will track the operations performed on them and append them to
47+
# the graph.
48+
49+
def move_bdim_to_front(x, result_ndim=None):
50+
"""
51+
Returns a tensor with a batch dimension at the front. If a batch
52+
dimension already exists, move it. Otherwise, create a new batch
53+
dimension at the front. If `result_ndim` is not None, ensure that the
54+
resulting tensor has rank equal to `result_ndim`.
55+
"""
56+
x_dim = len(x.shape)
57+
x_bdim = x.bdim
58+
if x_bdim is None:
59+
x = torch.unsqueeze(x, 0)
60+
else:
61+
x = torch.movedim(x, x_bdim, 0)
62+
if result_ndim is None:
63+
return x
64+
diff = result_ndim - x_dim - (x_bdim is None)
65+
for _ in range(diff):
66+
x = torch.unsqueeze(x, 1)
67+
return x
68+
69+
def movedim_batching_rule(x, from_dim, to_dim):
70+
x = move_bdim_to_front(x)
71+
return torch.movedim(x, from_dim + 1, to_dim + 1), 0
72+
73+
batching_rules = {}
74+
def gen_binary_op_batching_rule(op):
75+
def binary_op_batching_rule(a, b):
76+
a_ndim = len(a.shape)
77+
b_ndim = len(b.shape)
78+
result_ndim = max(a_ndim, b_ndim)
79+
a = move_bdim_to_front(a, result_ndim)
80+
b = move_bdim_to_front(b, result_ndim)
81+
res = op(a, b)
82+
return res, 0
83+
return binary_op_batching_rule
84+
85+
def unsqueeze_batching_rule(x, dim):
86+
x = move_bdim_to_front(x)
87+
if dim >= 0:
88+
return torch.unsqueeze(x, dim + 1), 0
89+
else:
90+
return torch.unsqueeze(x, dim), 0
91+
92+
93+
batching_rules[torch.mul] = gen_binary_op_batching_rule(torch.mul)
94+
batching_rules[torch.unsqueeze] = unsqueeze_batching_rule
95+
batching_rules[torch.movedim] = movedim_batching_rule
96+
97+
98+
# In order to apply a batching rule, we will simply pass in `Proxy` objects as
99+
# inputs to the functions. As the batching rules need some extra information
100+
# such as the batch dimension and shape, we will do some bookkeeping here.
101+
def gen_batching_rule_function(target, *args):
102+
def lift_shape(i):
103+
res = Proxy(i)
104+
res.shape = i.shape
105+
res.bdim = i.bdim
106+
return res
107+
proxy_args = [lift_shape(i) if isinstance(i, fx.Node) else i for i in args]
108+
out, bdim = batching_rules[target](*proxy_args)
109+
out_node = out.node
110+
out_node.bdim = bdim
111+
return out_node
112+
113+
def vmap(model: torch.nn.Module, in_axes: Tuple[Optional[int], ...], example_args: Tuple[Any, ...]) -> torch.nn.Module:
114+
"""vmap
115+
Given a model with inputs, vmap will return a function that works on
116+
batched versions of those inputs. Which inputs will be batched is
117+
determined by in_axes. In addition, as vmap requires shape (actually
118+
rank) information, we will pass in example_args (example inputs for the
119+
original module).
120+
"""
121+
in_axes = iter(in_axes)
122+
fx_model = fx.symbolic_trace(model)
123+
# Here we run a shape propagation pass in order to annotate the graph with shape information.
124+
ShapeProp(fx_model).propagate(*example_args)
125+
# As vmap rewrites the whole graph, it's easiest to create an entirely new
126+
# graph and append to that.
127+
new_graph: fx.Graph = fx.Graph()
128+
129+
# We will create an environment to map the new nodes created to the
130+
# corresponding old nodes.
131+
def lookup_env(l):
132+
return fx.node.map_aggregate(l, lambda x: env[x.name] if isinstance(x, fx.Node) else x)
133+
env = {}
134+
for node in fx_model.graph.nodes:
135+
if node.op == 'placeholder':
136+
# If the node is an input placeholder, we simply copy it over and
137+
# annotate it with the batch dimension from `in_axes`.
138+
new_node = new_graph.placeholder(node.name)
139+
new_node.bdim = next(in_axes)
140+
new_node.shape = node.shape
141+
env[node.name] = new_node
142+
elif node.op == 'output':
143+
new_graph.output(env[node.args[0].name])
144+
elif node.op == 'call_function':
145+
new_args = lookup_env(node.args)
146+
# If any of the inputs to the function has a new batch dimension,
147+
# we will need to use our batching rules. Otherwise, we will simply
148+
# copy the node over.
149+
if any([x.bdim is not None for x in new_args if isinstance(x, fx.Node)]):
150+
new_node = gen_batching_rule_function(node.target, *new_args)
151+
else:
152+
new_node = new_graph.node_copy(node, lambda x: env[x.name])
153+
new_node.bdim = None
154+
new_node.shape = node.shape
155+
env[node.name] = new_node
156+
else:
157+
raise RuntimeError("Not yet implemented")
158+
159+
160+
res = fx.GraphModule(fx_model, new_graph)
161+
print(res.code)
162+
res.graph.lint()
163+
return res
164+
165+
x = torch.randn(3, 5)
166+
y = torch.randn(2)
167+
class M(nn.Module):
168+
def forward(self, a, b):
169+
return torch.mul(a, b)
170+
171+
# Although this function actually takes in many shapes (due to broadcasting
172+
# rules and such), pretend that M() operates only on scalars.
173+
# The first thing we do is to turn this into a vector scalar multiplication. To
174+
# do so, we will batch along the first dimension to turn it into a vector.
175+
# We provide example_args to specify the original shapes of the function.
176+
model = vmap(M(), in_axes=(None, 0), example_args=(x[0][0], y[0]))
177+
178+
# Now, our shape signature is ((), (M,)) -> (M,). This is computing the
179+
# outer product of 2 vectors.
180+
print(model(x[0][0], y) .shape) # ((), (2,)) -> (2,)
181+
182+
# Now, we want to turn this from a scalar vector multiplication into a vector
183+
# vector multiplication. That is, we would like to have the shape signature of
184+
# ((N,), (M,)) -> (N, M). To do so, we will batch along the second argument.
185+
# This is also known as an outer product.
186+
187+
model = vmap(model, in_axes=(0, None), example_args=(x[0][0], y))
188+
189+
print(model(x[0], y).shape) # ((5,), (2,)) -> (5,2)
190+
191+
192+
# We can continue to add an arbitary number of batch dimensions to our input.
193+
# If we add another batch dimension to the first input we now get a batched
194+
# outer product computation. ((B, N), (M,)) -> (B, N, M)
195+
196+
model = vmap(model, in_axes=(0, None), example_args=(x[0], y))
197+
print(model(x, y).shape) # ((3, 5), (2,)) -> (3, 5, 2)

0 commit comments

Comments
 (0)