Skip to content

Commit 18cdef6

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
short circuit the flatten/unflatten between EBC and KTRegroupAsDict modules
Summary: # context * for the root cause and background please refer to this [post](https://fb.workplace.com/groups/1028545332188949/permalink/1042204770823005/) * basica idea of this diff is to **short circuit the pytree flatten-unflatten function pairs** between two preserved modules, i.e., EBC/fpEBC and KTRegroupAsDict. NOTE: There could be multiple EBCs and one single KTRegroupAsDict as shown in the [pic](https://fburl.com/gslide/lcyt8eh3) {F1864810545} * short-circuiting the EBC-KTRegroupAsDict pairs are very special and a must in most of the cases due to the EBC key-order issue with distributed table lookup. * hide all the operations behind a control flag `short_circuit_pytree_ebc_regroup` to the torchrec main api call `decapsulate_ir_modules`, which should only be visible to the infra layer, not to the users. # details * The `_short_circuit_pytree_ebc_regroup` function finds all the EBCs/fpEBC and KTRegroupAsDict modules in an unflattened module. Retrieve their fqns and sort to in_fqns (regroup_fqns) and out_fqns (ebc_fqns). Because currently the fpEBC is swapped as a whole, so we do some extra fqn logic to filter out the EBC that belongs to an up-level fpEBC. * a util function `prune_pytree_flatten_unflatten` removes the in-coming and out-going pytree flatten/unflatten function calls in the graph module, based on the given fqns. WARNING: The flag `short_circuit_pytree_ebc_regroup` should be turned on if EBCs are used and EBC sharding is needed. Assertions are also added if can't find a `KTRegroupAsDict` module, or `finalize_interpreter_modules` is not `True`. # additional changes * absorb the `finalize_interpreter_modules` process inside the torchrec main api `decapsulate_ir_modules`. * set `graph.owning_module` in export.unflatten as required by the graph modification * add one more layer of `sparse_module` for closely mimicing the APF model structure. Differential Revision: D62606738
1 parent 15c912e commit 18cdef6

File tree

2 files changed

+128
-3
lines changed

2 files changed

+128
-3
lines changed

torchrec/ir/tests/test_serializer.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -557,7 +557,7 @@ def test_key_order_with_ebc_and_regroup(self) -> None:
557557
ebc2.load_state_dict(ebc1.state_dict())
558558
regroup = KTRegroupAsDict([["f1", "f3"], ["f2"]], ["odd", "even"])
559559

560-
class myModel(nn.Module):
560+
class mySparse(nn.Module):
561561
def __init__(self, ebc, regroup):
562562
super().__init__()
563563
self.ebc = ebc
@@ -569,6 +569,17 @@ def forward(
569569
) -> Dict[str, torch.Tensor]:
570570
return self.regroup([self.ebc(features)])
571571

572+
class myModel(nn.Module):
573+
def __init__(self, ebc, regroup):
574+
super().__init__()
575+
self.sparse = mySparse(ebc, regroup)
576+
577+
def forward(
578+
self,
579+
features: KeyedJaggedTensor,
580+
) -> Dict[str, torch.Tensor]:
581+
return self.sparse(features)
582+
572583
model = myModel(ebc1, regroup)
573584
eager_out = model(id_list_features)
574585

@@ -582,11 +593,17 @@ def forward(
582593
preserve_module_call_signature=(tuple(sparse_fqns)),
583594
)
584595
unflatten_ep = torch.export.unflatten(ep)
585-
deserialized_model = decapsulate_ir_modules(unflatten_ep, JsonSerializer)
596+
deserialized_model = decapsulate_ir_modules(
597+
unflatten_ep,
598+
JsonSerializer,
599+
short_circuit_pytree_ebc_regroup=True,
600+
finalize_interpreter_modules=True,
601+
)
602+
586603
# we export the model with ebc1 and unflatten the model,
587604
# and then swap with ebc2 (you can think this as the the sharding process
588605
# resulting a shardedEBC), so that we can mimic the key-order change
589-
deserialized_model.ebc = ebc2
606+
deserialized_model.sparse.ebc = ebc2
590607

591608
deserialized_out = deserialized_model(id_list_features)
592609
for key in eager_out.keys():

torchrec/ir/utils.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#!/usr/bin/env python3
1111

1212
import logging
13+
import operator
1314
from collections import defaultdict
1415
from typing import Dict, List, Optional, Tuple, Type
1516

@@ -18,7 +19,12 @@
1819
from torch import nn
1920
from torch.export import Dim, ShapesCollection
2021
from torch.export.dynamic_shapes import _Dim as DIM
22+
from torch.export.unflatten import InterpreterModule
23+
from torch.fx import Node
2124
from torchrec.ir.types import SerializerInterface
25+
from torchrec.modules.embedding_modules import EmbeddingBagCollection
26+
from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection
27+
from torchrec.modules.regroup import KTRegroupAsDict
2228
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
2329

2430

