From ab2a1616b385dd1bb3755eb1a048aeb1f041b60f Mon Sep 17 00:00:00 2001 From: Huanyu He Date: Thu, 12 Sep 2024 12:41:02 -0700 Subject: [PATCH 1/2] KT unflatten issue with torch.export Summary: # context current error: ``` 1) torchrec.fb.ir.tests.test_serializer.TestSerializer: test_deserialized_device_vle 1) RuntimeError: Node ir_dynamic_batch_emb_lookup_default referenced nonexistent value id_list_features__values! Run Graph.lint() to diagnose such issues While executing %ir_dynamic_batch_emb_lookup_default : [num_users=1] = call_function[target=torch.ops.torchrec.ir_dynamic_batch_emb_lookup.default](args = ([%id_list_features__values, None, %id_list_features__lengths, None], %floordiv, [4, 5]), kwargs = {}) Original traceback: File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/009ebbab256a7e75/torchrec/fb/ir/tests/__test_serializer__/test_serializer#link-tree/torchrec/fb/ir/tests/test_serializer.py", line 142, in forward return self.sparse_arch(id_list_features) File "torchrec/fb/ir/tests/test_serializer.py", line 446, in test_deserialized_device_vle output = deserialized_model(features_batch_3.to(device)) File "torch/nn/modules/module.py", line 1736, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "torch/nn/modules/module.py", line 1747, in _call_impl return forward_call(*args, **kwargs) File "torch/export/unflatten.py", line 482, in forward tree_out = torch.fx.Interpreter(self, graph=self.graph).run( File "torch/fx/interpreter.py", line 146, in run self.env[node] = self.run_node(node) File "torch/fx/interpreter.py", line 200, in run_node args, kwargs = self.fetch_args_kwargs_from_env(n) File "torch/fx/interpreter.py", line 372, in fetch_args_kwargs_from_env args = self.map_nodes_to_values(n.args, n) File "torch/fx/interpreter.py", line 394, in map_nodes_to_values return map_arg(args, load_arg) File "torch/fx/node.py", line 760, in map_arg return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x) File "torch/fx/node.py", line 768, in map_aggregate t = tuple(map_aggregate(elem, fn) for elem in a) File "torch/fx/node.py", line 768, in t = tuple(map_aggregate(elem, fn) for elem in a) File "torch/fx/node.py", line 772, in map_aggregate return immutable_list(map_aggregate(elem, fn) for elem in a) File "torch/fx/node.py", line 772, in return immutable_list(map_aggregate(elem, fn) for elem in a) File "torch/fx/node.py", line 778, in map_aggregate return fn(a) File "torch/fx/node.py", line 760, in return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x) File "torch/fx/interpreter.py", line 391, in load_arg raise RuntimeError(f'Node {n} referenced nonexistent value {n_arg}! Run Graph.lint() ' ``` Differential Revision: D59238744 --- torchrec/sparse/jagged_tensor.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 9125d1979..7dbe84921 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -2943,7 +2943,9 @@ def _kt_unflatten( def _kt_flatten_spec(kt: KeyedTensor, spec: TreeSpec) -> List[torch.Tensor]: - return _kt_flatten(kt)[0] + _keys, _length_per_key = spec.context + res = KeyedTensor.regroup([kt], [_keys]) + return [res[0]] # The assumption here in torch.exporting KeyedTensor is that _length_per_key is static From 91a004b8d02c52fca8b13d3de161f77484fe177c Mon Sep 17 00:00:00 2001 From: Huanyu He Date: Thu, 12 Sep 2024 16:27:25 -0700 Subject: [PATCH 2/2] Test case for EBC key-order change (#2388) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2388 # context * [post](https://fb.workplace.com/groups/1028545332188949/permalink/1042204770823005/) * this test case mimics the EBC key-order change after sharding {F1864056306} # details * it's a very simple model: EBC ---> KTRegroupAsDict * we generate two EBCs: ebc1 and ebc2, such that the table orders are different: ``` ebc1 = EmbeddingBagCollection( tables=[tb1_config, tb2_config, tb3_config], is_weighted=False, ) ebc2 = EmbeddingBagCollection( tables=[tb1_config, tb3_config, tb2_config], is_weighted=False, ) ``` * we export the model with ebc1 and unflatten the model, and then swap with ebc2 (you can think this as the the sharding process resulting a shardedEBC), so that we can mimic the key-order change as shown in the above graph * the test checks the final results after KTRegroupAsDict are consistent with the original eager model Reviewed By: PaulZhang12 Differential Revision: D62604419 --- torchrec/ir/tests/test_serializer.py | 70 ++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/torchrec/ir/tests/test_serializer.py b/torchrec/ir/tests/test_serializer.py index 060c2d224..8b6926553 100644 --- a/torchrec/ir/tests/test_serializer.py +++ b/torchrec/ir/tests/test_serializer.py @@ -521,3 +521,73 @@ def forward( self.assertTrue(deserialized_model.regroup._is_inited) for key in eager_out.keys(): self.assertEqual(deserialized_out[key].shape, eager_out[key].shape) + + def test_key_order_with_ebc_and_regroup(self) -> None: + tb1_config = EmbeddingBagConfig( + name="t1", + embedding_dim=3, + num_embeddings=10, + feature_names=["f1"], + ) + tb2_config = EmbeddingBagConfig( + name="t2", + embedding_dim=4, + num_embeddings=10, + feature_names=["f2"], + ) + tb3_config = EmbeddingBagConfig( + name="t3", + embedding_dim=5, + num_embeddings=10, + feature_names=["f3"], + ) + id_list_features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f2", "f3", "f4", "f5"], + values=torch.tensor([0, 1, 2, 3, 2, 3, 4, 5, 6, 7, 8, 9, 1, 1, 2]), + offsets=torch.tensor([0, 2, 2, 3, 4, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 15]), + ) + ebc1 = EmbeddingBagCollection( + tables=[tb1_config, tb2_config, tb3_config], + is_weighted=False, + ) + ebc2 = EmbeddingBagCollection( + tables=[tb1_config, tb3_config, tb2_config], + is_weighted=False, + ) + ebc2.load_state_dict(ebc1.state_dict()) + regroup = KTRegroupAsDict([["f1", "f3"], ["f2"]], ["odd", "even"]) + + class myModel(nn.Module): + def __init__(self, ebc, regroup): + super().__init__() + self.ebc = ebc + self.regroup = regroup + + def forward( + self, + features: KeyedJaggedTensor, + ) -> Dict[str, torch.Tensor]: + return self.regroup([self.ebc(features)]) + + model = myModel(ebc1, regroup) + eager_out = model(id_list_features) + + model, sparse_fqns = encapsulate_ir_modules(model, JsonSerializer) + ep = torch.export.export( + model, + (id_list_features,), + {}, + strict=False, + # Allows KJT to not be unflattened and run a forward on unflattened EP + preserve_module_call_signature=(tuple(sparse_fqns)), + ) + unflatten_ep = torch.export.unflatten(ep) + deserialized_model = decapsulate_ir_modules(unflatten_ep, JsonSerializer) + # we export the model with ebc1 and unflatten the model, + # and then swap with ebc2 (you can think this as the the sharding process + # resulting a shardedEBC), so that we can mimic the key-order change + deserialized_model.ebc = ebc2 + + deserialized_out = deserialized_model(id_list_features) + for key in eager_out.keys(): + torch.testing.assert_close(deserialized_out[key], eager_out[key])