Skip to content

Commit 4e41b4d

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
register custom_op for fpEBC (pytorch#2067)
Summary: # context * convert `FeatureProcessedEmbeddingBagCollection` to custom op in IR export * add serialization and deserialization function for FPEBC * add an API for the `FeatureProcessorInterface` to export necessary paramters for create an instance * use this API (`get_init_kwargs`) in the serialize and deserialize functions to flatten and unflatten the feature processor # details 1. Added `FPEBCMetadata` schema for FP_EBC, use a `fp_json` string to store the necessary paramters 2. Added `FPEBCJsonSerializer`, converted the init_kwargs to json string and store in the `fp_json` field in the metadata 3. Added a fqn check for `serialized_fqns`, so that when a higher-level module is serialized, the lower-level module can be skipped (it's already included in the higher-level module) 4. Added an API called `get_init_kwargs` for `FeatureProcessorsCollection` and `FeatureProcessor`, and use a `FeatureProcessorNameMap` to map the classname to the feature processor class 5. Added `_non_strict_exporting_forward` function for FPEBC so that in non_strict IR export it goes to the custom_op logic Differential Revision: D57829276
1 parent 3d79962 commit 4e41b4d

File tree

8 files changed

+382
-76
lines changed

8 files changed

+382
-76
lines changed

torchrec/ir/schema.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,12 @@ class EBCMetadata:
3232
tables: List[EmbeddingBagConfigMetadata]
3333
is_weighted: bool
3434
device: Optional[str]
35+
36+
37+
@dataclass
38+
class FPEBCMetadata:
39+
tables: List[EmbeddingBagConfigMetadata]
40+
is_weighted: bool
41+
device: Optional[str]
42+
fp_type: str
43+
fp_json: Optional[str]

torchrec/ir/serializer.py

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,17 @@
1414
import torch
1515

1616
from torch import nn
17-
from torchrec.ir.schema import EBCMetadata, EmbeddingBagConfigMetadata
17+
from torchrec.ir.schema import EBCMetadata, EmbeddingBagConfigMetadata, FPEBCMetadata
1818

1919
from torchrec.ir.types import SerializerInterface
2020
from torchrec.modules.embedding_configs import DataType, EmbeddingBagConfig, PoolingType
2121
from torchrec.modules.embedding_modules import EmbeddingBagCollection
22+
from torchrec.modules.feature_processor_ import (
23+
FeatureProcessor,
24+
FeatureProcessorNameMap,
25+
FeatureProcessorsCollection,
26+
)
27+
from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection
2228

2329
logger: logging.Logger = logging.getLogger(__name__)
2430

@@ -128,13 +134,113 @@ def deserialize(
128134
)
129135

130136

137+
class FPEBCJsonSerializer(SerializerInterface):
138+
"""
139+
Serializer for torch.export IR using thrift.
140+
"""
141+
142+
@classmethod
143+
def serialize(
144+
cls,
145+
module: nn.Module,
146+
) -> torch.Tensor:
147+
if not isinstance(module, FeatureProcessedEmbeddingBagCollection):
148+
raise ValueError(
149+
f"Expected module to be of type FeatureProcessedEmbeddingBagCollection, got {type(module)}"
150+
)
151+
if isinstance(module._feature_processors, nn.ModuleDict):
152+
fp_type = "dict"
153+
param_dict = {
154+
feature: processor.get_init_kwargs()
155+
for feature, processor in module._feature_processors.items()
156+
}
157+
type_dict = {
158+
feature: type(processor).__name__
159+
for feature, processor in module._feature_processors.items()
160+
}
161+
fp_json = json.dumps(
162+
{
163+
"param_dict": param_dict,
164+
"type_dict": type_dict,
165+
}
166+
)
167+
elif isinstance(module._feature_processors, FeatureProcessorsCollection):
168+
fp_type = type(module._feature_processors).__name__
169+
param_dict = module._feature_processors.get_init_kwargs()
170+
fp_json = json.dumps(param_dict)
171+
else:
172+
raise ValueError(
173+
f"Expected module._feature_processors to be of type dict or FeatureProcessorsCollection, got {type(module._feature_processors)}"
174+
)
175+
ebc = module._embedding_bag_collection
176+
ebc_metadata = FPEBCMetadata(
177+
tables=[
178+
embedding_bag_config_to_metadata(table_config)
179+
for table_config in ebc.embedding_bag_configs()
180+
],
181+
is_weighted=ebc.is_weighted(),
182+
device=str(ebc.device),
183+
fp_type=fp_type,
184+
fp_json=fp_json,
185+
)
186+
187+
ebc_metadata_dict = ebc_metadata.__dict__
188+
ebc_metadata_dict["tables"] = [
189+
table_config.__dict__ for table_config in ebc_metadata_dict["tables"]
190+
]
191+
192+
return torch.frombuffer(
193+
json.dumps(ebc_metadata_dict).encode(), dtype=torch.uint8
194+
)
195+
196+
@classmethod
197+
def deserialize(
198+
cls, input: torch.Tensor, typename: str, device: Optional[torch.device] = None
199+
) -> nn.Module:
200+
if typename != "FeatureProcessedEmbeddingBagCollection":
201+
raise ValueError(
202+
f"Expected typename to be EmbeddingBagCollection, got {typename}"
203+
)
204+
205+
raw_bytes = input.numpy().tobytes()
206+
ebc_metadata_dict = json.loads(raw_bytes.decode())
207+
tables = [
208+
EmbeddingBagConfigMetadata(**table_config)
209+
for table_config in ebc_metadata_dict["tables"]
210+
]
211+
device = get_deserialized_device(ebc_metadata_dict.get("device"), device)
212+
ebc = EmbeddingBagCollection(
213+
tables=[
214+
embedding_metadata_to_config(table_config) for table_config in tables
215+
],
216+
is_weighted=ebc_metadata_dict["is_weighted"],
217+
device=device,
218+
)
219+
fp_dict = json.loads(ebc_metadata_dict["fp_json"])
220+
if ebc_metadata_dict["fp_type"] == "dict":
221+
feature_processors: Dict[str, FeatureProcessor] = {}
222+
for feature, fp_type in fp_dict["type_dict"].items():
223+
feature_processors[feature] = FeatureProcessorNameMap[fp_type](
224+
**fp_dict["param_dict"][feature]
225+
)
226+
else:
227+
feature_processors = FeatureProcessorNameMap[ebc_metadata_dict["fp_type"]](
228+
**fp_dict
229+
)
230+
return FeatureProcessedEmbeddingBagCollection(
231+
ebc,
232+
feature_processors,
233+
)
234+
235+
131236
class JsonSerializer(SerializerInterface):
132237
"""
133238
Serializer for torch.export IR using thrift.
134239
"""
135240

136241
module_to_serializer_cls: Dict[str, Type[SerializerInterface]] = {
137242
"EmbeddingBagCollection": EBCJsonSerializer,
243+
"FeatureProcessedEmbeddingBagCollection": FPEBCJsonSerializer,
138244
}
139245

140246
@classmethod

torchrec/ir/tests/test_serializer.py

Lines changed: 90 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -25,23 +25,28 @@
2525

2626
from torchrec.modules.embedding_configs import EmbeddingBagConfig
2727
from torchrec.modules.embedding_modules import EmbeddingBagCollection
28-
from torchrec.modules.feature_processor_ import PositionWeightedModuleCollection
28+
from torchrec.modules.feature_processor_ import (
29+
PositionWeightedModule,
30+
PositionWeightedModuleCollection,
31+
)
2932
from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection
3033
from torchrec.modules.utils import operator_registry_state
31-
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
34+
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
3235

3336

3437
class TestJsonSerializer(unittest.TestCase):
38+
# in the model we have 5 duplicated EBCs, 1 fpEBC with fpCollection, and 1 fpEBC with fpDict
3539
def generate_model(self) -> nn.Module:
3640
class Model(nn.Module):
37-
def __init__(self, ebc, fpebc):
41+
def __init__(self, ebc, fpebc1, fpebc2):
3842
super().__init__()
3943
self.ebc1 = ebc
4044
self.ebc2 = copy.deepcopy(ebc)
4145
self.ebc3 = copy.deepcopy(ebc)
4246
self.ebc4 = copy.deepcopy(ebc)
4347
self.ebc5 = copy.deepcopy(ebc)
44-
self.fpebc = fpebc
48+
self.fpebc1 = fpebc1
49+
self.fpebc2 = fpebc2
4550

4651
def forward(
4752
self,
@@ -53,22 +58,17 @@ def forward(
5358
kt4 = self.ebc4(features)
5459
kt5 = self.ebc5(features)
5560

56-
fpebc_res = self.fpebc(features)
61+
fpebc1_res = self.fpebc1(features)
62+
fpebc2_res = self.fpebc2(features)
5763
ebc_kt_vals = [kt.values() for kt in [kt1, kt2, kt3, kt4, kt5]]
58-
sparse_arch_vals = sum(ebc_kt_vals)
59-
sparse_arch_res = KeyedTensor(
60-
keys=kt1.keys(),
61-
values=sparse_arch_vals,
62-
length_per_key=kt1.length_per_key(),
63-
)
6464

65-
return KeyedTensor.regroup(
66-
[sparse_arch_res, fpebc_res], [["f1"], ["f2", "f3"]]
65+
return (
66+
ebc_kt_vals + list(fpebc1_res.values()) + list(fpebc2_res.values())
6767
)
6868

6969
tb1_config = EmbeddingBagConfig(
7070
name="t1",
71-
embedding_dim=4,
71+
embedding_dim=3,
7272
num_embeddings=10,
7373
feature_names=["f1"],
7474
)
@@ -80,7 +80,7 @@ def forward(
8080
)
8181
tb3_config = EmbeddingBagConfig(
8282
name="t3",
83-
embedding_dim=4,
83+
embedding_dim=5,
8484
num_embeddings=10,
8585
feature_names=["f3"],
8686
)
@@ -91,7 +91,7 @@ def forward(
9191
)
9292
max_feature_lengths = {"f1": 100, "f2": 100}
9393

94-
fpebc = FeatureProcessedEmbeddingBagCollection(
94+
fpebc1 = FeatureProcessedEmbeddingBagCollection(
9595
EmbeddingBagCollection(
9696
tables=[tb1_config, tb2_config],
9797
is_weighted=True,
@@ -100,8 +100,18 @@ def forward(
100100
max_feature_lengths=max_feature_lengths,
101101
),
102102
)
103+
fpebc2 = FeatureProcessedEmbeddingBagCollection(
104+
EmbeddingBagCollection(
105+
tables=[tb1_config, tb3_config],
106+
is_weighted=True,
107+
),
108+
{
109+
"f1": PositionWeightedModule(max_feature_length=10),
110+
"f3": PositionWeightedModule(max_feature_length=20),
111+
},
112+
)
103113

104-
model = Model(ebc, fpebc)
114+
model = Model(ebc, fpebc1, fpebc2)
105115

106116
return model
107117

@@ -132,12 +142,16 @@ def test_serialize_deserialize_ebc(self) -> None:
132142
for i, tensor in enumerate(ep_output):
133143
self.assertEqual(eager_out[i].shape, tensor.shape)
134144

135-
# Only 2 custom op registered, as dimensions of ebc are same
136-
self.assertEqual(len(operator_registry_state.op_registry_schema), 2)
145+
# Should have 3 custom op registered, as dimensions of ebc are same,
146+
# and two fpEBCs have different dims
147+
self.assertEqual(len(operator_registry_state.op_registry_schema), 3)
137148

138149
total_dim_ebc = sum(model.ebc1._lengths_per_embedding)
139-
total_dim_fpebc = sum(
140-
model.fpebc._embedding_bag_collection._lengths_per_embedding
150+
total_dim_fpebc1 = sum(
151+
model.fpebc1._embedding_bag_collection._lengths_per_embedding
152+
)
153+
total_dim_fpebc2 = sum(
154+
model.fpebc2._embedding_bag_collection._lengths_per_embedding
141155
)
142156
# Check if custom op is registered with the correct name
143157
# EmbeddingBagCollection type and total dim
@@ -146,7 +160,11 @@ def test_serialize_deserialize_ebc(self) -> None:
146160
in operator_registry_state.op_registry_schema
147161
)
148162
self.assertTrue(
149-
f"EmbeddingBagCollection_{total_dim_fpebc}"
163+
f"FeatureProcessedEmbeddingBagCollection_{total_dim_fpebc1}"
164+
in operator_registry_state.op_registry_schema
165+
)
166+
self.assertTrue(
167+
f"FeatureProcessedEmbeddingBagCollection_{total_dim_fpebc2}"
150168
in operator_registry_state.op_registry_schema
151169
)
152170

@@ -155,28 +173,68 @@ def test_serialize_deserialize_ebc(self) -> None:
155173
# Deserialize EBC
156174
deserialized_model = deserialize_embedding_modules(ep, JsonSerializer)
157175

176+
# check EBC config
158177
for i in range(5):
159178
ebc_name = f"ebc{i + 1}"
160-
assert isinstance(
179+
self.assertIsInstance(
161180
getattr(deserialized_model, ebc_name), EmbeddingBagCollection
162181
)
163182

164-
for deserialized_config, org_config in zip(
183+
for deserialized, orginal in zip(
165184
getattr(deserialized_model, ebc_name).embedding_bag_configs(),
166185
getattr(model, ebc_name).embedding_bag_configs(),
167186
):
168-
assert deserialized_config.name == org_config.name
169-
assert deserialized_config.embedding_dim == org_config.embedding_dim
170-
assert deserialized_config.num_embeddings, org_config.num_embeddings
171-
assert deserialized_config.feature_names, org_config.feature_names
187+
self.assertEqual(deserialized.name, orginal.name)
188+
self.assertEqual(deserialized.embedding_dim, orginal.embedding_dim)
189+
self.assertEqual(deserialized.num_embeddings, orginal.num_embeddings)
190+
self.assertEqual(deserialized.feature_names, orginal.feature_names)
191+
192+
# check FPEBC config
193+
for i in range(2):
194+
fpebc_name = f"fpebc{i + 1}"
195+
assert isinstance(
196+
getattr(deserialized_model, fpebc_name),
197+
FeatureProcessedEmbeddingBagCollection,
198+
)
199+
200+
deserialized_fp = getattr(
201+
deserialized_model, fpebc_name
202+
)._feature_processors
203+
original_fp = getattr(model, fpebc_name)._feature_processors
204+
if isinstance(original_fp, nn.ModuleDict):
205+
for deserialized, orginal in zip(
206+
deserialized_fp.values(), original_fp.values()
207+
):
208+
self.assertDictEqual(
209+
deserialized.get_init_kwargs(), orginal.get_init_kwargs()
210+
)
211+
else:
212+
self.assertDictEqual(
213+
deserialized_fp.get_init_kwargs(), original_fp.get_init_kwargs()
214+
)
215+
216+
for deserialized, orginal in zip(
217+
getattr(
218+
deserialized_model, fpebc_name
219+
)._embedding_bag_collection.embedding_bag_configs(),
220+
getattr(
221+
model, fpebc_name
222+
)._embedding_bag_collection.embedding_bag_configs(),
223+
):
224+
self.assertEqual(deserialized.name, orginal.name)
225+
self.assertEqual(deserialized.embedding_dim, orginal.embedding_dim)
226+
self.assertEqual(deserialized.num_embeddings, orginal.num_embeddings)
227+
self.assertEqual(deserialized.feature_names, orginal.feature_names)
172228

173229
deserialized_model.load_state_dict(model.state_dict())
174-
# Run forward on deserialized model
230+
231+
# Run forward on deserialized model and compare the output
175232
deserialized_out = deserialized_model(id_list_features)
176233

177-
for i, tensor in enumerate(deserialized_out):
178-
assert eager_out[i].shape == tensor.shape
179-
assert torch.allclose(eager_out[i], tensor)
234+
self.assertEqual(len(deserialized_out), len(eager_out))
235+
for deserialized, orginal in zip(deserialized_out, eager_out):
236+
self.assertEqual(deserialized.shape, orginal.shape)
237+
self.assertTrue(torch.allclose(deserialized, orginal))
180238

181239
def test_dynamic_shape_ebc(self) -> None:
182240
model = self.generate_model()

torchrec/ir/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,14 @@ def serialize_embedding_modules(
3737
Returns the modified module and the list of fqns that had the buffer added.
3838
"""
3939
preserve_fqns = []
40+
serialized_fqns = set()
4041
for fqn, module in model.named_modules():
4142
if type(module).__name__ in serializer_cls.module_to_serializer_cls:
43+
# this avoid serializing the submodule within a module that is already serialized
44+
if any(fqn.startswith(s_fqn) for s_fqn in serialized_fqns):
45+
continue
46+
else:
47+
serialized_fqns.add(fqn)
4248
serialized_module = serializer_cls.serialize(module)
4349
module.register_buffer("ir_metadata", serialized_module, persistent=False)
4450
preserve_fqns.append(fqn)

0 commit comments

Comments
 (0)