Skip to content

RegroupAsDict module #2007

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 158 additions & 0 deletions torchrec/modules/regroup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

#!/usr/bin/env python3

from typing import Dict, List, Optional, Tuple

import torch
from torchrec.sparse.jagged_tensor import (
_all_keys_used_once,
_desugar_keyed_tensors,
_remap_to_groups,
KeyedTensor,
)


@torch.fx.wrap
def _concat_values(kts: List[KeyedTensor], dim: int) -> torch.Tensor:
return torch.cat([kt.values() for kt in kts], dim=dim)


@torch.fx.wrap
def _permuted_values(
kts: List[KeyedTensor], remap: List[Tuple[int, str]], dim: int
) -> torch.Tensor:
embedding_dicts = [kt.to_dict() for kt in kts]
values = [embedding_dicts[idx][key] for (idx, key) in remap]
return torch.cat(values, dim=dim)


@torch.fx.wrap
def _build_dict(
keys: List[str], values: torch.Tensor, splits: List[int], dim: int
) -> Dict[str, torch.Tensor]:
return {
key: tensor for key, tensor in zip(keys, torch.split(values, splits, dim=dim))
}


class KTRegroupAsDict(torch.nn.Module):
"""
KTRegroupAsDict is a nn.Module that mirrors beahvior of static method KeyedTensor.regroup_as_dict()

The advantage of using this module it caches the regrouping logic after first batch.

Args:
groups (List[List[str]]): features per output group
keys (List[str]): key of each output group

Example::

keys = ['object', 'user']
groups = [['f1', 'f2'], ['f3']]
regroup_module = KTRegroupAsDict(groups, keys)


tensor_list = [torch.randn(2, 4), torch.randn(2, 8), torch.randn(2, 2)]
kts = [KeyedTensor.from_tensor_list(['f1', 'f2', 'f3' ], tensor_list)]
out = regroup_module(kts)

"""

def __init__(self, groups: List[List[str]], keys: List[str]) -> None:
super().__init__()
torch._C._log_api_usage_once(f"torchrec.modules.{self.__class__.__name__}")
assert len(groups) == len(keys), "Groups and keys should have same length"
self._groups = groups
self._keys = keys
self._is_inited = False

# cached values populated on first forward call
self.device: Optional[torch.device] = None
self._concat_dim: int = 1
self._use_fbgemm_regroup: bool = False
self._splits: List[int] = []
self._idx_key_pairs: List[Tuple[int, str]] = []
self._permute_tensor: Optional[torch.Tensor] = None
self._inv_permute_tensor: Optional[torch.Tensor] = None
self._offsets_tensor: Optional[torch.Tensor] = None
self._inv_offsets_tensor: Optional[torch.Tensor] = None

def _init_fbgemm_regroup(self, kts: List[KeyedTensor]) -> None:
self._use_fbgemm_regroup = True
keys, lengths, values = _desugar_keyed_tensors(kts)
permute, inv_permute, offsets, inv_offsets, splits = _remap_to_groups(
keys, lengths, self._groups
)
# no need to pin_memory() or to(..., non_blocking=True) since occurs only once
self._permute_tensor = permute.to(self.device)
self._inv_permute_tensor = inv_permute.to(self.device)
self._offsets_tensor = offsets.to(self.device)
self._inv_offsets_tensor = inv_offsets.to(self.device)
self._splits = splits

def _init_regroup(self, kts: List[KeyedTensor]) -> None:
lengths = [kt.length_per_key() for kt in kts]
indices = [kt._key_indices() for kt in kts]

key_to_idx: dict[str, int] = {}
for i, kt in enumerate(kts):
for key in kt.keys():
if key in key_to_idx:
raise RuntimeError(
f"Duplicate key {key} found in KeyedTensors, undefined behavior"
)
key_to_idx[key] = i

splits: List[int] = []
idx_key_pairs: List[Tuple[int, str]] = []
for group in self._groups:
group_length = 0
for name in group:
idx_key_pairs.append((key_to_idx[name], name))
group_length += lengths[key_to_idx[name]][
indices[key_to_idx[name]][name]
]
splits.append(group_length)

self._splits = splits
self._idx_key_pairs = idx_key_pairs

