Skip to content

benchmark of fbgemm op - permute_multi_embedding #2158

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions torchrec/distributed/benchmark/benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,9 +492,9 @@ def fx_script_module(eager_module: torch.nn.Module) -> torch.nn.Module:
def benchmark(
name: str,
model: torch.nn.Module,
warmup_inputs: List[KeyedJaggedTensor],
bench_inputs: List[KeyedJaggedTensor],
prof_inputs: List[KeyedJaggedTensor],
warmup_inputs: Union[List[KeyedJaggedTensor], List[Dict[str, Any]]],
bench_inputs: Union[List[KeyedJaggedTensor], List[Dict[str, Any]]],
prof_inputs: Union[List[KeyedJaggedTensor], List[Dict[str, Any]]],
world_size: int,
output_dir: str,
num_benchmarks: int,
Expand Down
97 changes: 97 additions & 0 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_cpu"
)
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu:permute_multi_embedding_ops_cpu"
)
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu:permute_multi_embedding_ops_gpu"
)
except OSError:
pass

Expand Down Expand Up @@ -164,6 +170,21 @@ def _all_keys_used_once(
return len(key_set) == len(group_set) == len(flat_keys) == len(flat_groups)


@torch.fx.wrap
def permute_multi_embedding(
keyed_tensors: List["KeyedTensor"], groups: List[List["str"]]
) -> List[torch.Tensor]:
keys, lengths, values = _desugar_keyed_tensors(keyed_tensors)
permutes, in_lengths, out_lengths = _multi_remap_to_groups(keys, lengths, groups)
permuted_values = torch.ops.fbgemm.permute_multi_embedding(
values,
permutes,
in_lengths,
out_lengths,
)
return permuted_values


@torch.fx.wrap
def _fbgemm_permute_pooled_embs(
keyed_tensors: List["KeyedTensor"], groups: List[List["str"]]
Expand Down Expand Up @@ -240,6 +261,82 @@ def _remap_to_groups(
return permute, inv_permute, offsets, inv_offsets, splits


def _multi_remap_to_groups(
keys: List[List[str]],
key_lengths: List[List[int]],
groups: List[List[str]],
) -> Tuple[List[int], List[int], List[int]]:
"""
Given a list of keys and lengths per key for each group, return the permute 2D tensor, and 1D tensor lengths:
[[input_tensor_idx, output_tensor_idx, input_start, output_start, length]], [length]
"""
# key => (tensor_idx, key_index)
key_map: Dict[str, Tuple[int, int]] = {
key: (tensor_idx, key_idx)
for tensor_idx, tensor in enumerate(keys)
for key_idx, key in enumerate(tensor)
}

# [offsets per tensor]
in_offsets: List[List[int]] = [[] for _ in key_lengths]
for i, tensor in enumerate(key_lengths):
in_offsets[i] = _cumsum(tensor)
in_lengths: List[int] = [sum(lengths) for lengths in key_lengths]

# set total_permutes as the jump stop sign
total_permutes: int = sum(len(tensor) for tensor in groups)
out_lengths: List[int] = [0] * len(groups)

# [input_tensor_idx, output_tensor_idx, input_start, output_start, length, jump]
permute_param = 6
permutes: List[int] = [0] * (total_permutes * permute_param)

# record the last seen index, so that can make the jump from last_seen to current
last_seen: Dict[str, int] = {}
permute_idx = 0
for output_tensor_idx, output_tenser in enumerate(groups):
output_start = 0
for output_key in output_tenser:
input_tensor_idx, input_key_idx = key_map[output_key]
input_start = in_offsets[input_tensor_idx][input_key_idx]
length = key_lengths[input_tensor_idx][input_key_idx]

# add jump data
if output_key not in last_seen:
jump = 0 # don't need to jump yet
# positive as a potential jump start
last_seen[output_key] = permute_idx
else:
prev = last_seen[output_key]
if prev >= 0: # positive ==> it's a jump start
# jump to current idx, positive as the jump start
permutes[prev * permute_param + 5] = permute_idx
else: # it's already in a jump sequence, mark as negative
permutes[-prev * permute_param + 5] = -permute_idx
# mark last_seen negative since it's already in jump
last_seen[output_key] = -permute_idx
# it's a potential jump stop
jump = -total_permutes

permutes[permute_idx * permute_param : permute_idx * permute_param + 6] = [
input_tensor_idx,
output_tensor_idx,
input_start,
output_start,
length,
jump,
]
permute_idx += 1
output_start += length
out_lengths[output_tensor_idx] = output_start

return (
permutes,
in_lengths,
out_lengths,
)


def _values_string(values: torch.Tensor, start: int, end: int) -> str:
size = values.size()
if len(size) == 1:
Expand Down
135 changes: 80 additions & 55 deletions torchrec/sparse/tests/jagged_tensor_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
_regroup_keyed_tensors,
KeyedJaggedTensor,
KeyedTensor,
permute_multi_embedding,
)
from torchrec.sparse.tests.utils import build_groups, build_kts

Expand All @@ -40,6 +41,7 @@ def bench(
run_backward: bool,
fn: Callable[..., List[torch.Tensor]],
fn_kwargs: Dict[str, Any],
output_dir: str = "",
) -> None:

# initial call
Expand All @@ -49,8 +51,8 @@ def wrapped_func(
model: torch.nn.Module, # not used
bench_inputs: List[KeyedJaggedTensor], # not used
fn: Callable[..., List[torch.Tensor]],
fn_kwargs: Dict[str, Any],
run_backward: bool,
**kwargs: Dict[str, Any],
) -> None:
result = fn(**fn_kwargs)
if run_backward:
Expand All @@ -64,26 +66,27 @@ def wrapped_func(
loss = torch.nn.functional.l1_loss(pred, labels)
loss.sum().backward()

model = DummyModel()
setattr(model, "forward", lambda kwargs: fn(**kwargs))
if device_type == "cuda":
result = benchmark(
name=name,
model=DummyModel(),
warmup_inputs=[],
model=model,
warmup_inputs=[fn_kwargs] * 10,
bench_inputs=[],
prof_inputs=[],
prof_inputs=[fn_kwargs] * 10,
world_size=1,
output_dir="",
output_dir=output_dir,
num_benchmarks=20,
func_to_benchmark=functools.partial(
wrapped_func, fn=fn, run_backward=run_backward, fn_kwargs=fn_kwargs
),
benchmark_func_kwargs={},
rank=0,
enable_logging=False,
enable_logging=True,
)

else: # cpu
model = DummyModel()
times = timeit.repeat(
lambda: wrapped_func(
model=model,
Expand Down Expand Up @@ -160,6 +163,12 @@ def wrapped_func(
default=2,
help="Total num of regrouping",
)
@click.option(
"--profile",
type=str,
default="",
help="profile output directory",
)
def main(
cuda_matrix: bool,
run_backward: bool,
Expand All @@ -170,6 +179,7 @@ def main(
dim_sparse: int,
batch_size: int,
n_groups: int,
profile: str,
) -> None:
if cuda_matrix:
n_denses = [64, 128, 256, 512, 1024]
Expand All @@ -184,54 +194,69 @@ def main(

for device_type in device_types:
for batch_size in batch_sizes:
for n_dense, n_sparse in zip(n_denses, n_sparses):

device = torch.device(device_type)
kts = build_kts(
n_dense,
n_sparse,
dim_dense,
dim_sparse,
batch_size,
device,
run_backward,
)
labels = torch.randint(
0, 1, (batch_size,), device=torch.device(device_type)
).float()
groups = build_groups(kts, n_groups)
bench(
"[fallback] _regroup_keyed_tenors",
labels,
batch_size,
n_dense + n_sparse,
device_type,
run_backward,
_regroup_keyed_tensors,
{"keyed_tensors": kts, "groups": groups},
)
bench(
"[prod] KeyedTensor.regroup",
labels,
batch_size,
n_dense + n_sparse,
device_type,
run_backward,
KeyedTensor.regroup,
{"keyed_tensors": kts, "groups": groups},
)
bench(
"[prod] KTRegroupAsDict",
labels,
batch_size,
n_dense + n_sparse,
device_type,
run_backward,
KTRegroupAsDict(
groups=groups, keys=[str(i) for i in range(n_groups)]
),
{"keyed_tensors": kts},
)
for duplicates in [False, True]:
for n_dense, n_sparse in zip(n_denses, n_sparses):
dup = "_dup" if duplicates else ""
device = torch.device(device_type)
kts = build_kts(
n_dense,
n_sparse,
dim_dense,
dim_sparse,
batch_size,
device,
run_backward,
)
labels = torch.randint(
0, 1, (batch_size,), device=torch.device(device_type)
).float()
groups = build_groups(kts, n_groups, duplicates=duplicates)
bench(
"_regroup_keyed_tenors" + dup,
labels,
batch_size,
n_dense + n_sparse,
device_type,
run_backward,
_regroup_keyed_tensors,
{"keyed_tensors": kts, "groups": groups},
profile,
)
bench(
"KeyedTensor.regroup" + dup,
labels,
batch_size,
n_dense + n_sparse,
device_type,
run_backward,
KeyedTensor.regroup,
{"keyed_tensors": kts, "groups": groups},
profile,
)
bench(
"KTRegroupAsDict" + dup,
labels,
batch_size,
n_dense + n_sparse,
device_type,
run_backward,
KTRegroupAsDict(
groups=groups, keys=[str(i) for i in range(n_groups)]
),
{"keyed_tensors": kts},
profile,
)
bench(
"permute_multi_embs" + dup,
labels,
batch_size,
n_dense + n_sparse,
device_type,
run_backward,
permute_multi_embedding,
{"keyed_tensors": kts, "groups": groups},
profile,
)


if __name__ == "__main__":
Expand Down
Loading
Loading