diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index f703fe8ec..3dbb857dc 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -1956,15 +1956,15 @@ def permute( self.weights_or_none(), ) else: - ( permuted_lengths, permuted_values, permuted_weights, - ) = torch.ops.fbgemm.permute_2D_sparse_data( + ) = torch.ops.fbgemm.permute_2D_sparse_data_input1D( indices_tensor, - self.lengths().view(len(self._keys), -1), + self.lengths(), self.values(), + self.stride(), self.weights_or_none(), permuted_length_per_key_sum, ) @@ -1977,7 +1977,7 @@ def permute( keys=permuted_keys, values=permuted_values, weights=permuted_weights, - lengths=permuted_lengths.view(-1), + lengths=permuted_lengths, offsets=None, stride=stride, stride_per_key_per_rank=optional_permuted_stride_per_key_per_rank, @@ -2343,14 +2343,14 @@ def dist_init( lengths, values, weights, - ) = torch.ops.fbgemm.permute_2D_sparse_data( + ) = torch.ops.fbgemm.permute_2D_sparse_data_input1D( torch.jit._unwrap_optional(recat), - lengths.view(-1, stride), + lengths, values, + stride, weights, values.numel(), ) - lengths = lengths.view(-1) else: # variable batch size per rank ( lengths, diff --git a/torchrec/sparse/tests/test_jagged_tensor.py b/torchrec/sparse/tests/test_jagged_tensor.py index 9efeb444c..5c1e83705 100644 --- a/torchrec/sparse/tests/test_jagged_tensor.py +++ b/torchrec/sparse/tests/test_jagged_tensor.py @@ -1316,7 +1316,6 @@ def test_permute(self) -> None: keys=keys, lengths=lengths, ) - indices = [1, 0, 2] permuted_jag_tensor = jag_tensor.permute(indices)