|
| 1 | +import torch |
| 2 | +import torch.fx |
| 3 | +""" |
| 4 | +In this example we are going do define a library of |
| 5 | +"composite" operations. Composite operations are those |
| 6 | +that are defined as callable functions that are composed |
| 7 | +of several other operations in their implementation. |
| 8 | +
|
| 9 | +Composite operations allow you to choose at what level |
| 10 | +of abstraction you want to interpret/manipulate the |
| 11 | +code. We show that we can provide a function to inline |
| 12 | +these functions as well as use a custom Tracer to auto- |
| 13 | +matically inline such functions. |
| 14 | +
|
| 15 | +Composite operations can be useful for exposing higher- |
| 16 | +level context to a backend/transform while still |
| 17 | +maintaining the ability to examine things at a more |
| 18 | +fine-grained level. |
| 19 | +""" |
| 20 | + |
| 21 | + |
| 22 | +def sigmoid_lowp(x : torch.Tensor): |
| 23 | + x = x.float() |
| 24 | + x = x.sigmoid() |
| 25 | + return x.half() |
| 26 | + |
| 27 | +# wrap() indicates that the passed-in function should always |
| 28 | +# be recorded as a call_function node rather than being traced |
| 29 | +# through. Later, we will see how we can: |
| 30 | +# a. Inline the implementation of such a function and |
| 31 | +# b. Define a tracer that automatically traces through such a function |
| 32 | +torch.fx.wrap(sigmoid_lowp) |
| 33 | + |
| 34 | +def add_lowp(a : torch.Tensor, b : torch.Tensor): |
| 35 | + a, b = a.float(), b.float() |
| 36 | + c = a + b |
| 37 | + return c.half() |
| 38 | + |
| 39 | +torch.fx.wrap(add_lowp) |
| 40 | + |
| 41 | + |
| 42 | +# Let's see what happens when we symbolically trace through some code |
| 43 | +# that uses these functions |
| 44 | + |
| 45 | +class Foo(torch.nn.Module): |
| 46 | + def forward(self, x, y): |
| 47 | + x = sigmoid_lowp(x) |
| 48 | + y = sigmoid_lowp(y) |
| 49 | + return add_lowp(x, y) |
| 50 | + |
| 51 | + |
| 52 | +traced = torch.fx.symbolic_trace(Foo()) |
| 53 | +print(traced.code) |
| 54 | +""" |
| 55 | +def forward(self, x, y): |
| 56 | + sigmoid_lowp = __main___sigmoid_lowp(x); x = None |
| 57 | + sigmoid_lowp_1 = __main___sigmoid_lowp(y); y = None |
| 58 | + add_lowp = __main___add_lowp(sigmoid_lowp, sigmoid_lowp_1); sigmoid_lowp = sigmoid_lowp_1 = None |
| 59 | + return add_lowp |
| 60 | +""" |
| 61 | + |
| 62 | +# Notice that the calls to `sigmoid_lowp` and `add_lowp` |
| 63 | +# appear literally in the trace; they are not traced through |
| 64 | + |
| 65 | + |
| 66 | +# ***** Inlining calls ***** |
| 67 | +# Now let's define a function that allows for inlining these calls |
| 68 | +# during graph manipulation. |
| 69 | + |
| 70 | +def inline_lowp_func(n : torch.fx.Node): |
| 71 | + # If we find a call to a function in our "lowp" module, inline it |
| 72 | + if n.op == 'call_function' and n.target.__module__ == inline_lowp_func.__module__: |
| 73 | + # We want to insert the operations comprising the implementation of the |
| 74 | + # function before the function itself. Then, we can swap the output value |
| 75 | + # of the function call with the output value for its implementation nodes |
| 76 | + with n.graph.inserting_before(n): |
| 77 | + # We can inline code by using `fx.Proxy` instances. |
| 78 | + # map_arg traverses all aggregate types and applies the given function |
| 79 | + # to Node instances in the data structure. In this case, we are applying |
| 80 | + # the fx.Proxy constructor. |
| 81 | + proxy_args = torch.fx.node.map_arg(n.args, torch.fx.Proxy) |
| 82 | + proxy_kwargs = torch.fx.node.map_arg(n.kwargs, torch.fx.Proxy) |
| 83 | + # Call the function itself with proxy arguments. This will emit |
| 84 | + # nodes in the graph corresponding to the operations in the im- |
| 85 | + # plementation of the function |
| 86 | + output_proxy = n.target(*proxy_args, **proxy_kwargs) |
| 87 | + # Now replace the original node's uses with the output node of |
| 88 | + # the implementation. |
| 89 | + node.replace_all_uses_with(output_proxy.node) |
| 90 | + # Delete the old node |
| 91 | + node.graph.erase_node(node) |
| 92 | + |
| 93 | +for node in traced.graph.nodes: |
| 94 | + if node.op == 'call_function' and node.target is sigmoid_lowp: |
| 95 | + inline_lowp_func(node) |
| 96 | + |
| 97 | +# Don't forget to recompile after graph manipulation |
| 98 | +traced.recompile() |
| 99 | + |
| 100 | +print(traced.code) |
| 101 | +""" |
| 102 | +def forward(self, x, y): |
| 103 | + float_1 = x.float(); x = None |
| 104 | + sigmoid = float_1.sigmoid(); float_1 = None |
| 105 | + half = sigmoid.half(); sigmoid = None |
| 106 | + float_2 = y.float(); y = None |
| 107 | + sigmoid_1 = float_2.sigmoid(); float_2 = None |
| 108 | + half_1 = sigmoid_1.half(); sigmoid_1 = None |
| 109 | + add_lowp = __main___add_lowp(half, half_1); half = half_1 = None |
| 110 | + return add_lowp |
| 111 | +""" |
| 112 | + |
| 113 | +# At this point, the implementation of `sigmoid_lowp` has been substituted |
| 114 | +# in for all of the calls to that function. |
| 115 | + |
| 116 | +# ***** Inlining calls during tracing ***** |
| 117 | +# Now we are going to define a custom tracer that can selectively inline |
| 118 | +# calls to certain composite operations on-the-fly. |
| 119 | + |
| 120 | +# New instance of our module |
| 121 | +f = Foo() |
| 122 | + |
| 123 | +class InliningTracer(torch.fx.Tracer): |
| 124 | + FNS_TO_INLINE = [add_lowp] |
| 125 | + |
| 126 | + def create_node(self, kind, target, args, kwargs, name=None, type_expr=None): |
| 127 | + if kind == 'call_function' and target in self.FNS_TO_INLINE: |
| 128 | + # Trace through the implementation of the function rather than |
| 129 | + # create a node |
| 130 | + proxy_args = torch.fx.node.map_arg(args, torch.fx.Proxy) |
| 131 | + proxy_kwargs = torch.fx.node.map_arg(kwargs, torch.fx.Proxy) |
| 132 | + return target(*proxy_args, **proxy_kwargs).node |
| 133 | + else: |
| 134 | + return super().create_node(kind, target, args, kwargs, name, type_expr) |
| 135 | + |
| 136 | + |
| 137 | +tracer = InliningTracer() |
| 138 | +graph = tracer.trace(f) |
| 139 | +module = torch.fx.GraphModule(f, graph) |
| 140 | +print(module.code) |
| 141 | +""" |
| 142 | +def forward(self, x, y): |
| 143 | + sigmoid_lowp = __main___sigmoid_lowp(x); x = None |
| 144 | + sigmoid_lowp_1 = __main___sigmoid_lowp(y); y = None |
| 145 | + float_1 = sigmoid_lowp.float(); sigmoid_lowp = None |
| 146 | + float_2 = sigmoid_lowp_1.float(); sigmoid_lowp_1 = None |
| 147 | + add = float_1 + float_2; float_1 = float_2 = None |
| 148 | + half = add.half(); add = None |
| 149 | + return half |
| 150 | +""" |
| 151 | + |
| 152 | +# As you can see, the implementation for `add_lowp` has been |
| 153 | +# inlined in the course of tracing with our InliningTracer. |
| 154 | +# Such functionality can be used to, for example, implement |
| 155 | +# a backend that wants to see the lowered form of some operations |
| 156 | +# but the high-level form of another. |
| 157 | + |
| 158 | +# ***** Future direction ***** |
| 159 | +# |
| 160 | +# We may define an API, such as `Tracer.is_leaf_function`, that |
| 161 | +# Tracer implementers can use to more easily specify the inlining |
| 162 | +# behavior implemented in InliningTracer. Such a method would return |
| 163 | +# True by default, but a Tracer can override it and return `False` for |
| 164 | +# functions the Tracer wants to be traced through. |
0 commit comments