25
25
26
26
from torchrec .modules .embedding_configs import EmbeddingBagConfig
27
27
from torchrec .modules .embedding_modules import EmbeddingBagCollection
28
- from torchrec .modules .feature_processor_ import PositionWeightedModuleCollection
28
+ from torchrec .modules .feature_processor_ import (
29
+ PositionWeightedModule ,
30
+ PositionWeightedModuleCollection ,
31
+ )
29
32
from torchrec .modules .fp_embedding_modules import FeatureProcessedEmbeddingBagCollection
30
33
from torchrec .modules .utils import operator_registry_state
31
- from torchrec .sparse .jagged_tensor import KeyedJaggedTensor , KeyedTensor
34
+ from torchrec .sparse .jagged_tensor import KeyedJaggedTensor
32
35
33
36
34
37
class TestJsonSerializer (unittest .TestCase ):
38
+ # in the model we have 5 duplicated EBCs, 1 fpEBC with fpCollection, and 1 fpEBC with fpDict
35
39
def generate_model (self ) -> nn .Module :
36
40
class Model (nn .Module ):
37
- def __init__ (self , ebc , fpebc ):
41
+ def __init__ (self , ebc , fpebc1 , fpebc2 ):
38
42
super ().__init__ ()
39
43
self .ebc1 = ebc
40
44
self .ebc2 = copy .deepcopy (ebc )
41
45
self .ebc3 = copy .deepcopy (ebc )
42
46
self .ebc4 = copy .deepcopy (ebc )
43
47
self .ebc5 = copy .deepcopy (ebc )
44
- self .fpebc = fpebc
48
+ self .fpebc1 = fpebc1
49
+ self .fpebc2 = fpebc2
45
50
46
51
def forward (
47
52
self ,
@@ -53,22 +58,17 @@ def forward(
53
58
kt4 = self .ebc4 (features )
54
59
kt5 = self .ebc5 (features )
55
60
56
- fpebc_res = self .fpebc (features )
61
+ fpebc1_res = self .fpebc1 (features )
62
+ fpebc2_res = self .fpebc2 (features )
57
63
ebc_kt_vals = [kt .values () for kt in [kt1 , kt2 , kt3 , kt4 , kt5 ]]
58
- sparse_arch_vals = sum (ebc_kt_vals )
59
- sparse_arch_res = KeyedTensor (
60
- keys = kt1 .keys (),
61
- values = sparse_arch_vals ,
62
- length_per_key = kt1 .length_per_key (),
63
- )
64
64
65
- return KeyedTensor . regroup (
66
- [ sparse_arch_res , fpebc_res ], [[ "f1" ], [ "f2" , "f3" ]]
65
+ return (
66
+ ebc_kt_vals + list ( fpebc1_res . values ()) + list ( fpebc2_res . values ())
67
67
)
68
68
69
69
tb1_config = EmbeddingBagConfig (
70
70
name = "t1" ,
71
- embedding_dim = 4 ,
71
+ embedding_dim = 3 ,
72
72
num_embeddings = 10 ,
73
73
feature_names = ["f1" ],
74
74
)
@@ -80,7 +80,7 @@ def forward(
80
80
)
81
81
tb3_config = EmbeddingBagConfig (
82
82
name = "t3" ,
83
- embedding_dim = 4 ,
83
+ embedding_dim = 5 ,
84
84
num_embeddings = 10 ,
85
85
feature_names = ["f3" ],
86
86
)
@@ -91,7 +91,7 @@ def forward(
91
91
)
92
92
max_feature_lengths = {"f1" : 100 , "f2" : 100 }
93
93
94
- fpebc = FeatureProcessedEmbeddingBagCollection (
94
+ fpebc1 = FeatureProcessedEmbeddingBagCollection (
95
95
EmbeddingBagCollection (
96
96
tables = [tb1_config , tb2_config ],
97
97
is_weighted = True ,
@@ -100,8 +100,18 @@ def forward(
100
100
max_feature_lengths = max_feature_lengths ,
101
101
),
102
102
)
103
+ fpebc2 = FeatureProcessedEmbeddingBagCollection (
104
+ EmbeddingBagCollection (
105
+ tables = [tb1_config , tb3_config ],
106
+ is_weighted = True ,
107
+ ),
108
+ {
109
+ "f1" : PositionWeightedModule (max_feature_length = 10 ),
110
+ "f3" : PositionWeightedModule (max_feature_length = 20 ),
111
+ },
112
+ )
103
113
104
- model = Model (ebc , fpebc )
114
+ model = Model (ebc , fpebc1 , fpebc2 )
105
115
106
116
return model
107
117
@@ -132,12 +142,16 @@ def test_serialize_deserialize_ebc(self) -> None:
132
142
for i , tensor in enumerate (ep_output ):
133
143
self .assertEqual (eager_out [i ].shape , tensor .shape )
134
144
135
- # Only 2 custom op registered, as dimensions of ebc are same
136
- self .assertEqual (len (operator_registry_state .op_registry_schema ), 2 )
145
+ # Should have 3 custom op registered, as dimensions of ebc are same,
146
+ # and two fpEBCs have different dims
147
+ self .assertEqual (len (operator_registry_state .op_registry_schema ), 3 )
137
148
138
149
total_dim_ebc = sum (model .ebc1 ._lengths_per_embedding )
139
- total_dim_fpebc = sum (
140
- model .fpebc ._embedding_bag_collection ._lengths_per_embedding
150
+ total_dim_fpebc1 = sum (
151
+ model .fpebc1 ._embedding_bag_collection ._lengths_per_embedding
152
+ )
153
+ total_dim_fpebc2 = sum (
154
+ model .fpebc2 ._embedding_bag_collection ._lengths_per_embedding
141
155
)
142
156
# Check if custom op is registered with the correct name
143
157
# EmbeddingBagCollection type and total dim
@@ -146,7 +160,11 @@ def test_serialize_deserialize_ebc(self) -> None:
146
160
in operator_registry_state .op_registry_schema
147
161
)
148
162
self .assertTrue (
149
- f"EmbeddingBagCollection_{ total_dim_fpebc } "
163
+ f"FeatureProcessedEmbeddingBagCollection_{ total_dim_fpebc1 } "
164
+ in operator_registry_state .op_registry_schema
165
+ )
166
+ self .assertTrue (
167
+ f"FeatureProcessedEmbeddingBagCollection_{ total_dim_fpebc2 } "
150
168
in operator_registry_state .op_registry_schema
151
169
)
152
170
@@ -155,28 +173,68 @@ def test_serialize_deserialize_ebc(self) -> None:
155
173
# Deserialize EBC
156
174
deserialized_model = deserialize_embedding_modules (ep , JsonSerializer )
157
175
176
+ # check EBC config
158
177
for i in range (5 ):
159
178
ebc_name = f"ebc{ i + 1 } "
160
- assert isinstance (
179
+ self . assertIsInstance (
161
180
getattr (deserialized_model , ebc_name ), EmbeddingBagCollection
162
181
)
163
182
164
- for deserialized_config , org_config in zip (
183
+ for deserialized , orginal in zip (
165
184
getattr (deserialized_model , ebc_name ).embedding_bag_configs (),
166
185
getattr (model , ebc_name ).embedding_bag_configs (),
167
186
):
168
- assert deserialized_config .name == org_config .name
169
- assert deserialized_config .embedding_dim == org_config .embedding_dim
170
- assert deserialized_config .num_embeddings , org_config .num_embeddings
171
- assert deserialized_config .feature_names , org_config .feature_names
187
+ self .assertEqual (deserialized .name , orginal .name )
188
+ self .assertEqual (deserialized .embedding_dim , orginal .embedding_dim )
189
+ self .assertEqual (deserialized .num_embeddings , orginal .num_embeddings )
190
+ self .assertEqual (deserialized .feature_names , orginal .feature_names )
191
+
192
+ # check FPEBC config
193
+ for i in range (2 ):
194
+ fpebc_name = f"fpebc{ i + 1 } "
195
+ assert isinstance (
196
+ getattr (deserialized_model , fpebc_name ),
197
+ FeatureProcessedEmbeddingBagCollection ,
198
+ )
199
+
200
+ deserialized_fp = getattr (
201
+ deserialized_model , fpebc_name
202
+ )._feature_processors
203
+ original_fp = getattr (model , fpebc_name )._feature_processors
204
+ if isinstance (original_fp , nn .ModuleDict ):
205
+ for deserialized , orginal in zip (
206
+ deserialized_fp .values (), original_fp .values ()
207
+ ):
208
+ self .assertDictEqual (
209
+ deserialized .get_init_kwargs (), orginal .get_init_kwargs ()
210
+ )
211
+ else :
212
+ self .assertDictEqual (
213
+ deserialized_fp .get_init_kwargs (), original_fp .get_init_kwargs ()
214
+ )
215
+
216
+ for deserialized , orginal in zip (
217
+ getattr (
218
+ deserialized_model , fpebc_name
219
+ )._embedding_bag_collection .embedding_bag_configs (),
220
+ getattr (
221
+ model , fpebc_name
222
+ )._embedding_bag_collection .embedding_bag_configs (),
223
+ ):
224
+ self .assertEqual (deserialized .name , orginal .name )
225
+ self .assertEqual (deserialized .embedding_dim , orginal .embedding_dim )
226
+ self .assertEqual (deserialized .num_embeddings , orginal .num_embeddings )
227
+ self .assertEqual (deserialized .feature_names , orginal .feature_names )
172
228
173
229
deserialized_model .load_state_dict (model .state_dict ())
174
- # Run forward on deserialized model
230
+
231
+ # Run forward on deserialized model and compare the output
175
232
deserialized_out = deserialized_model (id_list_features )
176
233
177
- for i , tensor in enumerate (deserialized_out ):
178
- assert eager_out [i ].shape == tensor .shape
179
- assert torch .allclose (eager_out [i ], tensor )
234
+ self .assertEqual (len (deserialized_out ), len (eager_out ))
235
+ for deserialized , orginal in zip (deserialized_out , eager_out ):
236
+ self .assertEqual (deserialized .shape , orginal .shape )
237
+ self .assertTrue (torch .allclose (deserialized , orginal ))
180
238
181
239
def test_dynamic_shape_ebc (self ) -> None :
182
240
model = self .generate_model ()
0 commit comments