Skip to content

Commit 1e6e30d

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
register custom_op for fpEBC (#2067)
Summary: Pull Request resolved: #2067 # context * convert `FeatureProcessedEmbeddingBagCollection` to custom op in IR export * add serialization and deserialization function for FPEBC * add an API for the `FeatureProcessorInterface` to export necessary paramters for create an instance * use this API (`get_init_kwargs`) in the serialize and deserialize functions to flatten and unflatten the feature processor # details 1. Added `FPEBCMetadata` schema for FP_EBC, use a `fp_json` string to store the necessary paramters 2. Added `FPEBCJsonSerializer`, converted the init_kwargs to json string and store in the `fp_json` field in the metadata 3. Added a fqn check for `serialized_fqns`, so that when a higher-level module is serialized, the lower-level module can be skipped (it's already included in the higher-level module) 4. Added an API called `get_init_kwargs` for `FeatureProcessorsCollection` and `FeatureProcessor`, and use a `FeatureProcessorNameMap` to map the classname to the feature processor class 5. Added `_non_strict_exporting_forward` function for FPEBC so that in non_strict IR export it goes to the custom_op logic Differential Revision: D57829276
1 parent 0ccc4b1 commit 1e6e30d

File tree

6 files changed

+366
-78
lines changed

6 files changed

+366
-78
lines changed

torchrec/ir/schema.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
# pyre-strict
99

1010
from dataclasses import dataclass
11-
from typing import List, Optional
11+
from typing import List, Optional, Tuple
1212

1313
from torchrec.modules.embedding_configs import DataType, PoolingType
1414

@@ -32,3 +32,19 @@ class EBCMetadata:
3232
tables: List[EmbeddingBagConfigMetadata]
3333
is_weighted: bool
3434
device: Optional[str]
35+
36+
37+
@dataclass
38+
class FPEBCMetadata:
39+
is_fp_collection: bool
40+
feature_list: List[str]
41+
42+
43+
@dataclass
44+
class PositionWeightedModuleMetadata:
45+
max_feature_length: int
46+
47+
48+
@dataclass
49+
class PositionWeightedModuleCollectionMetadata:
50+
max_feature_lengths: List[Tuple[str, int]]

torchrec/ir/serializer.py

Lines changed: 158 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,24 @@
1414
import torch
1515

1616
from torch import nn
17-
from torchrec.ir.schema import EBCMetadata, EmbeddingBagConfigMetadata
17+
from torchrec.ir.schema import (
18+
EBCMetadata,
19+
EmbeddingBagConfigMetadata,
20+
FPEBCMetadata,
21+
PositionWeightedModuleCollectionMetadata,
22+
PositionWeightedModuleMetadata,
23+
)
1824

1925
from torchrec.ir.types import SerializerInterface
2026
from torchrec.modules.embedding_configs import DataType, EmbeddingBagConfig, PoolingType
2127
from torchrec.modules.embedding_modules import EmbeddingBagCollection
28+
from torchrec.modules.feature_processor_ import (
29+
FeatureProcessor,
30+
FeatureProcessorsCollection,
31+
PositionWeightedModule,
32+
PositionWeightedModuleCollection,
33+
)
34+
from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection
2235

2336
logger: logging.Logger = logging.getLogger(__name__)
2437

@@ -71,7 +84,7 @@ def get_deserialized_device(
7184

7285
class EBCJsonSerializer(SerializerInterface):
7386
"""
74-
Serializer for torch.export IR using thrift.
87+
Serializer for torch.export IR using json.
7588
"""
7689

7790
@classmethod
@@ -132,13 +145,155 @@ def deserialize(
132145
)
133146

134147

148+
class PWMJsonSerializer(SerializerInterface):
149+
"""
150+
Serializer for torch.export IR using json.
151+
"""
152+
153+
@classmethod
154+
def serialize(cls, module: nn.Module) -> torch.Tensor:
155+
if not isinstance(module, PositionWeightedModule):
156+
raise ValueError(
157+
f"Expected module to be of type PositionWeightedModule, got {type(module)}"
158+
)
159+
metadata = PositionWeightedModuleMetadata(
160+
max_feature_length=module.position_weight.shape[0],
161+
)
162+
return torch.frombuffer(
163+
json.dumps(metadata.__dict__).encode(), dtype=torch.uint8
164+
)
165+
166+
@classmethod
167+
def deserialize(
168+
cls,
169+
input: torch.Tensor,
170+
typename: str,
171+
device: Optional[torch.device] = None,
172+
children: Dict[str, nn.Module] = {},
173+
) -> nn.Module:
174+
if typename != "PositionWeightedModule":
175+
raise ValueError(
176+
f"Expected typename to be PositionWeightedModule, got {typename}"
177+
)
178+
raw_bytes = input.numpy().tobytes()
179+
metadata = json.loads(raw_bytes)
180+
return PositionWeightedModule(metadata["max_feature_length"], device)
181+
182+
183+
class PWMCJsonSerializer(SerializerInterface):
184+
"""
185+
Serializer for torch.export IR using json.
186+
"""
187+
188+
@classmethod
189+
def serialize(cls, module: nn.Module) -> torch.Tensor:
190+
if not isinstance(module, PositionWeightedModuleCollection):
191+
raise ValueError(
192+
f"Expected module to be of type PositionWeightedModuleCollection, got {type(module)}"
193+
)
194+
metadata = PositionWeightedModuleCollectionMetadata(
195+
max_feature_lengths=[ # convert to list of tuples to preserve the order
196+
(feature, len) for feature, len in module.max_feature_lengths.items()
197+
],
198+
)
199+
return torch.frombuffer(
200+
json.dumps(metadata.__dict__).encode(), dtype=torch.uint8
201+
)
202+
203+
@classmethod
204+
def deserialize(
205+
cls,
206+
input: torch.Tensor,
207+
typename: str,
208+
device: Optional[torch.device] = None,
209+
children: Dict[str, nn.Module] = {},
210+
) -> nn.Module:
211+
if typename != "PositionWeightedModuleCollection":
212+
raise ValueError(
213+
f"Expected typename to be PositionWeightedModuleCollection, got {typename}"
214+
)
215+
raw_bytes = input.numpy().tobytes()
216+
metadata = PositionWeightedModuleCollectionMetadata(**json.loads(raw_bytes))
217+
max_feature_lengths = {
218+
feature: len for feature, len in metadata.max_feature_lengths
219+
}
220+
return PositionWeightedModuleCollection(max_feature_lengths, device)
221+
222+
223+
class FPEBCJsonSerializer(SerializerInterface):
224+
"""
225+
Serializer for torch.export IR using json.
226+
"""
227+
228+
@classmethod
229+
def requires_children(cls, typename: str) -> bool:
230+
return True
231+
232+
@classmethod
233+
def serialize(
234+
cls,
235+
module: nn.Module,
236+
) -> torch.Tensor:
237+
if not isinstance(module, FeatureProcessedEmbeddingBagCollection):
238+
raise ValueError(
239+
f"Expected module to be of type FeatureProcessedEmbeddingBagCollection, got {type(module)}"
240+
)
241+
elif isinstance(module._feature_processors, FeatureProcessorsCollection):
242+
metadata = FPEBCMetadata(
243+
is_fp_collection=True,
244+
feature_list=[],
245+
)
246+
else:
247+
metadata = FPEBCMetadata(
248+
is_fp_collection=False,
249+
feature_list=list(module._feature_processors.keys()),
250+
)
251+
252+
return torch.frombuffer(
253+
json.dumps(metadata.__dict__).encode(), dtype=torch.uint8
254+
)
255+
256+
@classmethod
257+
def deserialize(
258+
cls,
259+
input: torch.Tensor,
260+
typename: str,
261+
device: Optional[torch.device] = None,
262+
children: Dict[str, nn.Module] = {},
263+
) -> nn.Module:
264+
if typename != "FeatureProcessedEmbeddingBagCollection":
265+
raise ValueError(
266+
f"Expected typename to be EmbeddingBagCollection, got {typename}"
267+
)
268+
raw_bytes = input.numpy().tobytes()
269+
metadata = FPEBCMetadata(**json.loads(raw_bytes.decode()))
270+
if metadata.is_fp_collection:
271+
feature_processors = children["_feature_processors"]
272+
assert isinstance(feature_processors, FeatureProcessorsCollection)
273+
else:
274+
feature_processors: dict[str, FeatureProcessor] = {}
275+
for feature in metadata.feature_list:
276+
fp = children[f"_feature_processors.{feature}"]
277+
assert isinstance(fp, FeatureProcessor)
278+
feature_processors[feature] = fp
279+
ebc = children["_embedding_bag_collection"]
280+
assert isinstance(ebc, EmbeddingBagCollection)
281+
return FeatureProcessedEmbeddingBagCollection(
282+
ebc,
283+
feature_processors,
284+
)
285+
286+
135287
class JsonSerializer(SerializerInterface):
136288
"""
137-
Serializer for torch.export IR using thrift.
289+
Serializer for torch.export IR using json.
138290
"""
139291

140292
module_to_serializer_cls: Dict[str, Type[SerializerInterface]] = {
141293
"EmbeddingBagCollection": EBCJsonSerializer,
294+
"FeatureProcessedEmbeddingBagCollection": FPEBCJsonSerializer,
295+
"PositionWeightedModule": PWMJsonSerializer,
296+
"PositionWeightedModuleCollection": PWMCJsonSerializer,
142297
}
143298

144299
@classmethod

0 commit comments

Comments
 (0)