|
9 | 9 |
|
10 | 10 | #!/usr/bin/env python3
|
11 | 11 |
|
12 |
| -from typing import List, Optional, Tuple, Type |
| 12 | +from collections import defaultdict |
| 13 | +from typing import Dict, List, Optional, Tuple, Type, Union |
13 | 14 |
|
14 | 15 | import torch
|
15 | 16 |
|
16 | 17 | from torch import nn
|
17 |
| -from torch.export.exported_program import ExportedProgram |
| 18 | +from torch.export import Dim, ExportedProgram, ShapesCollection |
| 19 | +from torch.export.dynamic_shapes import _Dim as DIM |
| 20 | +from torchrec import KeyedJaggedTensor |
18 | 21 | from torchrec.ir.types import SerializerInterface
|
19 | 22 |
|
20 | 23 |
|
21 | 24 | # TODO: Replace the default interface with the python dataclass interface
|
22 | 25 | DEFAULT_SERIALIZER_CLS = SerializerInterface
|
| 26 | +DYNAMIC_DIMS: Dict[str, int] = defaultdict(int) |
23 | 27 |
|
24 | 28 |
|
25 | 29 | def serialize_embedding_modules(
|
@@ -88,3 +92,63 @@ def deserialize_embedding_modules(
|
88 | 92 | setattr(parent, attrs[-1], new_module)
|
89 | 93 |
|
90 | 94 | return model
|
| 95 | + |
| 96 | + |
| 97 | +def _get_dim(x: Union[DIM, str, None], s: str) -> DIM: |
| 98 | + if isinstance(x, DIM): |
| 99 | + return x |
| 100 | + elif isinstance(x, str): |
| 101 | + if x in DYNAMIC_DIMS: |
| 102 | + DYNAMIC_DIMS[x] += 1 |
| 103 | + x += str(DYNAMIC_DIMS[x]) |
| 104 | + dim = Dim(x) |
| 105 | + else: |
| 106 | + DYNAMIC_DIMS[s] += 1 |
| 107 | + dim = Dim(s + str(DYNAMIC_DIMS[s])) |
| 108 | + return dim |
| 109 | + |
| 110 | + |
| 111 | +def mark_dynamic_kjt( |
| 112 | + kjt: KeyedJaggedTensor, |
| 113 | + shapes_collection: Optional[ShapesCollection] = None, |
| 114 | + variable_length: bool = False, |
| 115 | + vlen: Optional[Union[DIM, str]] = None, |
| 116 | + llen: Optional[Union[DIM, str]] = None, |
| 117 | + lofs: Optional[Union[DIM, str]] = None, |
| 118 | +) -> ShapesCollection: |
| 119 | + """ |
| 120 | + Makes the given KJT dynamic. If it's not variable length, it will only have |
| 121 | + one dynamic dimension, which is the length of the values (and weights). |
| 122 | + If it is variable length, then the lengths and offsets will be dynamic. |
| 123 | +
|
| 124 | + If a shapes collection is provided, it will be updated with the new shapes, |
| 125 | + otherwise a new shapes collection will be created. A passed-in shapes_collection is |
| 126 | + useful if you have multiple KJTs or other dynamic shapes that you want to trace. |
| 127 | +
|
| 128 | + If a dynamic dim/name is provided, it will directly use that dim/name. Otherwise, |
| 129 | + it will use the default name "vlen" for values, and "llen", "lofs" if variable length. |
| 130 | + A passed-in dynamic dim is useful if the dynamic dim is already used in other places. |
| 131 | +
|
| 132 | + Args: |
| 133 | + kjt (KeyedJaggedTensor): The KJT to make dynamic. |
| 134 | + shapes_collection (Optional[ShapesCollection]): The collection to update. |
| 135 | + variable_length (bool): Whether the KJT is variable length. |
| 136 | + vlen (Optional[Union[DIM, str]]): The dynamic length for the values. |
| 137 | + llen (Optional[Union[DIM, str]]): The dynamic length for the lengths. |
| 138 | + lofs (Optional[Union[DIM, str]]): The dynamic length for the offsets. |
| 139 | + """ |
| 140 | + global DYNAMIC_DIMS |
| 141 | + if shapes_collection is None: |
| 142 | + shapes_collection = ShapesCollection() |
| 143 | + vlen = _get_dim(vlen, "vlen") |
| 144 | + shapes_collection[kjt._values] = (vlen,) |
| 145 | + if kjt._weights is not None: |
| 146 | + shapes_collection[kjt._weights] = (vlen,) |
| 147 | + if variable_length: |
| 148 | + llen = _get_dim(llen, "llen") |
| 149 | + lofs = _get_dim(lofs, "lofs") |
| 150 | + if kjt._lengths is not None: |
| 151 | + shapes_collection[kjt._lengths] = (llen,) |
| 152 | + if kjt._offsets is not None: |
| 153 | + shapes_collection[kjt._offsets] = (lofs,) |
| 154 | + return shapes_collection |
0 commit comments