28
28
from torchrec .modules .feature_processor_ import PositionWeightedModuleCollection
29
29
from torchrec .modules .fp_embedding_modules import FeatureProcessedEmbeddingBagCollection
30
30
from torchrec .modules .utils import operator_registry_state
31
- from torchrec .sparse .jagged_tensor import KeyedJaggedTensor , KeyedTensor
31
+ from torchrec .sparse .jagged_tensor import KeyedJaggedTensor
32
32
33
33
34
34
class TestJsonSerializer (unittest .TestCase ):
35
35
def generate_model (self ) -> nn .Module :
36
36
class Model (nn .Module ):
37
- def __init__ (self , ebc , fpebc ):
37
+ def __init__ (self , ebc , fpebc1 , fpebc2 ):
38
38
super ().__init__ ()
39
39
self .ebc1 = ebc
40
40
self .ebc2 = copy .deepcopy (ebc )
41
41
self .ebc3 = copy .deepcopy (ebc )
42
42
self .ebc4 = copy .deepcopy (ebc )
43
43
self .ebc5 = copy .deepcopy (ebc )
44
- self .fpebc = fpebc
44
+ self .fpebc1 = fpebc1
45
+ self .fpebc2 = fpebc2
45
46
46
47
def forward (
47
48
self ,
@@ -53,22 +54,17 @@ def forward(
53
54
kt4 = self .ebc4 (features )
54
55
kt5 = self .ebc5 (features )
55
56
56
- fpebc_res = self .fpebc (features )
57
+ fpebc1_res = self .fpebc1 (features )
58
+ fpebc2_res = self .fpebc2 (features )
57
59
ebc_kt_vals = [kt .values () for kt in [kt1 , kt2 , kt3 , kt4 , kt5 ]]
58
- sparse_arch_vals = sum (ebc_kt_vals )
59
- sparse_arch_res = KeyedTensor (
60
- keys = kt1 .keys (),
61
- values = sparse_arch_vals ,
62
- length_per_key = kt1 .length_per_key (),
63
- )
64
60
65
- return KeyedTensor . regroup (
66
- [ sparse_arch_res , fpebc_res ], [[ "f1" ], [ "f2" , "f3" ]]
61
+ return (
62
+ ebc_kt_vals + list ( fpebc1_res . values ()) + list ( fpebc2_res . values ())
67
63
)
68
64
69
65
tb1_config = EmbeddingBagConfig (
70
66
name = "t1" ,
71
- embedding_dim = 4 ,
67
+ embedding_dim = 3 ,
72
68
num_embeddings = 10 ,
73
69
feature_names = ["f1" ],
74
70
)
@@ -80,7 +76,7 @@ def forward(
80
76
)
81
77
tb3_config = EmbeddingBagConfig (
82
78
name = "t3" ,
83
- embedding_dim = 4 ,
79
+ embedding_dim = 5 ,
84
80
num_embeddings = 10 ,
85
81
feature_names = ["f3" ],
86
82
)
@@ -91,7 +87,7 @@ def forward(
91
87
)
92
88
max_feature_lengths = {"f1" : 100 , "f2" : 100 }
93
89
94
- fpebc = FeatureProcessedEmbeddingBagCollection (
90
+ fpebc1 = FeatureProcessedEmbeddingBagCollection (
95
91
EmbeddingBagCollection (
96
92
tables = [tb1_config , tb2_config ],
97
93
is_weighted = True ,
@@ -100,8 +96,15 @@ def forward(
100
96
max_feature_lengths = max_feature_lengths ,
101
97
),
102
98
)
99
+ fpebc2 = FeatureProcessedEmbeddingBagCollection (
100
+ EmbeddingBagCollection (
101
+ tables = [tb1_config , tb3_config ],
102
+ is_weighted = True ,
103
+ ),
104
+ PositionWeightedModuleCollection ({"f1" : 10 , "f3" : 20 }),
105
+ )
103
106
104
- model = Model (ebc , fpebc )
107
+ model = Model (ebc , fpebc1 , fpebc2 )
105
108
106
109
return model
107
110
@@ -133,11 +136,14 @@ def test_serialize_deserialize_ebc(self) -> None:
133
136
self .assertEqual (eager_out [i ].shape , tensor .shape )
134
137
135
138
# Only 2 custom op registered, as dimensions of ebc are same
136
- self .assertEqual (len (operator_registry_state .op_registry_schema ), 2 )
139
+ self .assertEqual (len (operator_registry_state .op_registry_schema ), 3 )
137
140
138
141
total_dim_ebc = sum (model .ebc1 ._lengths_per_embedding )
139
- total_dim_fpebc = sum (
140
- model .fpebc ._embedding_bag_collection ._lengths_per_embedding
142
+ total_dim_fpebc1 = sum (
143
+ model .fpebc1 ._embedding_bag_collection ._lengths_per_embedding
144
+ )
145
+ total_dim_fpebc2 = sum (
146
+ model .fpebc2 ._embedding_bag_collection ._lengths_per_embedding
141
147
)
142
148
# Check if custom op is registered with the correct name
143
149
# EmbeddingBagCollection type and total dim
@@ -146,7 +152,11 @@ def test_serialize_deserialize_ebc(self) -> None:
146
152
in operator_registry_state .op_registry_schema
147
153
)
148
154
self .assertTrue (
149
- f"EmbeddingBagCollection_{ total_dim_fpebc } "
155
+ f"FeatureProcessedEmbeddingBagCollection_{ total_dim_fpebc1 } "
156
+ in operator_registry_state .op_registry_schema
157
+ )
158
+ self .assertTrue (
159
+ f"FeatureProcessedEmbeddingBagCollection_{ total_dim_fpebc2 } "
150
160
in operator_registry_state .op_registry_schema
151
161
)
152
162
@@ -155,28 +165,60 @@ def test_serialize_deserialize_ebc(self) -> None:
155
165
# Deserialize EBC
156
166
deserialized_model = deserialize_embedding_modules (ep , JsonSerializer )
157
167
168
+ # check EBC config
158
169
for i in range (5 ):
159
170
ebc_name = f"ebc{ i + 1 } "
160
- assert isinstance (
171
+ self . assertIsInstance (
161
172
getattr (deserialized_model , ebc_name ), EmbeddingBagCollection
162
173
)
163
174
164
- for deserialized_config , org_config in zip (
175
+ for deserialized , orginal in zip (
165
176
getattr (deserialized_model , ebc_name ).embedding_bag_configs (),
166
177
getattr (model , ebc_name ).embedding_bag_configs (),
167
178
):
168
- assert deserialized_config .name == org_config .name
169
- assert deserialized_config .embedding_dim == org_config .embedding_dim
170
- assert deserialized_config .num_embeddings , org_config .num_embeddings
171
- assert deserialized_config .feature_names , org_config .feature_names
179
+ self .assertEqual (deserialized .name , orginal .name )
180
+ self .assertEqual (deserialized .embedding_dim , orginal .embedding_dim )
181
+ self .assertEqual (deserialized .num_embeddings , orginal .num_embeddings )
182
+ self .assertEqual (deserialized .feature_names , orginal .feature_names )
183
+
184
+ # check FPEBC config
185
+ for i in range (2 ):
186
+ fpebc_name = f"fpebc{ i + 1 } "
187
+ assert isinstance (
188
+ getattr (deserialized_model , fpebc_name ),
189
+ FeatureProcessedEmbeddingBagCollection ,
190
+ )
191
+
192
+ deserialized_kwargs = getattr (
193
+ deserialized_model , fpebc_name
194
+ )._feature_processors .get_init_kwargs ()
195
+ orginal_kwargs = getattr (
196
+ model , fpebc_name
197
+ )._feature_processors .get_init_kwargs ()
198
+ self .assertDictEqual (deserialized_kwargs , orginal_kwargs )
199
+
200
+ for deserialized , orginal in zip (
201
+ getattr (
202
+ deserialized_model , fpebc_name
203
+ )._embedding_bag_collection .embedding_bag_configs (),
204
+ getattr (
205
+ model , fpebc_name
206
+ )._embedding_bag_collection .embedding_bag_configs (),
207
+ ):
208
+ self .assertEqual (deserialized .name , orginal .name )
209
+ self .assertEqual (deserialized .embedding_dim , orginal .embedding_dim )
210
+ self .assertEqual (deserialized .num_embeddings , orginal .num_embeddings )
211
+ self .assertEqual (deserialized .feature_names , orginal .feature_names )
172
212
173
213
deserialized_model .load_state_dict (model .state_dict ())
174
- # Run forward on deserialized model
214
+
215
+ # Run forward on deserialized model and compare the output
175
216
deserialized_out = deserialized_model (id_list_features )
176
217
177
- for i , tensor in enumerate (deserialized_out ):
178
- assert eager_out [i ].shape == tensor .shape
179
- assert torch .allclose (eager_out [i ], tensor )
218
+ self .assertEqual (len (deserialized_out ), len (eager_out ))
219
+ for deserialized , orginal in zip (deserialized_out , eager_out ):
220
+ self .assertEqual (deserialized .shape , orginal .shape )
221
+ self .assertTrue (torch .allclose (deserialized , orginal ))
180
222
181
223
def test_dynamic_shape_ebc (self ) -> None :
182
224
model = self .generate_model ()
0 commit comments