Skip to content

Commit 4db1116

Browse files
author
James Reed
authored
Merge pull request #891 from jamesr66a/fx_module_tracer
[FX] Add example of tracer that records module qualname for each node
2 parents c8423c8 + 8f3739f commit 4db1116

File tree

1 file changed

+148
-0
lines changed

1 file changed

+148
-0
lines changed

fx/module_tracer.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
"""
2+
Recording Module Hierarchy With a Custom Tracer
3+
4+
In this example, we are going to define a custom `fx.Tracer` instance that--
5+
for each recorded operation--also notes down the qualified name of the module
6+
from which that operation originated. The _qualified name_ is the path to the
7+
Module from the root module. More information about this concept can be
8+
found in the documentation for `Module.get_submodule`:
9+
https://github.com/pytorch/pytorch/blob/9f2aea7b88f69fc74ad90b1418663802f80c1863/torch/nn/modules/module.py#L385
10+
"""
11+
import torch
12+
import torch.fx
13+
from typing import Any, Callable, Dict, Optional, Tuple
14+
15+
class ModulePathTracer(torch.fx.Tracer):
16+
"""
17+
ModulePathTracer is an FX tracer that--for each operation--also records
18+
the qualified name of the Module from which the operation originated.
19+
"""
20+
21+
# The current qualified name of the Module being traced. The top-level
22+
# module is signified by empty string. This is updated when entering
23+
# call_module and restored when exiting call_module
24+
current_module_qualified_name : str = ''
25+
# A map from FX Node to the qualname of the Module from which it
26+
# originated. This is recorded by `create_proxy` when recording an
27+
# operation
28+
node_to_originating_module : Dict[torch.fx.Node, str] = {}
29+
30+
def call_module(self, m: torch.nn.Module, forward: Callable[..., Any],
31+
args : Tuple[Any, ...], kwargs : Dict[str, Any]) -> Any:
32+
"""
33+
Override of Tracer.call_module (see
34+
https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer.call_module).
35+
36+
This override:
37+
1) Stores away the qualified name of the caller for restoration later
38+
2) Installs the qualified name of the caller in `current_module_qualified_name`
39+
for retrieval by `create_proxy`
40+
3) Delegates into the normal Tracer.call_module method
41+
4) Restores the caller's qualified name into current_module_qualified_name
42+
"""
43+
old_qualname = self.current_module_qualified_name
44+
try:
45+
self.current_module_qualified_name = self.path_of_module(m)
46+
return super().call_module(m, forward, args, kwargs)
47+
finally:
48+
self.current_module_qualified_name = old_qualname
49+
50+
def create_proxy(self, kind: str, target: torch.fx.node.Target, args: Tuple[Any, ...],
51+
kwargs: Dict[str, Any], name: Optional[str] = None, type_expr: Optional[Any] = None):
52+
"""
53+
Override of `Tracer.create_proxy`. This override intercepts the recording
54+
of every operation and stores away the current traced module's qualified
55+
name in `node_to_originating_module`
56+
"""
57+
proxy = super().create_proxy(kind, target, args, kwargs, name, type_expr)
58+
self.node_to_originating_module[proxy.node] = self.current_module_qualified_name
59+
return proxy
60+
61+
62+
# Testing: let's see how this works on a torchvision ResNet18 model
63+
import torchvision.models as models
64+
65+
# Model under test
66+
rn18 = models.resnet18()
67+
68+
# Instantiate our ModulePathTracer and use that to trace our ResNet18
69+
tracer = ModulePathTracer()
70+
traced_rn18 = tracer.trace(rn18)
71+
72+
# Print (node, module qualified name) for every node in the Graph
73+
for node in traced_rn18.nodes:
74+
module_qualname = tracer.node_to_originating_module.get(node)
75+
print('Node', node, 'is from module', module_qualname)
76+
"""
77+
Node x is from module
78+
Node conv1 is from module conv1
79+
Node bn1 is from module bn1
80+
Node relu is from module relu
81+
Node maxpool is from module maxpool
82+
Node layer1_0_conv1 is from module layer1.0.conv1
83+
Node layer1_0_bn1 is from module layer1.0.bn1
84+
Node layer1_0_relu is from module layer1.0.relu
85+
Node layer1_0_conv2 is from module layer1.0.conv2
86+
Node layer1_0_bn2 is from module layer1.0.bn2
87+
Node add is from module layer1.0
88+
Node layer1_0_relu_1 is from module layer1.0.relu
89+
Node layer1_1_conv1 is from module layer1.1.conv1
90+
Node layer1_1_bn1 is from module layer1.1.bn1
91+
Node layer1_1_relu is from module layer1.1.relu
92+
Node layer1_1_conv2 is from module layer1.1.conv2
93+
Node layer1_1_bn2 is from module layer1.1.bn2
94+
Node add_1 is from module layer1.1
95+
Node layer1_1_relu_1 is from module layer1.1.relu
96+
Node layer2_0_conv1 is from module layer2.0.conv1
97+
Node layer2_0_bn1 is from module layer2.0.bn1
98+
Node layer2_0_relu is from module layer2.0.relu
99+
Node layer2_0_conv2 is from module layer2.0.conv2
100+
Node layer2_0_bn2 is from module layer2.0.bn2
101+
Node layer2_0_downsample_0 is from module layer2.0.downsample.0
102+
Node layer2_0_downsample_1 is from module layer2.0.downsample.1
103+
Node add_2 is from module layer2.0
104+
Node layer2_0_relu_1 is from module layer2.0.relu
105+
Node layer2_1_conv1 is from module layer2.1.conv1
106+
Node layer2_1_bn1 is from module layer2.1.bn1
107+
Node layer2_1_relu is from module layer2.1.relu
108+
Node layer2_1_conv2 is from module layer2.1.conv2
109+
Node layer2_1_bn2 is from module layer2.1.bn2
110+
Node add_3 is from module layer2.1
111+
Node layer2_1_relu_1 is from module layer2.1.relu
112+
Node layer3_0_conv1 is from module layer3.0.conv1
113+
Node layer3_0_bn1 is from module layer3.0.bn1
114+
Node layer3_0_relu is from module layer3.0.relu
115+
Node layer3_0_conv2 is from module layer3.0.conv2
116+
Node layer3_0_bn2 is from module layer3.0.bn2
117+
Node layer3_0_downsample_0 is from module layer3.0.downsample.0
118+
Node layer3_0_downsample_1 is from module layer3.0.downsample.1
119+
Node add_4 is from module layer3.0
120+
Node layer3_0_relu_1 is from module layer3.0.relu
121+
Node layer3_1_conv1 is from module layer3.1.conv1
122+
Node layer3_1_bn1 is from module layer3.1.bn1
123+
Node layer3_1_relu is from module layer3.1.relu
124+
Node layer3_1_conv2 is from module layer3.1.conv2
125+
Node layer3_1_bn2 is from module layer3.1.bn2
126+
Node add_5 is from module layer3.1
127+
Node layer3_1_relu_1 is from module layer3.1.relu
128+
Node layer4_0_conv1 is from module layer4.0.conv1
129+
Node layer4_0_bn1 is from module layer4.0.bn1
130+
Node layer4_0_relu is from module layer4.0.relu
131+
Node layer4_0_conv2 is from module layer4.0.conv2
132+
Node layer4_0_bn2 is from module layer4.0.bn2
133+
Node layer4_0_downsample_0 is from module layer4.0.downsample.0
134+
Node layer4_0_downsample_1 is from module layer4.0.downsample.1
135+
Node add_6 is from module layer4.0
136+
Node layer4_0_relu_1 is from module layer4.0.relu
137+
Node layer4_1_conv1 is from module layer4.1.conv1
138+
Node layer4_1_bn1 is from module layer4.1.bn1
139+
Node layer4_1_relu is from module layer4.1.relu
140+
Node layer4_1_conv2 is from module layer4.1.conv2
141+
Node layer4_1_bn2 is from module layer4.1.bn2
142+
Node add_7 is from module layer4.1
143+
Node layer4_1_relu_1 is from module layer4.1.relu
144+
Node avgpool is from module avgpool
145+
Node flatten is from module
146+
Node fc is from module fc
147+
Node output is from module None
148+
"""

0 commit comments

Comments
 (0)