Skip to content

Commit cb5d996

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
util function for marking input KJT dynamic (#2058)
Summary: # context * In the IR export workflow, the module takes KJTs as input and produces an `ExportedProgram` of the module * KJT actually has a variable length for the values and weights * This dynamic nature of KJT needs to be explicitly passed to torch.export # changes * add a util function to mark the input KJT's dynamic shape * add in the test of how to correctly specify the dynamics shapes for the input KJT # results * input KJTs with different value lengths ``` (Pdb) feature1.values() tensor([0, 1, 2, 3, 2, 3]) (Pdb) feature2.values() tensor([0, 1, 2, 3, 2, 3, 4]) ``` * exported_program can take those input KJTs ``` (Pdb) ep.module()(feature1) [tensor([[-2.8735e-16, -2.8735e-16, -2.8735e-16, -2.8735e-16], [-2.8735e-16, -2.8735e-16, -2.8735e-16, -2.8735e-16]]), tensor([[-2.8735e-16, -2.8735e-16, -2.8735e-16, -2.8735e-16, -1.4368e-15, -1.4368e-15, -1.4368e-15, -1.4368e-15], [-2.8735e-16, -2.8735e-16, -2.8735e-16, -2.8735e-16, -1.4368e-15, -1.4368e-15, -1.4368e-15, -1.4368e-15]])] (Pdb) ep.module()(feature2) [tensor([[-2.8735e-16, -2.8735e-16, -2.8735e-16, -2.8735e-16], [-2.8735e-16, -2.8735e-16, -2.8735e-16, -2.8735e-16]]), tensor([[-2.8735e-16, -2.8735e-16, -2.8735e-16, -2.8735e-16, -1.4368e-15, -1.4368e-15, -1.4368e-15, -1.4368e-15], [-2.8735e-16, -2.8735e-16, -2.8735e-16, -2.8735e-16, -1.4368e-15, -1.4368e-15, -1.4368e-15, -1.4368e-15]])] ``` * deserialized module can take those input KJTs ``` (Pdb) deserialized_model(feature1) [tensor([[ 0.2630, 0.1473, -0.3691, 0.2261], [ 0.0000, 0.0000, 0.0000, 0.0000]], grad_fn=<SplitWithSizesBackward0>), tensor([[ 0.2198, -0.1648, -0.0121, 0.1998, -0.0384, -0.2458, -0.6844, 0.8741], [ 0.1313, 0.2968, -0.2979, -0.2150, -0.2593, 0.6758, 1.0010, 0.9052]], grad_fn=<SplitWithSizesBackward0>)] (Pdb) deserialized_model(feature2) [tensor([[ 0.2630, 0.1473, -0.3691, 0.2261], [ 0.0000, 0.0000, 0.0000, 0.0000]], grad_fn=<SplitWithSizesBackward0>), tensor([[ 0.2198, -0.1648, -0.0121, 0.1998, -0.0384, -0.2458, -0.6844, 0.8741], [ 0.1313, 0.2968, -0.2979, -0.2150, -0.9359, 0.1123, 0.5834, -0.1357]], grad_fn=<SplitWithSizesBackward0>)] ``` Differential Revision: D57824907
1 parent 2c18303 commit cb5d996

File tree

2 files changed

+117
-3
lines changed

2 files changed

+117
-3
lines changed

torchrec/ir/tests/test_serializer.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@
1717
from torch import nn
1818
from torchrec.ir.serializer import JsonSerializer
1919

20-
from torchrec.ir.utils import deserialize_embedding_modules, serialize_embedding_modules
20+
from torchrec.ir.utils import (
21+
deserialize_embedding_modules,
22+
mark_dynamic_kjt,
23+
serialize_embedding_modules,
24+
)
2125

2226
from torchrec.modules.embedding_configs import EmbeddingBagConfig
2327
from torchrec.modules.embedding_modules import EmbeddingBagCollection
@@ -174,6 +178,52 @@ def test_serialize_deserialize_ebc(self) -> None:
174178
assert eager_out[i].shape == tensor.shape
175179
assert torch.allclose(eager_out[i], tensor)
176180

181+
def test_dynamic_shape_ebc(self) -> None:
182+
model = self.generate_model()
183+
feature1 = KeyedJaggedTensor.from_offsets_sync(
184+
keys=["f1", "f2", "f3"],
185+
values=torch.tensor([0, 1, 2, 3, 2, 3]),
186+
offsets=torch.tensor([0, 2, 2, 3, 4, 5, 6]),
187+
)
188+
189+
feature2 = KeyedJaggedTensor.from_offsets_sync(
190+
keys=["f1", "f2", "f3"],
191+
values=torch.tensor([0, 1, 2, 3, 2, 3, 4]),
192+
offsets=torch.tensor([0, 2, 2, 3, 4, 5, 7]),
193+
)
194+
eager_out = model(feature2)
195+
196+
# Serialize EBC
197+
collection = mark_dynamic_kjt(feature1)
198+
model, sparse_fqns = serialize_embedding_modules(model, JsonSerializer)
199+
ep = torch.export.export(
200+
model,
201+
(feature1,),
202+
{},
203+
dynamic_shapes=collection.dynamic_shapes(model, (feature1,)),
204+
strict=False,
205+
# Allows KJT to not be unflattened and run a forward on unflattened EP
206+
preserve_module_call_signature=(tuple(sparse_fqns)),
207+
)
208+
209+
# Run forward on ExportedProgram
210+
ep_output = ep.module()(feature2)
211+
212+
# other asserts
213+
for i, tensor in enumerate(ep_output):
214+
self.assertEqual(eager_out[i].shape, tensor.shape)
215+
216+
# Deserialize EBC
217+
deserialized_model = deserialize_embedding_modules(ep, JsonSerializer)
218+
219+
deserialized_model.load_state_dict(model.state_dict())
220+
# Run forward on deserialized model
221+
deserialized_out = deserialized_model(feature2)
222+
223+
for i, tensor in enumerate(deserialized_out):
224+
self.assertEqual(eager_out[i].shape, tensor.shape)
225+
assert torch.allclose(eager_out[i], tensor)
226+
177227
def test_deserialized_device(self) -> None:
178228
model = self.generate_model()
179229
id_list_features = KeyedJaggedTensor.from_offsets_sync(

torchrec/ir/utils.py

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,21 @@
99

1010
#!/usr/bin/env python3
1111

12-
from typing import List, Optional, Tuple, Type
12+
from collections import defaultdict
13+
from typing import Dict, List, Optional, Tuple, Type, Union
1314

1415
import torch
1516

1617
from torch import nn
17-
from torch.export.exported_program import ExportedProgram
18+
from torch.export import Dim, ExportedProgram, ShapesCollection
19+
from torch.export.dynamic_shapes import _Dim as DIM
20+
from torchrec import KeyedJaggedTensor
1821
from torchrec.ir.types import SerializerInterface
1922

2023

2124
# TODO: Replace the default interface with the python dataclass interface
2225
DEFAULT_SERIALIZER_CLS = SerializerInterface
26+
DYNAMIC_DIMS: Dict[str, int] = defaultdict(int)
2327

2428

2529
def serialize_embedding_modules(
@@ -88,3 +92,63 @@ def deserialize_embedding_modules(
8892
setattr(parent, attrs[-1], new_module)
8993

9094
return model
95+
96+
97+
def _get_dim(x: Union[DIM, str, None], s: str) -> DIM:
98+
if isinstance(x, DIM):
99+
return x
100+
elif isinstance(x, str):
101+
if x in DYNAMIC_DIMS:
102+
DYNAMIC_DIMS[x] += 1
103+
x += str(DYNAMIC_DIMS[x])
104+
dim = Dim(x)
105+
else:
106+
DYNAMIC_DIMS[s] += 1
107+
dim = Dim(s + str(DYNAMIC_DIMS[s]))
108+
return dim
109+
110+
111+
def mark_dynamic_kjt(
112+
kjt: KeyedJaggedTensor,
113+
shapes_collection: Optional[ShapesCollection] = None,
114+
variable_length: bool = False,
115+
vlen: Optional[Union[DIM, str]] = None,
116+
llen: Optional[Union[DIM, str]] = None,
117+
lofs: Optional[Union[DIM, str]] = None,
118+
) -> ShapesCollection:
119+
"""
120+
Makes the given KJT dynamic. If it's not variable length, it will only have
121+
one dynamic dimension, which is the length of the values (and weights).
122+
If it is variable length, then the lengths and offsets will be dynamic.
123+
124+
If a shapes collection is provided, it will be updated with the new shapes,
125+
otherwise a new shapes collection will be created. A passed-in shapes_collection is
126+
useful if you have multiple KJTs or other dynamic shapes that you want to trace.
127+
128+
If a dynamic dim/name is provided, it will directly use that dim/name. Otherwise,
129+
it will use the default name "vlen" for values, and "llen", "lofs" if variable length.
130+
A passed-in dynamic dim is useful if the dynamic dim is already used in other places.
131+
132+
Args:
133+
kjt (KeyedJaggedTensor): The KJT to make dynamic.
134+
shapes_collection (Optional[ShapesCollection]): The collection to update.
135+
variable_length (bool): Whether the KJT is variable length.
136+
vlen (Optional[Union[DIM, str]]): The dynamic length for the values.
137+
llen (Optional[Union[DIM, str]]): The dynamic length for the lengths.
138+
lofs (Optional[Union[DIM, str]]): The dynamic length for the offsets.
139+
"""
140+
global DYNAMIC_DIMS
141+
if shapes_collection is None:
142+
shapes_collection = ShapesCollection()
143+
vlen = _get_dim(vlen, "vlen")
144+
shapes_collection[kjt._values] = (vlen,)
145+
if kjt._weights is not None:
146+
shapes_collection[kjt._weights] = (vlen,)
147+
if variable_length:
148+
llen = _get_dim(llen, "llen")
149+
lofs = _get_dim(lofs, "lofs")
150+
if kjt._lengths is not None:
151+
shapes_collection[kjt._lengths] = (llen,)
152+
if kjt._offsets is not None:
153+
shapes_collection[kjt._offsets] = (lofs,)
154+
return shapes_collection

0 commit comments

Comments
 (0)