Skip to content

Commit 1e4a525

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
register custom_op for fpEBC
Summary: # context * convert `FeatureProcessedEmbeddingBagCollection` to custom op in IR export * add serialization and deserialization function for FPEBC * add an API for the `FeatureProcessorInterface` to export necessary paramters for create an instance * use this API (`get_init_kwargs`) in the serialize and deserialize functions to flatten and unflatten the feature processor # details 1. Added `FPEBCMetadata` schema for FP_EBC, use a `fp_json` string to store the necessary paramters 2. Added `FPEBCJsonSerializer`, converted the init_kwargs to json string and store in the `fp_json` field in the metadata 3. Added a fqn check for `serialized_fqns`, so that when a higher-level module is serialized, the lower-level module can be skipped (it's already included in the higher-level module) 4. Added an API called `get_init_kwargs` for `FeatureProcessorsCollection` and `FeatureProcessor`, and use a `FeatureProcessorNameMap` to map the classname to the feature processor class 5. Added `_non_strict_exporting_forward` function for FPEBC so that in non_strict IR export it goes to the custom_op logic Differential Revision: D57829276
1 parent df78731 commit 1e4a525

File tree

8 files changed

+370
-74
lines changed

8 files changed

+370
-74
lines changed

torchrec/ir/schema.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,12 @@ class EBCMetadata:
3232
tables: List[EmbeddingBagConfigMetadata]
3333
is_weighted: bool
3434
device: Optional[str]
35+
36+
37+
@dataclass
38+
class FPEBCMetadata:
39+
tables: List[EmbeddingBagConfigMetadata]
40+
is_weighted: bool
41+
device: Optional[str]
42+
fp_type: str
43+
fp_json: Optional[str]

torchrec/ir/serializer.py

Lines changed: 111 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,17 @@
1414
import torch
1515

1616
from torch import nn
17-
from torchrec.ir.schema import EBCMetadata, EmbeddingBagConfigMetadata
17+
from torchrec.ir.schema import EBCMetadata, EmbeddingBagConfigMetadata, FPEBCMetadata
1818

1919
from torchrec.ir.types import SerializerInterface
2020
from torchrec.modules.embedding_configs import DataType, EmbeddingBagConfig, PoolingType
2121
from torchrec.modules.embedding_modules import EmbeddingBagCollection
22+
from torchrec.modules.feature_processor_ import (
23+
FeatureProcessor,
24+
FeatureProcessorNameMap,
25+
FeatureProcessorsCollection,
26+
)
27+
from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection
2228

2329
logger: logging.Logger = logging.getLogger(__name__)
2430

@@ -128,13 +134,117 @@ def deserialize(
128134
)
129135

130136

