Skip to content

Commit e8f1081

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
benchmark of fbgemm op - permute_multi_embedding (#2158)
Summary: X-link: pytorch/FBGEMM#2771 # context * added both **op-level** and **fn-level** benchmarks for the KT.regroup implementations * analyze the op-level and fn-level performance in runtime and memory usage * findings are that: **a**. In the fn-level performance, the `permute_multi_embedding` (new op) outperforms both the native-pytorch implementation and the `permute_pooled_embs_auto_grad` (current Prod) **b**. In the op-level performance, the new op is slower than the current prod because the new op integrates # performance notes The good: 1. the algorithm is designed in a way that it doesn't need to know in advance whether the 1-to-N mapping exists in the permutes. 2. `_all_keys_used_once` is no longer needed 3. no longer need a torch.cat before calling the old operator 4. no need to use `_pin_and_move` for the meta data (arguments), it will be handled inside the operator, it's more friendly to tracing. The same bad: 1. it requires several HtoD communications (move tensor to device): a) [resolved] 3 tensors, which are `permutes`, `input_lengths`, and `output_lengths`. Those tensors needs to be on the device so that the cuda kernels has access to it. b) [resolved] 2 lists of (scalar_t*) pointers, input and output tensor lists. c) [resolved] Didn't find a good way to let the kernel knows the address of the lists of input/output tensors, because the lists are also need to be on the device. 2. tensor.contiguous for the backward function, it looks like the grad from the backward are somehow not contiguous. # benchmark * op-level results ``` INFO:root:size: 1024 x 57168; permute_multi_embedding: 1.5612200498580933 ms; permute_pooled_embs_auto_grad: 0.9015970826148987 ms INFO:root:size: 1024 x 134096; permute_multi_embedding: 3.0794131755828857 ms; permute_pooled_embs_auto_grad: 2.114053726196289 ms INFO:root:size: 1024 x 136752; permute_multi_embedding: 2.6919198036193848 ms; permute_pooled_embs_auto_grad: 2.159184455871582 ms INFO:root:size: 1024 x 260944; permute_multi_embedding: 4.805435180664063 ms; permute_pooled_embs_auto_grad: 4.098493576049805 ms INFO:root:size: 1024 x 538432; permute_multi_embedding: 9.359790802001953 ms; permute_pooled_embs_auto_grad: 8.504887580871582 ms INFO:root:size: 1024 x 536592; permute_multi_embedding: 9.375926017761232 ms; permute_pooled_embs_auto_grad: 8.459586143493652 ms ``` * fn-level results ``` _regroup_keyed_tenors | B: 1024 | F: 1020 | device: cuda | Runtime (P90): 2.8 ms | Memory (P90): 1011.0 KeyedTensor.regroup | B: 1024 | F: 1020 | device: cuda | Runtime (P90): 5.0 ms | Memory (P90): 1517.0 KTRegroupAsDict | B: 1024 | F: 1020 | device: cuda | Runtime (P90): 4.9 ms | Memory (P90): 1517.0 permute_multi_embs | B: 1024 | F: 1020 | device: cuda | Runtime (P90): 2.2 ms | Memory (P90): 1011.0 _regroup_keyed_tenors_dup | B: 1024 | F: 1020 | device: cuda | Runtime (P90): 2.5 ms | Memory (P90): 1011.0 KeyedTensor.regroup_dup | B: 1024 | F: 1020 | device: cuda | Runtime (P90): 2.5 ms | Memory (P90): 1011.0 KTRegroupAsDict_dup | B: 1024 | F: 1020 | device: cuda | Runtime (P90): 2.5 ms | Memory (P90): 1011.0 permute_multi_embs_dup | B: 1024 | F: 1020 | device: cuda | Runtime (P90): 3.2 ms | Memory (P90): 1011.0 ``` # traces * [files](https://drive.google.com/drive/folders/1_9hOtQUQeFICBVxQtusvpQ_VajduFUmR?usp=sharing) ``` [[email protected] /data/sandcastle/boxes/fbsource (ae677c240)]$ ll *.json -rw-rw-r-- 1 hhy hhy 8062993 Jun 21 23:26 trace-KeyedTensor.regroup_dup.json -rw-rw-r-- 1 hhy hhy 949610 Jun 21 23:26 trace-KeyedTensor.regroup.json -rw-rw-r-- 1 hhy hhy 5140143 Jun 21 23:26 trace-KTRegroupAsDict_dup.json -rw-rw-r-- 1 hhy hhy 350370 Jun 21 23:26 trace-KTRegroupAsDict.json -rw-rw-r-- 1 hhy hhy 581033 Jun 21 23:26 trace-permute_multi_embs_dup.json -rw-rw-r-- 1 hhy hhy 582607 Jun 21 23:26 trace-permute_multi_embs.json -rw-rw-r-- 1 hhy hhy 8025337 Jun 21 23:26 trace-_regroup_keyed_tenors_dup.json -rw-rw-r-- 1 hhy hhy 8041586 Jun 21 23:26 trace-_regroup_keyed_tenors.json ``` Differential Revision: D58906839
1 parent 18b184c commit e8f1081

File tree

2 files changed

+27
-0
lines changed

2 files changed

+27
-0
lines changed

torchrec/sparse/jagged_tensor.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,21 @@ def _all_keys_used_once(
170170
return len(key_set) == len(group_set) == len(flat_keys) == len(flat_groups)
171171

172172

173+
@torch.fx.wrap
174+
def permute_multi_embedding(
175+
keyed_tensors: List["KeyedTensor"], groups: List[List["str"]]
176+
) -> List[torch.Tensor]:
177+
keys, lengths, values = _desugar_keyed_tensors(keyed_tensors)
178+
permutes, in_lengths, out_lengths = _multi_remap_to_groups(keys, lengths, groups)
179+
permuted_values = torch.ops.fbgemm.permute_multi_embedding(
180+
values,
181+
permutes,
182+
in_lengths,
183+
out_lengths,
184+
)
185+
return permuted_values
186+
187+
173188
@torch.fx.wrap
174189
def _fbgemm_permute_pooled_embs(
175190
keyed_tensors: List["KeyedTensor"], groups: List[List["str"]]

torchrec/sparse/tests/jagged_tensor_benchmark.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
_regroup_keyed_tensors,
2222
KeyedJaggedTensor,
2323
KeyedTensor,
24+
permute_multi_embedding,
2425
)
2526
from torchrec.sparse.tests.utils import build_groups, build_kts
2627

@@ -245,6 +246,17 @@ def main(
245246
{"keyed_tensors": kts},
246247
profile,
247248
)
249+
bench(
250+
"permute_multi_embs" + dup,
251+
labels,
252+
batch_size,
253+
n_dense + n_sparse,
254+
device_type,
255+
run_backward,
256+
permute_multi_embedding,
257+
{"keyed_tensors": kts, "groups": groups},
258+
profile,
259+
)
248260

249261

250262
if __name__ == "__main__":

0 commit comments

Comments
 (0)