Skip to content

Commit ab3d1e9

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
switch to new op for KeyedTensor.regroup (#2226)
Summary: Pull Request resolved: #2226 # context * the new op `permute_multi_embedding` outperforms the original op `_fbgemm_permute_pooled_embs` * this diff makes the move to switch to the new op # benchmark * more results: D58907223, [traces](https://drive.google.com/drive/folders/1DEYozPihmij2zRAyG9AMxaIbcjTPWRVU?usp=drive_link) * previous prod {F1755206204} * new prod {F1755207013} * metrics |Operator|CPU runtime|GPU runtime|GPU memory|notes| |---|---|---|---|---| |**[fallback] pytorch generic**|3.9 ms|3.2 ms|1.0 K|CPU-bounded, allow duplicates| |**[previous prod] permute_pooled_embs**|1.9 ms|4.9 ms|1.5 K|GPU-boudned, does **NOT** allow duplicates, PT2 non-compatible `pin_and_move`| |**[new prod] permute_multi_embedding**|1.0 ms|2.0 ms|1.0 K|both CPU and GPU runtime/memory improved, **ALLOW** duplicates, PT2 friendly| NOTE: the new op takes in `List[List[str]]` and `List[List[int]]`, it currently does not support dynamic_shape and produces error like the following: > 1) SerializeError: Failed serializing node kt_regroup_permutes in graph: %kt_regroup_permutes : [num_users=3] = call_function[target=torch.ops.fbgemm.kt_regroup_permutes.default](args = (%ir_custom_op, [[f1], [f2]], [[3], [5]], [[f1], [f2]]), kwargs = {}) ... Caused by SerializeError: Unsupported list/tuple argument type: [<class 'torch.fx.immutable_collections.immutable_list'>, <class 'torch.fx.immutable_collections.immutable_list'>] Reviewed By: dstaay-fb Differential Revision: D55277833 fbshipit-source-id: be47179c62b2df48445c78eabf5d7d44582a495b
1 parent 0221dfc commit ab3d1e9

File tree

2 files changed

+54
-5
lines changed

2 files changed

+54
-5
lines changed

torchrec/ir/tests/test_serializer.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,59 @@ def test_deserialized_device(self) -> None:
352352
continue
353353
assert param.device.type == device.type, f"{name} should be on {device}"
354354

355+
# pyre-ignore
356+
@unittest.skipIf(
357+
torch.cuda.device_count() <= 0,
358+
"this test needs a GPU machine to run",
359+
)
360+
def test_deserialize_device_kt_regroup(self) -> None:
361+
class Model(nn.Module):
362+
def __init__(self, ebc):
363+
super().__init__()
364+
self.ebc = ebc
365+
366+
def forward(
367+
self,
368+
features: KeyedJaggedTensor,
369+
) -> List[torch.Tensor]:
370+
kt = self.ebc(features)
371+
return KeyedTensor.regroup([kt], [[key] for key in kt.keys()])
372+
373+
model = self.generate_model()
374+
model = Model(model.ebc1)
375+
id_list_features = KeyedJaggedTensor.from_offsets_sync(
376+
keys=["f1", "f2", "f3"],
377+
values=torch.tensor([0, 1, 2, 3, 2, 3]),
378+
offsets=torch.tensor([0, 2, 2, 3, 4, 5, 6]),
379+
)
380+
eager_out = model(id_list_features)
381+
382+
# Serialize EBC
383+
model, sparse_fqns = encapsulate_ir_modules(model, JsonSerializer)
384+
ep = torch.export.export(
385+
model,
386+
(id_list_features,),
387+
{},
388+
strict=False,
389+
# Allows KJT to not be unflattened and run a forward on unflattened EP
390+
preserve_module_call_signature=(tuple(sparse_fqns)),
391+
)
392+
unflatten_model = torch.export.unflatten(ep)
393+
deserialized_model = decapsulate_ir_modules(
394+
unflatten_model, JsonSerializer, torch.device("cuda")
395+
)
396+
device = torch.device("cuda")
397+
deserialized_model.to(device)
398+
id_list_features = id_list_features.to(device)
399+
400+
deserialized_model.load_state_dict(model.state_dict())
401+
# Run forward on deserialized model
402+
deserialized_out = deserialized_model(id_list_features)
403+
404+
for i, tensor in enumerate(deserialized_out):
405+
assert eager_out[i].shape == tensor.shape
406+
assert torch.allclose(eager_out[i].to(tensor), tensor)
407+
355408
def test_compound_module(self) -> None:
356409
tb1_config = EmbeddingBagConfig(
357410
name="t1",

torchrec/sparse/jagged_tensor.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2753,11 +2753,7 @@ def to_dict(self) -> Dict[str, torch.Tensor]:
27532753
def regroup(
27542754
keyed_tensors: List["KeyedTensor"], groups: List[List[str]]
27552755
) -> List[torch.Tensor]:
2756-
# Fast path, one-to-one correspondence between keyed_tensors and groups
2757-
if _all_keys_used_once(keyed_tensors, groups) is True:
2758-
return _fbgemm_permute_pooled_embs(keyed_tensors, groups)
2759-
else: # Fallback to slow path otherwise
2760-
return _regroup_keyed_tensors(keyed_tensors, groups)
2756+
return permute_multi_embedding(keyed_tensors, groups)
27612757

27622758
@staticmethod
27632759
def regroup_as_dict(

0 commit comments

Comments
 (0)