Skip to content

Commit 4020e16

Browse files
PaulZhang12facebook-github-bot
authored andcommitted
Introduce Serializer in TorchRec for torch.export PEA (#1848)
Summary: Pull Request resolved: #1848 Adding a serializer class to support thrift serialization of PEA configs for regenerating the eager module after torch.export through saving the metadata as a bytes buffer in the module. Reviewed By: dstaay-fb Differential Revision: D55661704 fbshipit-source-id: 6ea47e495eaad032f17ad6c5f45206259c8009b3
1 parent 849a24f commit 4020e16

File tree

2 files changed

+88
-0
lines changed

2 files changed

+88
-0
lines changed

torchrec/ir/types.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
#!/usr/bin/env python3
9+
10+
import abc
11+
from typing import Any, Dict, Type
12+
13+
from torch import nn
14+
15+
16+
class SerializerInterface(abc.ABC):
17+
"""
18+
Interface for Serializer classes for torch.export IR.
19+
"""
20+
21+
@classmethod
22+
@property
23+
# pyre-ignore [3]: Returning `None` but type `Any` is specified.
24+
def module_to_serializer_cls(cls) -> Dict[Type[nn.Module], Type[Any]]:
25+
raise NotImplementedError
26+
27+
@classmethod
28+
@abc.abstractmethod
29+
# pyre-ignore [3]: Returning `None` but type `Any` is specified.
30+
def serialize(
31+
cls,
32+
module: nn.Module,
33+
) -> Any:
34+
# Take the eager embedding module and generate bytes in buffer
35+
pass
36+
37+
@classmethod
38+
@abc.abstractmethod
39+
# pyre-ignore [2]: Parameter `input` must have a type other than `Any`.
40+
def deserialize(cls, input: Any) -> nn.Module:
41+
# Take the bytes in the buffer and regenerate the eager embedding module
42+
pass

torchrec/ir/utils.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
#!/usr/bin/env python3
9+
10+
from typing import Type
11+
12+
from torch import nn
13+
from torchrec.ir.types import SerializerInterface
14+
15+
16+
# TODO: Replace the default interface with the python dataclass interface
17+
DEFAULT_SERIALIZER_CLS = SerializerInterface
18+
19+
20+
def serialize_embedding_modules(
21+
model: nn.Module,
22+
serializer_cls: Type[SerializerInterface] = DEFAULT_SERIALIZER_CLS,
23+
) -> nn.Module:
24+
for _, module in model.named_modules():
25+
if type(module) in serializer_cls.module_to_serializer_cls:
26+
serialized_module = serializer_cls.serialize(module)
27+
module.register_buffer("ir_metadata", serialized_module, persistent=False)
28+
29+
return model
30+
31+
32+
def deserialize_embedding_modules(
33+
model: nn.Module,
34+
serializer_cls: Type[SerializerInterface] = DEFAULT_SERIALIZER_CLS,
35+
) -> nn.Module:
36+
fqn_to_new_module = {}
37+
for name, module in model.named_modules():
38+
if "ir_metadata" in dict(module.named_buffers()):
39+
serialized_module = dict(module.named_buffers())["ir_metadata"]
40+
deserialized_module = serializer_cls.deserialize(serialized_module)
41+
fqn_to_new_module[name] = deserialized_module
42+
43+
for fqn, new_module in fqn_to_new_module.items():
44+
setattr(model, fqn, new_module)
45+
46+
return model

0 commit comments

Comments
 (0)