def forward(self, keyed_tensors: List[KeyedTensor]) -> Dict[str, torch.Tensor]:
if not self._is_inited:
assert len(keyed_tensors) > 0, "Empty list provided"
assert all(
kt.device == keyed_tensors[0].device for kt in keyed_tensors
), "All inputs should be on the same device."
self.device = keyed_tensors[0].device
assert all(
kt.key_dim() == keyed_tensors[0].key_dim() for kt in keyed_tensors
), "All inputs should have the same key_dim"
self._dim = keyed_tensors[0].key_dim()

if _all_keys_used_once(keyed_tensors, self._groups) and self._dim == 1:
self._init_fbgemm_regroup(keyed_tensors)
else:
self._init_regroup(keyed_tensors)
self._is_inited = True

if self._use_fbgemm_regroup:
values = _concat_values(keyed_tensors, self._dim)
permuted_values = torch.ops.fbgemm.permute_pooled_embs_auto_grad(
values,
self._offsets_tensor,
self._permute_tensor,
self._inv_offsets_tensor,
self._inv_permute_tensor,
)
else:
permuted_values = _permuted_values(
keyed_tensors, self._idx_key_pairs, self._dim
)

return _build_dict(self._keys, permuted_values, self._splits, self._dim)
134 changes: 134 additions & 0 deletions torchrec/modules/tests/test_regroup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import unittest

import torch
import torch.fx

from torchrec.modules.regroup import KTRegroupAsDict
from torchrec.sparse.jagged_tensor import _all_keys_used_once, KeyedTensor
from torchrec.sparse.tests.utils import build_groups, build_kts


class KTRegroupAsDictTest(unittest.TestCase):
def setUp(self) -> None:
super().setUp()
self.kts = build_kts(
dense_features=20,
sparse_features=20,
dim_dense=64,
dim_sparse=128,
batch_size=128,
device=torch.device("cpu"),
run_backward=True,
)
self.num_groups = 2
self.keys = ["user", "object"]
self.labels = torch.randint(0, 1, (128,), device=torch.device("cpu")).float()

def test_regroup_backward_skips_and_duplicates(self) -> None:
groups = build_groups(
kts=self.kts, num_groups=self.num_groups, skips=True, duplicates=True
)
assert _all_keys_used_once(self.kts, groups) is False

regroup_module = KTRegroupAsDict(groups=groups, keys=self.keys)
tensor_groups = regroup_module(self.kts)
pred0 = tensor_groups["user"].sum(dim=1).mul(tensor_groups["object"].sum(dim=1))
loss = torch.nn.functional.l1_loss(pred0, self.labels).sum()
actual_kt_0_grad, actual_kt_1_grad = torch.autograd.grad(
loss, [self.kts[0].values(), self.kts[1].values()]
)

# clear grads so can reuse inputs
self.kts[0].values().grad = None
self.kts[1].values().grad = None

tensor_groups = KeyedTensor.regroup_as_dict(
keyed_tensors=self.kts, groups=groups, keys=self.keys
)
pred1 = tensor_groups["user"].sum(dim=1).mul(tensor_groups["object"].sum(dim=1))
loss = torch.nn.functional.l1_loss(pred1, self.labels).sum()
expected_kt_0_grad, expected_kt_1_grad = torch.autograd.grad(
loss, [self.kts[0].values(), self.kts[1].values()]
)

torch.allclose(pred0, pred1)
torch.allclose(actual_kt_0_grad, expected_kt_0_grad)
torch.allclose(actual_kt_1_grad, expected_kt_1_grad)

def test_regroup_backward(self) -> None:
groups = build_groups(
kts=self.kts, num_groups=self.num_groups, skips=False, duplicates=False
)
assert _all_keys_used_once(self.kts, groups) is True

regroup_module = KTRegroupAsDict(groups=groups, keys=self.keys)
tensor_groups = regroup_module(self.kts)
pred0 = tensor_groups["user"].sum(dim=1).mul(tensor_groups["object"].sum(dim=1))
loss = torch.nn.functional.l1_loss(pred0, self.labels).sum()
actual_kt_0_grad, actual_kt_1_grad = torch.autograd.grad(
loss, [self.kts[0].values(), self.kts[1].values()]
)