137+
class FPEBCJsonSerializer(SerializerInterface):
138+
"""
139+
Serializer for torch.export IR using thrift.
140+
"""
141+
142+
@classmethod
143+
def serialize(
144+
cls,
145+
module: nn.Module,
146+
) -> torch.Tensor:
147+
if not isinstance(module, FeatureProcessedEmbeddingBagCollection):
148+
raise ValueError(
149+
f"Expected module to be of type FeatureProcessedEmbeddingBagCollection, got {type(module)}"
150+
)
151+
if isinstance(module._feature_processors, dict):
152+
fp_type = "dict"
153+
param_dict = {
154+
feature: processor.get_init_kwargs()
155+
for feature, processor in module._feature_processors.items()
156+
}
157+
type_dict = {
158+
feature: type(processor).__name__
159+
for feature, processor in module._feature_processors.items()
160+
}
161+
fp_json = json.dumps(
162+
{
163+
"param_dict": param_dict,
164+
"type_dict": type_dict,
165+
}
166+
)
167+
elif isinstance(module._feature_processors, FeatureProcessorsCollection):
168+
fp_type = type(module._feature_processors).__name__
169+
param_dict = module._feature_processors.get_init_kwargs()
170+
fp_json = json.dumps(param_dict)
171+
else:
172+
raise ValueError(
173+
f"Expected module._feature_processors to be of type dict or FeatureProcessorsCollection, got {type(module)}"
174+
)
175+
ebc = module._embedding_bag_collection
176+
ebc_metadata = FPEBCMetadata(
177+
tables=[
178+
embedding_bag_config_to_metadata(table_config)
179+
for table_config in ebc.embedding_bag_configs()
180+
],
181+
is_weighted=ebc.is_weighted(),
182+
device=str(ebc.device),
183+
fp_type=fp_type,
184+
fp_json=fp_json,
185+
)
186+
187+
ebc_metadata_dict = ebc_metadata.__dict__
188+
ebc_metadata_dict["tables"] = [
189+
table_config.__dict__ for table_config in ebc_metadata_dict["tables"]
190+
]
191+
192+
return torch.frombuffer(
193+
json.dumps(ebc_metadata_dict).encode(), dtype=torch.uint8
194+
)
195+
196+
@classmethod
197+
def deserialize(
198+
cls, input: torch.Tensor, typename: str, device: Optional[torch.device] = None
199+
) -> nn.Module:
200+
if typename != "FeatureProcessedEmbeddingBagCollection":
201+
raise ValueError(
202+
f"Expected typename to be EmbeddingBagCollection, got {typename}"
203+
)
204+
205+
raw_bytes = input.numpy().tobytes()
206+
ebc_metadata_dict = json.loads(raw_bytes.decode())
207+
tables = [
208+
EmbeddingBagConfigMetadata(**table_config)
209+
for table_config in ebc_metadata_dict["tables"]
210+
]
211+
device = get_deserialized_device(ebc_metadata_dict.get("device"), device)
212+
ebc = EmbeddingBagCollection(
213+
tables=[
214+
embedding_metadata_to_config(table_config) for table_config in tables
215+
],
216+
is_weighted=ebc_metadata_dict["is_weighted"],
217+
device=device,
218+
)
219+
fp_dict = json.loads(ebc_metadata_dict["fp_json"])
220+
if isinstance(ebc_metadata_dict["fp_type"], dict):
221+
feature_processors: Dict[str, FeatureProcessor] = {}
222+
for feature, fp_type in fp_dict["type_dict"].items():
223+
feature_processors[feature] = FeatureProcessorNameMap[fp_type](
224+
**fp_dict["param_dict"][feature]
225+
)
226+
elif isinstance(ebc_metadata_dict["fp_type"], str):
227+
feature_processors = FeatureProcessorNameMap[ebc_metadata_dict["fp_type"]](
228+
**fp_dict
229+
)
230+
else:
231+
raise ValueError(
232+
f"Expected ebc_metadata_dict['fp_type'] to be of type str or dict, got {ebc_metadata_dict['fp_type']}"
233+
)
234+
return FeatureProcessedEmbeddingBagCollection(
235+
ebc,
236+
feature_processors,
237+
)
238+
239+
131240
class JsonSerializer(SerializerInterface):
132241
"""
133242
Serializer for torch.export IR using thrift.
134243
"""
135244

136245
module_to_serializer_cls: Dict[str, Type[SerializerInterface]] = {
137246
"EmbeddingBagCollection": EBCJsonSerializer,
247+
"FeatureProcessedEmbeddingBagCollection": FPEBCJsonSerializer,
138248
}
139249

140250
@classmethod

torchrec/ir/tests/test_serializer.py

Lines changed: 72 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -28,20 +28,21 @@
2828
from torchrec.modules.feature_processor_ import PositionWeightedModuleCollection
2929
from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection
3030
from torchrec.modules.utils import operator_registry_state
31-
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
31+
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
3232

3333

3434
class TestJsonSerializer(unittest.TestCase):
3535
def generate_model(self) -> nn.Module:
3636
class Model(nn.Module):
37-
def __init__(self, ebc, fpebc):
37+
def __init__(self, ebc, fpebc1, fpebc2):
3838
super().__init__()
3939
self.ebc1 = ebc
4040
self.ebc2 = copy.deepcopy(ebc)
4141
self.ebc3 = copy.deepcopy(ebc)
4242
self.ebc4 = copy.deepcopy(ebc)
4343
self.ebc5 = copy.deepcopy(ebc)
44-
self.fpebc = fpebc
44+
self.fpebc1 = fpebc1
45+
self.fpebc2 = fpebc2
4546

