Skip to content

Commit aef78a5

Browse files
author
James Reed
authored
Merge pull request #899 from jamesr66a/primitive_lib
[FX] Add primitive library example
2 parents 507493d + 25530d6 commit aef78a5

File tree

1 file changed

+164
-0
lines changed

1 file changed

+164
-0
lines changed

fx/primitive_library.py

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
import torch
2+
import torch.fx
3+
"""
4+
In this example we are going do define a library of
5+
"composite" operations. Composite operations are those
6+
that are defined as callable functions that are composed
7+
of several other operations in their implementation.
8+
9+
Composite operations allow you to choose at what level
10+
of abstraction you want to interpret/manipulate the
11+
code. We show that we can provide a function to inline
12+
these functions as well as use a custom Tracer to auto-
13+
matically inline such functions.
14+
15+
Composite operations can be useful for exposing higher-
16+
level context to a backend/transform while still
17+
maintaining the ability to examine things at a more
18+
fine-grained level.
19+
"""
20+
21+
22+
def sigmoid_lowp(x : torch.Tensor):
23+
x = x.float()
24+
x = x.sigmoid()
25+
return x.half()
26+
27+
# wrap() indicates that the passed-in function should always
28+
# be recorded as a call_function node rather than being traced
29+
# through. Later, we will see how we can:
30+
# a. Inline the implementation of such a function and
31+
# b. Define a tracer that automatically traces through such a function
32+
torch.fx.wrap(sigmoid_lowp)
33+
34+
def add_lowp(a : torch.Tensor, b : torch.Tensor):
35+
a, b = a.float(), b.float()
36+
c = a + b
37+
return c.half()
38+
39+
torch.fx.wrap(add_lowp)
40+
41+
42+
# Let's see what happens when we symbolically trace through some code
43+
# that uses these functions
44+
45+
class Foo(torch.nn.Module):
46+
def forward(self, x, y):
47+
x = sigmoid_lowp(x)
48+
y = sigmoid_lowp(y)
49+
return add_lowp(x, y)
50+
51+
52+
traced = torch.fx.symbolic_trace(Foo())
53+
print(traced.code)
54+
"""
55+
def forward(self, x, y):
56+
sigmoid_lowp = __main___sigmoid_lowp(x); x = None
57+
sigmoid_lowp_1 = __main___sigmoid_lowp(y); y = None
58+
add_lowp = __main___add_lowp(sigmoid_lowp, sigmoid_lowp_1); sigmoid_lowp = sigmoid_lowp_1 = None
59+
return add_lowp
60+
"""
61+
62+
# Notice that the calls to `sigmoid_lowp` and `add_lowp`
63+
# appear literally in the trace; they are not traced through
64+
65+
66+
# ***** Inlining calls *****
67+
# Now let's define a function that allows for inlining these calls
68+
# during graph manipulation.
69+
70+
def inline_lowp_func(n : torch.fx.Node):
71+
# If we find a call to a function in our "lowp" module, inline it
72+
if n.op == 'call_function' and n.target.__module__ == inline_lowp_func.__module__:
73+
# We want to insert the operations comprising the implementation of the
74+
# function before the function itself. Then, we can swap the output value
75+
# of the function call with the output value for its implementation nodes
76+
with n.graph.inserting_before(n):
77+
# We can inline code by using `fx.Proxy` instances.
78+
# map_arg traverses all aggregate types and applies the given function
79+
# to Node instances in the data structure. In this case, we are applying
80+
# the fx.Proxy constructor.
81+
proxy_args = torch.fx.node.map_arg(n.args, torch.fx.Proxy)
82+
proxy_kwargs = torch.fx.node.map_arg(n.kwargs, torch.fx.Proxy)
83+
# Call the function itself with proxy arguments. This will emit
84+
# nodes in the graph corresponding to the operations in the im-
85+
# plementation of the function
86+
output_proxy = n.target(*proxy_args, **proxy_kwargs)
87+
# Now replace the original node's uses with the output node of
88+
# the implementation.
89+
node.replace_all_uses_with(output_proxy.node)
90+
# Delete the old node
91+
node.graph.erase_node(node)
92+
93+
for node in traced.graph.nodes:
94+
if node.op == 'call_function' and node.target is sigmoid_lowp:
95+
inline_lowp_func(node)
96+
97+
# Don't forget to recompile after graph manipulation
98+
traced.recompile()
99+
100+
print(traced.code)
101+
"""
102+
def forward(self, x, y):
103+
float_1 = x.float(); x = None
104+
sigmoid = float_1.sigmoid(); float_1 = None
105+
half = sigmoid.half(); sigmoid = None
106+
float_2 = y.float(); y = None
107+
sigmoid_1 = float_2.sigmoid(); float_2 = None
108+
half_1 = sigmoid_1.half(); sigmoid_1 = None
109+
add_lowp = __main___add_lowp(half, half_1); half = half_1 = None
110+
return add_lowp
111+
"""
112+
113+
# At this point, the implementation of `sigmoid_lowp` has been substituted
114+
# in for all of the calls to that function.
115+
116+
# ***** Inlining calls during tracing *****
117+
# Now we are going to define a custom tracer that can selectively inline
118+
# calls to certain composite operations on-the-fly.
119+
120+
# New instance of our module
121+
f = Foo()
122+
123+
class InliningTracer(torch.fx.Tracer):
124+
FNS_TO_INLINE = [add_lowp]
125+
126+
def create_node(self, kind, target, args, kwargs, name=None, type_expr=None):
127+
if kind == 'call_function' and target in self.FNS_TO_INLINE:
128+
# Trace through the implementation of the function rather than
129+
# create a node
130+
proxy_args = torch.fx.node.map_arg(args, torch.fx.Proxy)
131+
proxy_kwargs = torch.fx.node.map_arg(kwargs, torch.fx.Proxy)
132+
return target(*proxy_args, **proxy_kwargs).node
133+
else:
134+
return super().create_node(kind, target, args, kwargs, name, type_expr)
135+
136+
137+
tracer = InliningTracer()
138+
graph = tracer.trace(f)
139+
module = torch.fx.GraphModule(f, graph)
140+
print(module.code)
141+
"""
142+
def forward(self, x, y):
143+
sigmoid_lowp = __main___sigmoid_lowp(x); x = None
144+
sigmoid_lowp_1 = __main___sigmoid_lowp(y); y = None
145+
float_1 = sigmoid_lowp.float(); sigmoid_lowp = None
146+
float_2 = sigmoid_lowp_1.float(); sigmoid_lowp_1 = None
147+
add = float_1 + float_2; float_1 = float_2 = None
148+
half = add.half(); add = None
149+
return half
150+
"""
151+
152+
# As you can see, the implementation for `add_lowp` has been
153+
# inlined in the course of tracing with our InliningTracer.
154+
# Such functionality can be used to, for example, implement
155+
# a backend that wants to see the lowered form of some operations
156+
# but the high-level form of another.
157+
158+
# ***** Future direction *****
159+
#
160+
# We may define an API, such as `Tracer.is_leaf_function`, that
161+
# Tracer implementers can use to more easily specify the inlining
162+
# behavior implemented in InliningTracer. Such a method would return
163+
# True by default, but a Tracer can override it and return `False` for
164+
# functions the Tracer wants to be traced through.

0 commit comments

Comments
 (0)