|
| 1 | +# This example is provided only for explanatory and educational purposes. |
| 2 | +import torch |
| 3 | +import torch.nn as nn |
| 4 | +import torch.fx as fx |
| 5 | +from torch.fx import Proxy |
| 6 | +from typing import Tuple, Any, Optional |
| 7 | + |
| 8 | +import torch.fx |
| 9 | +from torch.fx.passes.shape_prop import ShapeProp |
| 10 | + |
| 11 | +# Vmap |
| 12 | +# --------------- |
| 13 | +# `vmap` (short for vectorizing map) is a function transformation that takes in |
| 14 | +# a model that operates on single examples and returns a model that operates on |
| 15 | +# multiple examples. For example, if our model `M` originally takes in tensors |
| 16 | +# of shape (H, W) and returns a scalar, then `vmap(M)` should take in tensors |
| 17 | +# of shape `(B, H, W)` and return a vector. This procedure is also often called |
| 18 | +# "batching" a model. |
| 19 | +# |
| 20 | +# How is this feat accomplished? One observation is that to "batch" a model, it |
| 21 | +# suffices to batch each individual operation. In other words, given an |
| 22 | +# operation that works on the current shape, how do we make it work with an |
| 23 | +# additional batch dimension? This leads us to batching rules. |
| 24 | +# |
| 25 | +# Batching Rules |
| 26 | +# --------------- |
| 27 | +# A batching rule for a function `f` takes in the function `f` (that operates |
| 28 | +# on unbatched values), a batched argument `x`, and performs the necessary |
| 29 | +# transformations to apply `f` to `x`. |
| 30 | +# |
| 31 | +# One simple example is `torch.movedim(x, from_dim, to_dim)`, which moves a |
| 32 | +# dimension from `from_dim` to `to_dim`. For example, if `x.shape = (1,2,3,4)`, |
| 33 | +# then torch.movedim(x, 0, 2) would result in a shape of `(2,3,1,4)`. |
| 34 | +# |
| 35 | +# However, let's say that we introduce a batch dimension - `x.shape = |
| 36 | +# (B,1,2,3,4)`. Now, we can't simply execute the same `torch.movedim(x,0,2)`, |
| 37 | +# as there is an extra batch dimension in the front. Instead, we must execute |
| 38 | +# `torch.movedim(x,1,3)`. This procedure (and some other stuff to make sure the |
| 39 | +# batch dimension is always at the front) is what's done in |
| 40 | +# `movedim_batching_rule`. |
| 41 | +# |
| 42 | +# There is one final thing to note about these batching rules - they're almost |
| 43 | +# entirely written in normal PyTorch, with the exception of `bdim` attribute |
| 44 | +# that's needed for tracking the batch dimension. That is because in order to |
| 45 | +# use these batching rules, we will be tracing them by passing in `Proxy` |
| 46 | +# objects that will track the operations performed on them and append them to |
| 47 | +# the graph. |
| 48 | + |
| 49 | +def move_bdim_to_front(x, result_ndim=None): |
| 50 | + """ |
| 51 | + Returns a tensor with a batch dimension at the front. If a batch |
| 52 | + dimension already exists, move it. Otherwise, create a new batch |
| 53 | + dimension at the front. If `result_ndim` is not None, ensure that the |
| 54 | + resulting tensor has rank equal to `result_ndim`. |
| 55 | + """ |
| 56 | + x_dim = len(x.shape) |
| 57 | + x_bdim = x.bdim |
| 58 | + if x_bdim is None: |
| 59 | + x = torch.unsqueeze(x, 0) |
| 60 | + else: |
| 61 | + x = torch.movedim(x, x_bdim, 0) |
| 62 | + if result_ndim is None: |
| 63 | + return x |
| 64 | + diff = result_ndim - x_dim - (x_bdim is None) |
| 65 | + for _ in range(diff): |
| 66 | + x = torch.unsqueeze(x, 1) |
| 67 | + return x |
| 68 | + |
| 69 | +def movedim_batching_rule(x, from_dim, to_dim): |
| 70 | + x = move_bdim_to_front(x) |
| 71 | + return torch.movedim(x, from_dim + 1, to_dim + 1), 0 |
| 72 | + |
| 73 | +batching_rules = {} |
| 74 | +def gen_binary_op_batching_rule(op): |
| 75 | + def binary_op_batching_rule(a, b): |
| 76 | + a_ndim = len(a.shape) |
| 77 | + b_ndim = len(b.shape) |
| 78 | + result_ndim = max(a_ndim, b_ndim) |
| 79 | + a = move_bdim_to_front(a, result_ndim) |
| 80 | + b = move_bdim_to_front(b, result_ndim) |
| 81 | + res = op(a, b) |
| 82 | + return res, 0 |
| 83 | + return binary_op_batching_rule |
| 84 | + |
| 85 | +def unsqueeze_batching_rule(x, dim): |
| 86 | + x = move_bdim_to_front(x) |
| 87 | + if dim >= 0: |
| 88 | + return torch.unsqueeze(x, dim + 1), 0 |
| 89 | + else: |
| 90 | + return torch.unsqueeze(x, dim), 0 |
| 91 | + |
| 92 | + |
| 93 | +batching_rules[torch.mul] = gen_binary_op_batching_rule(torch.mul) |
| 94 | +batching_rules[torch.unsqueeze] = unsqueeze_batching_rule |
| 95 | +batching_rules[torch.movedim] = movedim_batching_rule |
| 96 | + |
| 97 | + |
| 98 | +# In order to apply a batching rule, we will simply pass in `Proxy` objects as |
| 99 | +# inputs to the functions. As the batching rules need some extra information |
| 100 | +# such as the batch dimension and shape, we will do some bookkeeping here. |
| 101 | +def gen_batching_rule_function(target, *args): |
| 102 | + def lift_shape(i): |
| 103 | + res = Proxy(i) |
| 104 | + res.shape = i.shape |
| 105 | + res.bdim = i.bdim |
| 106 | + return res |
| 107 | + proxy_args = [lift_shape(i) if isinstance(i, fx.Node) else i for i in args] |
| 108 | + out, bdim = batching_rules[target](*proxy_args) |
| 109 | + out_node = out.node |
| 110 | + out_node.bdim = bdim |
| 111 | + return out_node |
| 112 | + |
| 113 | +def vmap(model: torch.nn.Module, in_axes: Tuple[Optional[int], ...], example_args: Tuple[Any, ...]) -> torch.nn.Module: |
| 114 | + """vmap |
| 115 | + Given a model with inputs, vmap will return a function that works on |
| 116 | + batched versions of those inputs. Which inputs will be batched is |
| 117 | + determined by in_axes. In addition, as vmap requires shape (actually |
| 118 | + rank) information, we will pass in example_args (example inputs for the |
| 119 | + original module). |
| 120 | + """ |
| 121 | + in_axes = iter(in_axes) |
| 122 | + fx_model = fx.symbolic_trace(model) |
| 123 | + # Here we run a shape propagation pass in order to annotate the graph with shape information. |
| 124 | + ShapeProp(fx_model).propagate(*example_args) |
| 125 | + # As vmap rewrites the whole graph, it's easiest to create an entirely new |
| 126 | + # graph and append to that. |
| 127 | + new_graph: fx.Graph = fx.Graph() |
| 128 | + |
| 129 | + # We will create an environment to map the new nodes created to the |
| 130 | + # corresponding old nodes. |
| 131 | + def lookup_env(l): |
| 132 | + return fx.node.map_aggregate(l, lambda x: env[x.name] if isinstance(x, fx.Node) else x) |
| 133 | + env = {} |
| 134 | + for node in fx_model.graph.nodes: |
| 135 | + if node.op == 'placeholder': |
| 136 | + # If the node is an input placeholder, we simply copy it over and |
| 137 | + # annotate it with the batch dimension from `in_axes`. |
| 138 | + new_node = new_graph.placeholder(node.name) |
| 139 | + new_node.bdim = next(in_axes) |
| 140 | + new_node.shape = node.shape |
| 141 | + env[node.name] = new_node |
| 142 | + elif node.op == 'output': |
| 143 | + new_graph.output(env[node.args[0].name]) |
| 144 | + elif node.op == 'call_function': |
| 145 | + new_args = lookup_env(node.args) |
| 146 | + # If any of the inputs to the function has a new batch dimension, |
| 147 | + # we will need to use our batching rules. Otherwise, we will simply |
| 148 | + # copy the node over. |
| 149 | + if any([x.bdim is not None for x in new_args if isinstance(x, fx.Node)]): |
| 150 | + new_node = gen_batching_rule_function(node.target, *new_args) |
| 151 | + else: |
| 152 | + new_node = new_graph.node_copy(node, lambda x: env[x.name]) |
| 153 | + new_node.bdim = None |
| 154 | + new_node.shape = node.shape |
| 155 | + env[node.name] = new_node |
| 156 | + else: |
| 157 | + raise RuntimeError("Not yet implemented") |
| 158 | + |
| 159 | + |
| 160 | + res = fx.GraphModule(fx_model, new_graph) |
| 161 | + print(res.code) |
| 162 | + res.graph.lint() |
| 163 | + return res |
| 164 | + |
| 165 | +x = torch.randn(3, 5) |
| 166 | +y = torch.randn(2) |
| 167 | +class M(nn.Module): |
| 168 | + def forward(self, a, b): |
| 169 | + return torch.mul(a, b) |
| 170 | + |
| 171 | +# Although this function actually takes in many shapes (due to broadcasting |
| 172 | +# rules and such), pretend that M() operates only on scalars. |
| 173 | +# The first thing we do is to turn this into a vector scalar multiplication. To |
| 174 | +# do so, we will batch along the first dimension to turn it into a vector. |
| 175 | +# We provide example_args to specify the original shapes of the function. |
| 176 | +model = vmap(M(), in_axes=(None, 0), example_args=(x[0][0], y[0])) |
| 177 | + |
| 178 | +# Now, our shape signature is ((), (M,)) -> (M,). This is computing the |
| 179 | +# outer product of 2 vectors. |
| 180 | +print(model(x[0][0], y) .shape) # ((), (2,)) -> (2,) |
| 181 | + |
| 182 | +# Now, we want to turn this from a scalar vector multiplication into a vector |
| 183 | +# vector multiplication. That is, we would like to have the shape signature of |
| 184 | +# ((N,), (M,)) -> (N, M). To do so, we will batch along the second argument. |
| 185 | +# This is also known as an outer product. |
| 186 | + |
| 187 | +model = vmap(model, in_axes=(0, None), example_args=(x[0][0], y)) |
| 188 | + |
| 189 | +print(model(x[0], y).shape) # ((5,), (2,)) -> (5,2) |
| 190 | + |
| 191 | + |
| 192 | +# We can continue to add an arbitary number of batch dimensions to our input. |
| 193 | +# If we add another batch dimension to the first input we now get a batched |
| 194 | +# outer product computation. ((B, N), (M,)) -> (B, N, M) |
| 195 | + |
| 196 | +model = vmap(model, in_axes=(0, None), example_args=(x[0], y)) |
| 197 | +print(model(x, y).shape) # ((3, 5), (2,)) -> (3, 5, 2) |
0 commit comments