# clear grads so can reuse inputs
self.kts[0].values().grad = None
self.kts[1].values().grad = None

tensor_groups = KeyedTensor.regroup_as_dict(
keyed_tensors=self.kts, groups=groups, keys=self.keys
)
pred1 = tensor_groups["user"].sum(dim=1).mul(tensor_groups["object"].sum(dim=1))
loss = torch.nn.functional.l1_loss(pred1, self.labels).sum()
expected_kt_0_grad, expected_kt_1_grad = torch.autograd.grad(
loss, [self.kts[0].values(), self.kts[1].values()]
)

torch.allclose(pred0, pred1)
torch.allclose(actual_kt_0_grad, expected_kt_0_grad)
torch.allclose(actual_kt_1_grad, expected_kt_1_grad)

def test_fx_and_jit_regroup(self) -> None:
groups = build_groups(
kts=self.kts, num_groups=self.num_groups, skips=False, duplicates=False
)
assert _all_keys_used_once(self.kts, groups) is True

regroup_module = KTRegroupAsDict(groups=groups, keys=self.keys)
# first pass
regroup_module(self.kts)

# now trace
gm = torch.fx.symbolic_trace(regroup_module)
jit_gm = torch.jit.script(gm)

out = jit_gm(self.kts)
eager_out = regroup_module(self.kts)
for key in out.keys():
torch.allclose(out[key], eager_out[key])

def test_fx_and_jit_regroup_skips_and_duplicates(self) -> None:
groups = build_groups(
kts=self.kts, num_groups=self.num_groups, skips=True, duplicates=True
)
assert _all_keys_used_once(self.kts, groups) is False

regroup_module = KTRegroupAsDict(groups=groups, keys=self.keys)
# first pass
regroup_module(self.kts)

# now trace
gm = torch.fx.symbolic_trace(regroup_module)
jit_gm = torch.jit.script(gm)

out = jit_gm(self.kts)
eager_out = regroup_module(self.kts)
for key in out.keys():
torch.allclose(out[key], eager_out[key])
8 changes: 8 additions & 0 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,10 @@ def to_padded_dense_weights(
self.weights(), [self.offsets()], [N], padding_value
)

@property
def device(self) -> torch.device:
return self._values.device

def lengths(self) -> torch.Tensor:
_lengths = _maybe_compute_lengths(self._lengths, self._offsets)
self._lengths = _lengths
Expand Down Expand Up @@ -2570,6 +2574,10 @@ def values(self) -> torch.Tensor:
def key_dim(self) -> int:
return self._key_dim

@property
def device(self) -> torch.device:
return self._values.device

def offset_per_key(self) -> List[int]:
_offset_per_key = _maybe_compute_offset_per_key_kt(
self._length_per_key,
Expand Down
18 changes: 17 additions & 1 deletion torchrec/sparse/tests/jagged_tensor_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import torch
from torchrec.distributed.benchmark.benchmark_utils import benchmark, BenchmarkResult
from torchrec.modules.regroup import KTRegroupAsDict
from torchrec.sparse.jagged_tensor import (
_regroup_keyed_tensors,
KeyedJaggedTensor,
Expand Down Expand Up @@ -53,7 +54,10 @@ def wrapped_func(
) -> None:
result = fn(**fn_kwargs)
if run_backward:
vectors = [tensor.sum(dim=1) for tensor in result]
if isinstance(result, dict):
vectors = [tensor.sum(dim=1) for tensor in result.values()]
else:
vectors = [tensor.sum(dim=1) for tensor in result]
pred = vectors[0]
for vector in vectors[1:]:
pred.mul(vector)
Expand Down Expand Up @@ -216,6 +220,18 @@ def main(
KeyedTensor.regroup,
{"keyed_tensors": kts, "groups": groups},
)
bench(
"[prod] KTRegroupAsDict",
labels,
batch_size,
n_dense + n_sparse,
device_type,
run_backward,
KTRegroupAsDict(
groups=groups, keys=[str(i) for i in range(n_groups)]
),
{"keyed_tensors": kts},
)


if __name__ == "__main__":
Expand Down
Loading