Skip to content

Commit 360f8f3

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
debug on KJT issue
Summary: # context * In the IR export workflow, the module takes KJTs as input and produces an `ExportedProgram` of the module * KJT actually has a variable length for the values and weights * This dynamic nature of KJT needs to be explicitly passed to torch.export # changes * add a util function to mark the input KJT's dynamic shape * add in the test of how to correctly specify the dynamics shapes for the input KJT # results * input KJTs with different value lengths ``` (Pdb) feature1.values() tensor([0, 1, 2, 3, 2, 3]) (Pdb) feature2.values() tensor([0, 1, 2, 3, 2, 3, 4]) ``` * exported_program can take those input KJTs ``` (Pdb) ep.module()(feature1) [tensor([[-2.8735e-16, -2.8735e-16, -2.8735e-16, -2.8735e-16], [-2.8735e-16, -2.8735e-16, -2.8735e-16, -2.8735e-16]]), tensor([[-2.8735e-16, -2.8735e-16, -2.8735e-16, -2.8735e-16, -1.4368e-15, -1.4368e-15, -1.4368e-15, -1.4368e-15], [-2.8735e-16, -2.8735e-16, -2.8735e-16, -2.8735e-16, -1.4368e-15, -1.4368e-15, -1.4368e-15, -1.4368e-15]])] (Pdb) ep.module()(feature2) [tensor([[-2.8735e-16, -2.8735e-16, -2.8735e-16, -2.8735e-16], [-2.8735e-16, -2.8735e-16, -2.8735e-16, -2.8735e-16]]), tensor([[-2.8735e-16, -2.8735e-16, -2.8735e-16, -2.8735e-16, -1.4368e-15, -1.4368e-15, -1.4368e-15, -1.4368e-15], [-2.8735e-16, -2.8735e-16, -2.8735e-16, -2.8735e-16, -1.4368e-15, -1.4368e-15, -1.4368e-15, -1.4368e-15]])] ``` * deserialized module can take those input KJTs ``` (Pdb) deserialized_model(feature1) [tensor([[ 0.2630, 0.1473, -0.3691, 0.2261], [ 0.0000, 0.0000, 0.0000, 0.0000]], grad_fn=<SplitWithSizesBackward0>), tensor([[ 0.2198, -0.1648, -0.0121, 0.1998, -0.0384, -0.2458, -0.6844, 0.8741], [ 0.1313, 0.2968, -0.2979, -0.2150, -0.2593, 0.6758, 1.0010, 0.9052]], grad_fn=<SplitWithSizesBackward0>)] (Pdb) deserialized_model(feature2) [tensor([[ 0.2630, 0.1473, -0.3691, 0.2261], [ 0.0000, 0.0000, 0.0000, 0.0000]], grad_fn=<SplitWithSizesBackward0>), tensor([[ 0.2198, -0.1648, -0.0121, 0.1998, -0.0384, -0.2458, -0.6844, 0.8741], [ 0.1313, 0.2968, -0.2979, -0.2150, -0.9359, 0.1123, 0.5834, -0.1357]], grad_fn=<SplitWithSizesBackward0>)] ``` Differential Revision: D57824907
1 parent 2c18303 commit 360f8f3

File tree

1 file changed

+48
-2
lines changed

1 file changed

+48
-2
lines changed

torchrec/ir/utils.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,21 @@
99

1010
#!/usr/bin/env python3
1111

12-
from typing import List, Optional, Tuple, Type
12+
from collections import defaultdict
13+
from typing import Dict, List, Optional, Tuple, Type, Union
1314

1415
import torch
1516

1617
from torch import nn
17-
from torch.export.exported_program import ExportedProgram
18+
from torch.export import Dim, ExportedProgram, ShapesCollection
19+
from torch.export.dynamic_shapes import _Dim as DIM
20+
from torchrec import KeyedJaggedTensor
1821
from torchrec.ir.types import SerializerInterface
1922

2023

2124
# TODO: Replace the default interface with the python dataclass interface
2225
DEFAULT_SERIALIZER_CLS = SerializerInterface
26+
DYNAMIC_DIMS: Dict[str, int] = defaultdict(int)
2327

2428

2529
def serialize_embedding_modules(
@@ -88,3 +92,45 @@ def deserialize_embedding_modules(
8892
setattr(parent, attrs[-1], new_module)
8993

9094
return model
95+
96+
97+
def _get_dim(x: Union[DIM, str, None], s: str) -> DIM:
98+
if isinstance(x, DIM):
99+
return x
100+
elif isinstance(x, str):
101+
if x in DYNAMIC_DIMS:
102+
DYNAMIC_DIMS[x] += 1
103+
x += str(DYNAMIC_DIMS[x])
104+
dim = Dim(x)
105+
else:
106+
DYNAMIC_DIMS[s] += 1
107+
dim = Dim(s + str(DYNAMIC_DIMS[s]))
108+
return dim
109+
110+
111+
def mark_dynamic_kjt(
112+
kjt: KeyedJaggedTensor,
113+
vlen: Union[DIM, str, None] = None,
114+
llen: Union[DIM, str, None] = None,
115+
lofs: Union[DIM, str, None] = None,
116+
variable_length: bool = False,
117+
shapes: Optional[ShapesCollection] = None,
118+
) -> ShapesCollection:
119+
"""
120+
Makes the given KJT dynamic.
121+
"""
122+
global DYNAMIC_DIMS
123+
if shapes is None:
124+
shapes = ShapesCollection()
125+
vlen = _get_dim(vlen, "vlen")
126+
shapes[kjt._values] = (vlen,)
127+
if kjt._weights is not None:
128+
shapes[kjt._weights] = (vlen,)
129+
if variable_length:
130+
llen = _get_dim(llen, "llen")
131+
lofs = _get_dim(lofs, "lofs")
132+
if kjt._lengths is not None:
133+
shapes[kjt._lengths] = (llen,)
134+
if kjt._offsets is not None:
135+
shapes[kjt._offsets] = (lofs,)
136+
return shapes

0 commit comments

Comments
 (0)