|
9 | 9 |
|
10 | 10 | #!/usr/bin/env python3
|
11 | 11 |
|
12 |
| -from typing import List, Optional, Tuple, Type |
| 12 | +from collections import defaultdict |
| 13 | +from typing import Dict, List, Optional, Tuple, Type, Union |
13 | 14 |
|
14 | 15 | import torch
|
15 | 16 |
|
16 | 17 | from torch import nn
|
17 |
| -from torch.export.exported_program import ExportedProgram |
| 18 | +from torch.export import Dim, ExportedProgram, ShapesCollection |
| 19 | +from torch.export.dynamic_shapes import _Dim as DIM |
| 20 | +from torchrec import KeyedJaggedTensor |
18 | 21 | from torchrec.ir.types import SerializerInterface
|
19 | 22 |
|
20 | 23 |
|
21 | 24 | # TODO: Replace the default interface with the python dataclass interface
|
22 | 25 | DEFAULT_SERIALIZER_CLS = SerializerInterface
|
| 26 | +DYNAMIC_DIMS: Dict[str, int] = defaultdict(int) |
23 | 27 |
|
24 | 28 |
|
25 | 29 | def serialize_embedding_modules(
|
@@ -88,3 +92,45 @@ def deserialize_embedding_modules(
|
88 | 92 | setattr(parent, attrs[-1], new_module)
|
89 | 93 |
|
90 | 94 | return model
|
| 95 | + |
| 96 | + |
| 97 | +def _get_dim(x: Union[DIM, str, None], s: str) -> DIM: |
| 98 | + if isinstance(x, DIM): |
| 99 | + return x |
| 100 | + elif isinstance(x, str): |
| 101 | + if x in DYNAMIC_DIMS: |
| 102 | + DYNAMIC_DIMS[x] += 1 |
| 103 | + x += str(DYNAMIC_DIMS[x]) |
| 104 | + dim = Dim(x) |
| 105 | + else: |
| 106 | + DYNAMIC_DIMS[s] += 1 |
| 107 | + dim = Dim(s + str(DYNAMIC_DIMS[s])) |
| 108 | + return dim |
| 109 | + |
| 110 | + |
| 111 | +def mark_dynamic_kjt( |
| 112 | + kjt: KeyedJaggedTensor, |
| 113 | + vlen: Union[DIM, str, None] = None, |
| 114 | + llen: Union[DIM, str, None] = None, |
| 115 | + lofs: Union[DIM, str, None] = None, |
| 116 | + variable_length: bool = False, |
| 117 | + shapes: Optional[ShapesCollection] = None, |
| 118 | +) -> ShapesCollection: |
| 119 | + """ |
| 120 | + Makes the given KJT dynamic. |
| 121 | + """ |
| 122 | + global DYNAMIC_DIMS |
| 123 | + if shapes is None: |
| 124 | + shapes = ShapesCollection() |
| 125 | + vlen = _get_dim(vlen, "vlen") |
| 126 | + shapes[kjt._values] = (vlen,) |
| 127 | + if kjt._weights is not None: |
| 128 | + shapes[kjt._weights] = (vlen,) |
| 129 | + if variable_length: |
| 130 | + llen = _get_dim(llen, "llen") |
| 131 | + lofs = _get_dim(lofs, "lofs") |
| 132 | + if kjt._lengths is not None: |
| 133 | + shapes[kjt._lengths] = (llen,) |
| 134 | + if kjt._offsets is not None: |
| 135 | + shapes[kjt._offsets] = (lofs,) |
| 136 | + return shapes |
0 commit comments