Skip to content

Commit c972d54

Browse files
sarckkfacebook-github-bot
authored andcommitted
Add jit script KJT benchmarks (#2033)
Summary: Pull Request resolved: #2033 Add benchmarks for jit scripted KJT methods Reviewed By: gnahzg Differential Revision: D57701618 fbshipit-source-id: 4b09ab6841fb151f8d008ef0cbbc1ed5e78f1eef
1 parent a29c82e commit c972d54

File tree

1 file changed

+65
-21
lines changed

1 file changed

+65
-21
lines changed

torchrec/sparse/tests/keyed_jagged_tensor_benchmark.py

Lines changed: 65 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from torchrec.distributed.test_utils.test_model import ModelInput
2020
from torchrec.modules.embedding_configs import EmbeddingBagConfig
21-
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
21+
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor
2222

2323

2424
def generate_kjt(
@@ -69,15 +69,17 @@ def wrapped_func(
6969
kjt: KeyedJaggedTensor,
7070
test_func: Callable[[KeyedJaggedTensor], object],
7171
fn_kwargs: Dict[str, Any],
72+
jit_script: bool,
7273
) -> Callable[..., object]:
7374
def fn() -> object:
7475
return test_func(kjt, **fn_kwargs)
7576

76-
return fn
77+
return fn if jit_script else torch.jit.script(fn)
7778

7879

7980
def benchmark_kjt(
80-
method_name: str,
81+
test_name: str,
82+
test_func: Callable[..., object],
8183
kjt: KeyedJaggedTensor,
8284
num_repeat: int,
8385
num_warmup: int,
@@ -86,21 +88,15 @@ def benchmark_kjt(
8688
mean_pooling_factor: int,
8789
fn_kwargs: Dict[str, Any],
8890
is_static_method: bool,
91+
jit_script: bool,
8992
) -> None:
90-
test_name = method_name
91-
92-
# pyre-ignore
93-
def test_func(kjt: KeyedJaggedTensor, **kwargs):
94-
return getattr(KeyedJaggedTensor if is_static_method else kjt, method_name)(
95-
**kwargs
96-
)
9793

9894
for _ in range(num_warmup):
99-
test_func(kjt, **fn_kwargs)
95+
test_func(**fn_kwargs)
10096

10197
times = []
10298
for _ in range(num_repeat):
103-
time_elapsed = timeit.timeit(wrapped_func(kjt, test_func, fn_kwargs), number=1)
99+
time_elapsed = timeit.timeit(lambda: test_func(**fn_kwargs), number=1)
104100
# remove length_per_key and offset_per_key cache for fairer comparison
105101
kjt.unsync()
106102
times.append(time_elapsed)
@@ -112,7 +108,7 @@ def test_func(kjt: KeyedJaggedTensor, **kwargs):
112108
)
113109

114110
print(
115-
f" {test_name : <{35}} | B: {batch_size : <{8}} | F: {num_features : <{8}} | Mean Pooling Factor: {mean_pooling_factor : <{8}} | Runtime (P50): {result.runtime_percentile(50, interpolation='linear'):5f} ms | Runtime (P90): {result.runtime_percentile(90, interpolation='linear'):5f} ms"
111+
f" {test_name : <{35}} | JIT Script: {'Yes' if jit_script else 'No' : <{8}} | B: {batch_size : <{8}} | F: {num_features : <{8}} | Mean Pooling Factor: {mean_pooling_factor : <{8}} | Runtime (P50): {result.runtime_percentile(50, interpolation='linear'):5f} ms | Runtime (P90): {result.runtime_percentile(90, interpolation='linear'):5f} ms"
116112
)
117113

118114

@@ -148,6 +144,31 @@ def gen_dist_split_input(
148144
return (kjt_lengths, kjt_values, batch_size_per_rank, recat)
149145

150146

147+
@torch.jit.script
148+
def permute(kjt: KeyedJaggedTensor, indices: List[int]) -> KeyedJaggedTensor:
149+
return kjt.permute(indices)
150+
151+
152+
@torch.jit.script
153+
def todict(kjt: KeyedJaggedTensor) -> Dict[str, JaggedTensor]:
154+
return kjt.to_dict()
155+
156+
157+
@torch.jit.script
158+
def split(kjt: KeyedJaggedTensor, segments: List[int]) -> List[KeyedJaggedTensor]:
159+
return kjt.split(segments)
160+
161+
162+
@torch.jit.script
163+
def getitem(kjt: KeyedJaggedTensor, key: str) -> JaggedTensor:
164+
return kjt[key]
165+
166+
167+
@torch.jit.script
168+
def dist_splits(kjt: KeyedJaggedTensor, key_splits: List[int]) -> List[List[int]]:
169+
return kjt.dist_splits(key_splits)
170+
171+
151172
def bench(
152173
num_repeat: int,
153174
num_warmup: int,
@@ -184,12 +205,13 @@ def bench(
184205
tables, batch_size, num_workers, num_features, mean_pooling_factor, device
185206
)
186207

187-
benchmarked_methods: List[Tuple[str, Dict[str, Any], bool]] = [
188-
("permute", {"indices": permute_indices}, False),
189-
("to_dict", {}, False),
190-
("split", {"segments": splits}, False),
191-
("__getitem__", {"key": key}, False),
192-
("dist_splits", {"key_splits": splits}, False),
208+
# pyre-ignore[33]
209+
benchmarked_methods: List[Tuple[str, Dict[str, Any], bool, Callable[..., Any]]] = [
210+
("permute", {"indices": permute_indices}, False, permute),
211+
("to_dict", {}, False, todict),
212+
("split", {"segments": splits}, False, split),
213+
("__getitem__", {"key": key}, False, getitem),
214+
("dist_splits", {"key_splits": splits}, False, dist_splits),
193215
(
194216
"dist_init",
195217
{
@@ -206,12 +228,33 @@ def bench(
206228
"stride_per_rank": strides_per_rank,
207229
},
208230
True, # is static method
231+
torch.jit.script(KeyedJaggedTensor.dist_init),
209232
),
210233
]
211234

212-
for method_name, fn_kwargs, is_static_method in benchmarked_methods:
235+
for method_name, fn_kwargs, is_static_method, jit_func in benchmarked_methods:
236+
test_func = getattr(KeyedJaggedTensor if is_static_method else kjt, method_name)
237+
benchmark_kjt(
238+
test_name=method_name,
239+
test_func=test_func,
240+
kjt=kjt,
241+
num_repeat=num_repeat,
242+
num_warmup=num_warmup,
243+
num_features=num_features,
244+
batch_size=batch_size,
245+
mean_pooling_factor=mean_pooling_factor,
246+
fn_kwargs=fn_kwargs,
247+
is_static_method=is_static_method,
248+
jit_script=False,
249+
)
250+
251+
if not is_static_method:
252+
# Explicitly pass in KJT for instance methods
253+
fn_kwargs = {"kjt": kjt, **fn_kwargs}
254+
213255
benchmark_kjt(
214-
method_name=method_name,
256+
test_name=method_name,
257+
test_func=jit_func,
215258
kjt=kjt,
216259
num_repeat=num_repeat,
217260
num_warmup=num_warmup,
@@ -220,6 +263,7 @@ def bench(
220263
mean_pooling_factor=mean_pooling_factor,
221264
fn_kwargs=fn_kwargs,
222265
is_static_method=is_static_method,
266+
jit_script=True,
223267
)
224268

225269

0 commit comments

Comments
 (0)