4647
def forward(
4748
self,
@@ -53,22 +54,17 @@ def forward(
5354
kt4 = self.ebc4(features)
5455
kt5 = self.ebc5(features)
5556

56-
fpebc_res = self.fpebc(features)
57+
fpebc1_res = self.fpebc1(features)
58+
fpebc2_res = self.fpebc2(features)
5759
ebc_kt_vals = [kt.values() for kt in [kt1, kt2, kt3, kt4, kt5]]
58-
sparse_arch_vals = sum(ebc_kt_vals)
59-
sparse_arch_res = KeyedTensor(
60-
keys=kt1.keys(),
61-
values=sparse_arch_vals,
62-
length_per_key=kt1.length_per_key(),
63-
)
6460

65-
return KeyedTensor.regroup(
66-
[sparse_arch_res, fpebc_res], [["f1"], ["f2", "f3"]]
61+
return (
62+
ebc_kt_vals + list(fpebc1_res.values()) + list(fpebc2_res.values())
6763
)
6864

6965
tb1_config = EmbeddingBagConfig(
7066
name="t1",
71-
embedding_dim=4,
67+
embedding_dim=3,
7268
num_embeddings=10,
7369
feature_names=["f1"],
7470
)
@@ -80,7 +76,7 @@ def forward(
8076
)
8177
tb3_config = EmbeddingBagConfig(
8278
name="t3",
83-
embedding_dim=4,
79+
embedding_dim=5,
8480
num_embeddings=10,
8581
feature_names=["f3"],
8682
)
@@ -91,7 +87,7 @@ def forward(
9187
)
9288
max_feature_lengths = {"f1": 100, "f2": 100}
9389

94-
fpebc = FeatureProcessedEmbeddingBagCollection(
90+
fpebc1 = FeatureProcessedEmbeddingBagCollection(
9591
EmbeddingBagCollection(
9692
tables=[tb1_config, tb2_config],
9793
is_weighted=True,
@@ -100,8 +96,15 @@ def forward(
10096
max_feature_lengths=max_feature_lengths,
10197
),
10298
)
99+
fpebc2 = FeatureProcessedEmbeddingBagCollection(
100+
EmbeddingBagCollection(
101+
tables=[tb1_config, tb3_config],
102+
is_weighted=True,
103+
),
104+
PositionWeightedModuleCollection({"f1": 10, "f3": 20}),
105+
)
103106

104-
model = Model(ebc, fpebc)
107+
model = Model(ebc, fpebc1, fpebc2)
105108

106109
return model
107110

@@ -133,11 +136,14 @@ def test_serialize_deserialize_ebc(self) -> None:
133136
self.assertEqual(eager_out[i].shape, tensor.shape)
134137

135138
# Only 2 custom op registered, as dimensions of ebc are same
136-
self.assertEqual(len(operator_registry_state.op_registry_schema), 2)
139+
self.assertEqual(len(operator_registry_state.op_registry_schema), 3)
137140

138141
total_dim_ebc = sum(model.ebc1._lengths_per_embedding)
139-
total_dim_fpebc = sum(
140-
model.fpebc._embedding_bag_collection._lengths_per_embedding
142+
total_dim_fpebc1 = sum(
143+
model.fpebc1._embedding_bag_collection._lengths_per_embedding
144+
)
145+
total_dim_fpebc2 = sum(
146+
model.fpebc2._embedding_bag_collection._lengths_per_embedding
141147
)
142148
# Check if custom op is registered with the correct name
143149
# EmbeddingBagCollection type and total dim
@@ -146,7 +152,11 @@ def test_serialize_deserialize_ebc(self) -> None:
146152
in operator_registry_state.op_registry_schema
147153
)
148154
self.assertTrue(
149-
f"EmbeddingBagCollection_{total_dim_fpebc}"
155+
f"FeatureProcessedEmbeddingBagCollection_{total_dim_fpebc1}"
156+
in operator_registry_state.op_registry_schema
157+
)
158+
self.assertTrue(
159+
f"FeatureProcessedEmbeddingBagCollection_{total_dim_fpebc2}"
150160
in operator_registry_state.op_registry_schema
151161
)
152162

@@ -155,28 +165,60 @@ def test_serialize_deserialize_ebc(self) -> None:
155165
# Deserialize EBC
156166
deserialized_model = deserialize_embedding_modules(ep, JsonSerializer)
157167

168+
# check EBC config
158169
for i in range(5):
159170
ebc_name = f"ebc{i + 1}"
160-
assert isinstance(
171+
self.assertIsInstance(
161172
getattr(deserialized_model, ebc_name), EmbeddingBagCollection
162173
)
163174

164-
for deserialized_config, org_config in zip(
175+
for deserialized, orginal in zip(
165176
getattr(deserialized_model, ebc_name).embedding_bag_configs(),
166177
getattr(model, ebc_name).embedding_bag_configs(),
167178
):
168-
assert deserialized_config.name == org_config.name
169-
assert deserialized_config.embedding_dim == org_config.embedding_dim
170-
assert deserialized_config.num_embeddings, org_config.num_embeddings
171-
assert deserialized_config.feature_names, org_config.feature_names
179+
self.assertEqual(deserialized.name, orginal.name)
180+
self.assertEqual(deserialized.embedding_dim, orginal.embedding_dim)
181+
self.assertEqual(deserialized.num_embeddings, orginal.num_embeddings)
182+
self.assertEqual(deserialized.feature_names, orginal.feature_names)
183+
184+
# check FPEBC config
185+
for i in range(2):
186+
fpebc_name = f"fpebc{i + 1}"
187+
assert isinstance(
188+
getattr(deserialized_model, fpebc_name),
189+
FeatureProcessedEmbeddingBagCollection,
190+
)
191+
192+
deserialized_kwargs = getattr(
193+
deserialized_model, fpebc_name
194+
)._feature_processors.get_init_kwargs()
195+
orginal_kwargs = getattr(
196+
model, fpebc_name
197+
)._feature_processors.get_init_kwargs()
198+
self.assertDictEqual(deserialized_kwargs, orginal_kwargs)
199+
200+
for deserialized, orginal in zip(
201+
getattr(
202+
deserialized_model, fpebc_name
203+
)._embedding_bag_collection.embedding_bag_configs(),
204+
getattr(
205+
model, fpebc_name
206+
)._embedding_bag_collection.embedding_bag_configs(),
207+
):
208+
self.assertEqual(deserialized.name, orginal.name)
209+
self.assertEqual(deserialized.embedding_dim, orginal.embedding_dim)
210+
self.assertEqual(deserialized.num_embeddings, orginal.num_embeddings)
211+
self.assertEqual(deserialized.feature_names, orginal.feature_names)
172212

173213
deserialized_model.load_state_dict(model.state_dict())
174-
# Run forward on deserialized model
214+
215+
# Run forward on deserialized model and compare the output
175216
deserialized_out = deserialized_model(id_list_features)
176217

177-
for i, tensor in enumerate(deserialized_out):
178-
assert eager_out[i].shape == tensor.shape
179-
assert torch.allclose(eager_out[i], tensor)
218+
self.assertEqual(len(deserialized_out), len(eager_out))
219+
for deserialized, orginal in zip(deserialized_out, eager_out):
220+
self.assertEqual(deserialized.shape, orginal.shape)
221+
self.assertTrue(torch.allclose(deserialized, orginal))
180222

181223
def test_dynamic_shape_ebc(self) -> None:
182224
model = self.generate_model()

torchrec/ir/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,14 @@ def serialize_embedding_modules(
3737
Returns the modified module and the list of fqns that had the buffer added.
3838
"""
3939
preserve_fqns = []
40+
serialized_fqns = set()
4041
for fqn, module in model.named_modules():
4142
if type(module).__name__ in serializer_cls.module_to_serializer_cls:
43+
# this avoid serializing the submodule within a module that is already serialized
44+
if any(fqn.startswith(s_fqn) for s_fqn in serialized_fqns):
45+
continue
46+
else:
47+
serialized_fqns.add(fqn)
4248
serialized_module = serializer_cls.serialize(module)
4349
module.register_buffer("ir_metadata", serialized_module, persistent=False)
4450
preserve_fqns.append(fqn)

0 commit comments

Comments
 (0)