@@ -129,6 +135,8 @@ def decapsulate_ir_modules(
129135
module: nn.Module,
130136
serializer: Type[SerializerInterface] = DEFAULT_SERIALIZER_CLS,
131137
device: Optional[torch.device] = None,
138+
finalize_interpreter_modules: bool = False,
139+
short_circuit_pytree_ebc_regroup: bool = False,
132140
) -> nn.Module:
133141
"""
134142
Takes a module and decapsulate its embedding modules by retrieving the buffer.
@@ -147,6 +155,16 @@ def decapsulate_ir_modules(
147155
# we use "ir_metadata" as a convention to identify the deserializable module
148156
if "ir_metadata" in dict(module.named_buffers()):
149157
module = serializer.decapsulate_module(module, device)
158+
159+
if short_circuit_pytree_ebc_regroup:
160+
module = _short_circuit_pytree_ebc_regroup(module)
161+
assert finalize_interpreter_modules, "need finalize_interpreter_modules=True"
162+
163+
if finalize_interpreter_modules:
164+
for mod in module.modules():
165+
if isinstance(mod, InterpreterModule):
166+
mod.finalize()
167+
150168
return module
151169

152170

@@ -233,3 +251,93 @@ def move_to_copy_nodes_to_device(
233251
nodes.kwargs = new_kwargs
234252

235253
return unflattened_module
254+
255+
256+
def _short_circuit_pytree_ebc_regroup(module: nn.Module) -> nn.Module:
257+
"""
258+
Bypass pytree flatten and unflatten function between EBC and KTRegroupAsDict to avoid key-order issue.
259+
https://fb.workplace.com/groups/1028545332188949/permalink/1042204770823005/
260+
EBC ==> (out-going) pytree.flatten ==> tensors and specs ==> (in-coming) pytree.unflatten ==> KTRegroupAsDict
261+
"""
262+
ebc_fqns: List[str] = []
263+
regroup_fqns: List[str] = []
264+
for fqn, m in module.named_modules():
265+
if isinstance(m, FeatureProcessedEmbeddingBagCollection):
266+
ebc_fqns.append(fqn)
267+
elif isinstance(m, EmbeddingBagCollection):
268+
if len(ebc_fqns) > 0 and fqn.startswith(ebc_fqns[-1]):
269+
continue
270+
ebc_fqns.append(fqn)
271+
elif isinstance(m, KTRegroupAsDict):
272+
regroup_fqns.append(fqn)
273+
if (len(ebc_fqns) == 0) != (len(regroup_fqns) == 0):
274+
logger.warning("Perf impact if EBC and KTRegroupAsDict are not used together.")
275+
return module
276+
else:
277+
return prune_pytree_flatten_unflatten(
278+
module, in_fqns=regroup_fqns, out_fqns=ebc_fqns
279+
)
280+
281+
282+
def prune_pytree_flatten_unflatten(
283+
module: nn.Module, in_fqns: List[str], out_fqns: List[str]
284+
) -> nn.Module:
285+
"""
286+
Remove pytree flatten and unflatten function between the given in_fqns and out_fqns.
287+
"preserved module" ==> (out-going) pytree.flatten ==> [tensors and specs]
288+
[tensors and specs] ==> (in-coming) pytree.unflatten ==> "preserved module"
289+
"""
290+
291+
def _get_graph_node(mod: nn.Module, fqn: str) -> Tuple[nn.Module, Node]:
292+
for node in mod.graph.nodes:
293+
if node.op == "call_module" and node.target == fqn:
294+
return mod, node
295+
assert "." in fqn, f"can't find {fqn} in the graph of {mod}"
296+
curr, fqn = fqn.split(".", maxsplit=1)
297+
mod = getattr(mod, curr)
298+
return _get_graph_node(mod, fqn)
299+
300+
# remove tree_unflatten from the in_fqns (in-coming nodes)
301+
for fqn in in_fqns:
302+
submodule, node = _get_graph_node(module, fqn)
303+
assert len(node.args) == 1
304+
getitem_getitem: Node = node.args[0] # pyre-ignore[9]
305+
assert (
306+
getitem_getitem.op == "call_function"
307+
and getitem_getitem.target == operator.getitem
308+
)
309+
tree_unflatten_getitem = node.args[0].args[0] # pyre-ignore[16]
310+
assert (
311+
tree_unflatten_getitem.op == "call_function"
312+
and tree_unflatten_getitem.target == operator.getitem
313+
)
314+
tree_unflatten = tree_unflatten_getitem.args[0]
315+
assert (
316+
tree_unflatten.op == "call_function"
317+
and tree_unflatten.target == torch.utils._pytree.tree_unflatten
318+
)
319+
logger.info(f"Removing tree_unflatten from {fqn}")
320+
input_nodes = tree_unflatten.args[0]
321+
node.args = (input_nodes,)
322+
submodule.graph.eliminate_dead_code()
323+
324+
# remove tree_flatten_spec from the out_fqns (out-going nodes)
325+
for fqn in out_fqns:
326+
submodule, node = _get_graph_node(module, fqn)
327+
users = list(node.users.keys())
328+
assert (
329+
len(users) == 1
330+
and users[0].op == "call_function"
331+
and users[0].target == torch.fx._pytree.tree_flatten_spec
332+
)
333+
tree_flatten_users = list(users[0].users.keys())
334+
assert (
335+
len(tree_flatten_users) == 1
336+
and tree_flatten_users[0].op == "call_function"
337+
and tree_flatten_users[0].target == operator.getitem
338+
)
339+
logger.info(f"Removing tree_flatten_spec from {fqn}")
340+
getitem_node = tree_flatten_users[0]
341+
getitem_node.replace_all_uses_with(node)
342+
submodule.graph.eliminate_dead_code()
343+
return module

0 commit comments

Comments
 (0)