18
18
19
19
from torchrec .distributed .test_utils .test_model import ModelInput
20
20
from torchrec .modules .embedding_configs import EmbeddingBagConfig
21
- from torchrec .sparse .jagged_tensor import KeyedJaggedTensor
21
+ from torchrec .sparse .jagged_tensor import JaggedTensor , KeyedJaggedTensor
22
22
23
23
24
24
def generate_kjt (
@@ -69,15 +69,17 @@ def wrapped_func(
69
69
kjt : KeyedJaggedTensor ,
70
70
test_func : Callable [[KeyedJaggedTensor ], object ],
71
71
fn_kwargs : Dict [str , Any ],
72
+ jit_script : bool ,
72
73
) -> Callable [..., object ]:
73
74
def fn () -> object :
74
75
return test_func (kjt , ** fn_kwargs )
75
76
76
- return fn
77
+ return fn if jit_script else torch . jit . script ( fn )
77
78
78
79
79
80
def benchmark_kjt (
80
- method_name : str ,
81
+ test_name : str ,
82
+ test_func : Callable [..., object ],
81
83
kjt : KeyedJaggedTensor ,
82
84
num_repeat : int ,
83
85
num_warmup : int ,
@@ -86,21 +88,15 @@ def benchmark_kjt(
86
88
mean_pooling_factor : int ,
87
89
fn_kwargs : Dict [str , Any ],
88
90
is_static_method : bool ,
91
+ jit_script : bool ,
89
92
) -> None :
90
- test_name = method_name
91
-
92
- # pyre-ignore
93
- def test_func (kjt : KeyedJaggedTensor , ** kwargs ):
94
- return getattr (KeyedJaggedTensor if is_static_method else kjt , method_name )(
95
- ** kwargs
96
- )
97
93
98
94
for _ in range (num_warmup ):
99
- test_func (kjt , ** fn_kwargs )
95
+ test_func (** fn_kwargs )
100
96
101
97
times = []
102
98
for _ in range (num_repeat ):
103
- time_elapsed = timeit .timeit (wrapped_func ( kjt , test_func , fn_kwargs ), number = 1 )
99
+ time_elapsed = timeit .timeit (lambda : test_func ( ** fn_kwargs ), number = 1 )
104
100
# remove length_per_key and offset_per_key cache for fairer comparison
105
101
kjt .unsync ()
106
102
times .append (time_elapsed )
@@ -112,7 +108,7 @@ def test_func(kjt: KeyedJaggedTensor, **kwargs):
112
108
)
113
109
114
110
print (
115
- f" { test_name : <{35 }} | B: { batch_size : <{8 }} | F: { num_features : <{8 }} | Mean Pooling Factor: { mean_pooling_factor : <{8 }} | Runtime (P50): { result .runtime_percentile (50 , interpolation = 'linear' ):5f} ms | Runtime (P90): { result .runtime_percentile (90 , interpolation = 'linear' ):5f} ms"
111
+ f" { test_name : <{35 }} | JIT Script: { 'Yes' if jit_script else 'No' : <{ 8 } } | B: { batch_size : <{8 }} | F: { num_features : <{8 }} | Mean Pooling Factor: { mean_pooling_factor : <{8 }} | Runtime (P50): { result .runtime_percentile (50 , interpolation = 'linear' ):5f} ms | Runtime (P90): { result .runtime_percentile (90 , interpolation = 'linear' ):5f} ms"
116
112
)
117
113
118
114
@@ -148,6 +144,31 @@ def gen_dist_split_input(
148
144
return (kjt_lengths , kjt_values , batch_size_per_rank , recat )
149
145
150
146
147
+ @torch .jit .script
148
+ def permute (kjt : KeyedJaggedTensor , indices : List [int ]) -> KeyedJaggedTensor :
149
+ return kjt .permute (indices )
150
+
151
+
152
+ @torch .jit .script
153
+ def todict (kjt : KeyedJaggedTensor ) -> Dict [str , JaggedTensor ]:
154
+ return kjt .to_dict ()
155
+
156
+
157
+ @torch .jit .script
158
+ def split (kjt : KeyedJaggedTensor , segments : List [int ]) -> List [KeyedJaggedTensor ]:
159
+ return kjt .split (segments )
160
+
161
+
162
+ @torch .jit .script
163
+ def getitem (kjt : KeyedJaggedTensor , key : str ) -> JaggedTensor :
164
+ return kjt [key ]
165
+
166
+
167
+ @torch .jit .script
168
+ def dist_splits (kjt : KeyedJaggedTensor , key_splits : List [int ]) -> List [List [int ]]:
169
+ return kjt .dist_splits (key_splits )
170
+
171
+
151
172
def bench (
152
173
num_repeat : int ,
153
174
num_warmup : int ,
@@ -184,12 +205,13 @@ def bench(
184
205
tables , batch_size , num_workers , num_features , mean_pooling_factor , device
185
206
)
186
207
187
- benchmarked_methods : List [Tuple [str , Dict [str , Any ], bool ]] = [
188
- ("permute" , {"indices" : permute_indices }, False ),
189
- ("to_dict" , {}, False ),
190
- ("split" , {"segments" : splits }, False ),
191
- ("__getitem__" , {"key" : key }, False ),
192
- ("dist_splits" , {"key_splits" : splits }, False ),
208
+ # pyre-ignore[33]
209
+ benchmarked_methods : List [Tuple [str , Dict [str , Any ], bool , Callable [..., Any ]]] = [
210
+ ("permute" , {"indices" : permute_indices }, False , permute ),
211
+ ("to_dict" , {}, False , todict ),
212
+ ("split" , {"segments" : splits }, False , split ),
213
+ ("__getitem__" , {"key" : key }, False , getitem ),
214
+ ("dist_splits" , {"key_splits" : splits }, False , dist_splits ),
193
215
(
194
216
"dist_init" ,
195
217
{
@@ -206,12 +228,33 @@ def bench(
206
228
"stride_per_rank" : strides_per_rank ,
207
229
},
208
230
True , # is static method
231
+ torch .jit .script (KeyedJaggedTensor .dist_init ),
209
232
),
210
233
]
211
234
212
- for method_name , fn_kwargs , is_static_method in benchmarked_methods :
235
+ for method_name , fn_kwargs , is_static_method , jit_func in benchmarked_methods :
236
+ test_func = getattr (KeyedJaggedTensor if is_static_method else kjt , method_name )
237
+ benchmark_kjt (
238
+ test_name = method_name ,
239
+ test_func = test_func ,
240
+ kjt = kjt ,
241
+ num_repeat = num_repeat ,
242
+ num_warmup = num_warmup ,
243
+ num_features = num_features ,
244
+ batch_size = batch_size ,
245
+ mean_pooling_factor = mean_pooling_factor ,
246
+ fn_kwargs = fn_kwargs ,
247
+ is_static_method = is_static_method ,
248
+ jit_script = False ,
249
+ )
250
+
251
+ if not is_static_method :
252
+ # Explicitly pass in KJT for instance methods
253
+ fn_kwargs = {"kjt" : kjt , ** fn_kwargs }
254
+
213
255
benchmark_kjt (
214
- method_name = method_name ,
256
+ test_name = method_name ,
257
+ test_func = jit_func ,
215
258
kjt = kjt ,
216
259
num_repeat = num_repeat ,
217
260
num_warmup = num_warmup ,
@@ -220,6 +263,7 @@ def bench(
220
263
mean_pooling_factor = mean_pooling_factor ,
221
264
fn_kwargs = fn_kwargs ,
222
265
is_static_method = is_static_method ,
266
+ jit_script = True ,
223
267
)
224
268
225
269
0 commit comments