Skip to content

Test case for EBC key-order change #2388

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions torchrec/ir/tests/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,3 +521,73 @@ def forward(
self.assertTrue(deserialized_model.regroup._is_inited)
for key in eager_out.keys():
self.assertEqual(deserialized_out[key].shape, eager_out[key].shape)

def test_key_order_with_ebc_and_regroup(self) -> None:
tb1_config = EmbeddingBagConfig(
name="t1",
embedding_dim=3,
num_embeddings=10,
feature_names=["f1"],
)
tb2_config = EmbeddingBagConfig(
name="t2",
embedding_dim=4,
num_embeddings=10,
feature_names=["f2"],
)
tb3_config = EmbeddingBagConfig(
name="t3",
embedding_dim=5,
num_embeddings=10,
feature_names=["f3"],
)
id_list_features = KeyedJaggedTensor.from_offsets_sync(
keys=["f1", "f2", "f3", "f4", "f5"],
values=torch.tensor([0, 1, 2, 3, 2, 3, 4, 5, 6, 7, 8, 9, 1, 1, 2]),
offsets=torch.tensor([0, 2, 2, 3, 4, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 15]),
)
ebc1 = EmbeddingBagCollection(
tables=[tb1_config, tb2_config, tb3_config],
is_weighted=False,
)
ebc2 = EmbeddingBagCollection(
tables=[tb1_config, tb3_config, tb2_config],
is_weighted=False,
)
ebc2.load_state_dict(ebc1.state_dict())
regroup = KTRegroupAsDict([["f1", "f3"], ["f2"]], ["odd", "even"])

class myModel(nn.Module):
def __init__(self, ebc, regroup):
super().__init__()
self.ebc = ebc
self.regroup = regroup

def forward(
self,
features: KeyedJaggedTensor,
) -> Dict[str, torch.Tensor]:
return self.regroup([self.ebc(features)])

model = myModel(ebc1, regroup)
eager_out = model(id_list_features)

model, sparse_fqns = encapsulate_ir_modules(model, JsonSerializer)
ep = torch.export.export(
model,
(id_list_features,),
{},
strict=False,
# Allows KJT to not be unflattened and run a forward on unflattened EP
preserve_module_call_signature=(tuple(sparse_fqns)),
)
unflatten_ep = torch.export.unflatten(ep)
deserialized_model = decapsulate_ir_modules(unflatten_ep, JsonSerializer)
# we export the model with ebc1 and unflatten the model,
# and then swap with ebc2 (you can think this as the the sharding process
# resulting a shardedEBC), so that we can mimic the key-order change
deserialized_model.ebc = ebc2

deserialized_out = deserialized_model(id_list_features)
for key in eager_out.keys():
torch.testing.assert_close(deserialized_out[key], eager_out[key])
4 changes: 3 additions & 1 deletion torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2943,7 +2943,9 @@ def _kt_unflatten(


def _kt_flatten_spec(kt: KeyedTensor, spec: TreeSpec) -> List[torch.Tensor]:
return _kt_flatten(kt)[0]
_keys, _length_per_key = spec.context
res = KeyedTensor.regroup([kt], [_keys])
return [res[0]]


# The assumption here in torch.exporting KeyedTensor is that _length_per_key is static
Expand Down
Loading