Skip to content

Commit b8a1c40

Browse files
Gufan Yinfacebook-github-bot
Gufan Yin
authored andcommitted
Revert D59031938: Split of "[TorchRec][PT2] KJT custom op for 1d lengths input"
Differential Revision: D59031938 Original commit changeset: 3a80e2acedf0 Original Phabricator Diff: D59031938 fbshipit-source-id: ee71998434827a88799f9192a1f7d82eb7be18a3
1 parent 1db15f8 commit b8a1c40

File tree

2 files changed

+3
-157
lines changed

2 files changed

+3
-157
lines changed

torchrec/sparse/jagged_tensor.py

Lines changed: 2 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1955,20 +1955,8 @@ def permute(
19551955
indices_tensor,
19561956
self.weights_or_none(),
19571957
)
1958-
elif is_torchdynamo_compiling():
1959-
(
1960-
permuted_lengths,
1961-
permuted_values,
1962-
permuted_weights,
1963-
) = torch.ops.fbgemm.permute_2D_sparse_data_input1D(
1964-
indices_tensor,
1965-
self.lengths(),
1966-
self.values(),
1967-
self.stride(),
1968-
self.weights_or_none(),
1969-
permuted_length_per_key_sum,
1970-
)
19711958
else:
1959+
19721960
(
19731961
permuted_lengths,
19741962
permuted_values,
@@ -2350,20 +2338,7 @@ def dist_init(
23502338
s == stride for s in stride_per_rank
23512339
)
23522340

2353-
if single_batch_per_rank and is_torchdynamo_compiling():
2354-
(
2355-
lengths,
2356-
values,
2357-
weights,
2358-
) = torch.ops.fbgemm.permute_2D_sparse_data_input1D(
2359-
torch.jit._unwrap_optional(recat),
2360-
lengths,
2361-
values,
2362-
stride,
2363-
weights,
2364-
values.numel(),
2365-
)
2366-
elif single_batch_per_rank:
2341+
if single_batch_per_rank:
23672342
(
23682343
lengths,
23692344
values,

torchrec/sparse/tests/test_jagged_tensor_gpu.py

Lines changed: 1 addition & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,7 @@
1111
import unittest
1212

1313
import torch
14-
from torchrec.sparse.jagged_tensor import (
15-
_regroup_keyed_tensors,
16-
KeyedJaggedTensor,
17-
KeyedTensor,
18-
)
14+
from torchrec.sparse.jagged_tensor import _regroup_keyed_tensors, KeyedTensor
1915
from torchrec.sparse.tests.utils import build_groups, build_kts
2016
from torchrec.test_utils import skip_if_asan_class
2117

@@ -115,128 +111,3 @@ def test_regroup_backward(self) -> None:
115111

116112
torch.allclose(actual_kt_0_grad, expected_kt_0_grad)
117113
torch.allclose(actual_kt_1_grad, expected_kt_1_grad)
118-
119-
# pyre-ignore
120-
@unittest.skipIf(
121-
torch.cuda.device_count() <= 0,
122-
"Not enough GPUs, this test requires at least one GPUs",
123-
)
124-
def test_permute(self) -> None:
125-
values = torch.tensor(
126-
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], device=self.device
127-
)
128-
lengths = torch.tensor([0, 2, 0, 1, 1, 1, 0, 3, 0], device=self.device)
129-
keys = ["index_0", "index_1", "index_2"]
130-
131-
jag_tensor = KeyedJaggedTensor.from_lengths_sync(
132-
values=values,
133-
keys=keys,
134-
lengths=lengths,
135-
)
136-
indices = [1, 0, 2]
137-
permuted_jag_tensor = jag_tensor.permute(indices)
138-
139-
self.assertEqual(permuted_jag_tensor.keys(), ["index_1", "index_0", "index_2"])
140-
self.assertEqual(
141-
permuted_jag_tensor.offset_per_key(),
142-
[0, 3, 5, 8],
143-
)
144-
self.assertEqual(
145-
permuted_jag_tensor.values().tolist(),
146-
[3.0, 4.0, 5.0, 1.0, 2.0, 6.0, 7.0, 8.0],
147-
)
148-
self.assertEqual(
149-
permuted_jag_tensor.lengths().tolist(), [1, 1, 1, 0, 2, 0, 0, 3, 0]
150-
)
151-
self.assertEqual(permuted_jag_tensor.weights_or_none(), None)
152-
153-
# pyre-ignore
154-
@unittest.skipIf(
155-
torch.cuda.device_count() <= 0,
156-
"Not enough GPUs, this test requires at least one GPUs",
157-
)
158-
def test_permute_vb(self) -> None:
159-
values = torch.tensor(
160-
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], device=self.device
161-
)
162-
lengths = torch.tensor([1, 0, 1, 3, 0, 1, 0, 2, 0], device=self.device)
163-
keys = ["index_0", "index_1", "index_2"]
164-
stride_per_key_per_rank = [[2], [4], [3]]
165-
166-
jag_tensor = KeyedJaggedTensor.from_lengths_sync(
167-
values=values,
168-
keys=keys,
169-
lengths=lengths,
170-
stride_per_key_per_rank=stride_per_key_per_rank,
171-
)
172-
173-
indices = [1, 0, 2]
174-
permuted_jag_tensor = jag_tensor.permute(indices)
175-
176-
self.assertEqual(permuted_jag_tensor.keys(), ["index_1", "index_0", "index_2"])
177-
self.assertEqual(
178-
permuted_jag_tensor.offset_per_key(),
179-
[0, 5, 6, 8],
180-
)
181-
self.assertEqual(
182-
permuted_jag_tensor.values().tolist(),
183-
[2.0, 3.0, 4.0, 5.0, 6.0, 1.0, 7.0, 8.0],
184-
)
185-
self.assertEqual(
186-
permuted_jag_tensor.lengths().tolist(), [1, 3, 0, 1, 1, 0, 0, 2, 0]
187-
)
188-
self.assertEqual(permuted_jag_tensor.weights_or_none(), None)
189-
190-
# pyre-ignore
191-
@unittest.skipIf(
192-
torch.cuda.device_count() <= 0,
193-
"Not enough GPUs, this test requires at least one GPUs",
194-
)
195-
def test_permute_duplicates(self) -> None:
196-
values = torch.tensor(
197-
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], device=self.device
198-
)
199-
lengths = torch.tensor([0, 2, 0, 1, 1, 1, 0, 3, 0], device=self.device)
200-
keys = ["index_0", "index_1", "index_2"]
201-
202-
jag_tensor = KeyedJaggedTensor.from_lengths_sync(
203-
values=values,
204-
keys=keys,
205-
lengths=lengths,
206-
)
207-
208-
indices = [1, 0, 2, 1, 1]
209-
permuted_jag_tensor = jag_tensor.permute(indices)
210-
211-
self.assertEqual(
212-
permuted_jag_tensor.keys(),
213-
["index_1", "index_0", "index_2", "index_1", "index_1"],
214-
)
215-
self.assertEqual(
216-
permuted_jag_tensor.offset_per_key(),
217-
[0, 3, 5, 8, 11, 14],
218-
)
219-
self.assertEqual(
220-
permuted_jag_tensor.values().tolist(),
221-
[
222-
3.0,
223-
4.0,
224-
5.0,
225-
1.0,
226-
2.0,
227-
6.0,
228-
7.0,
229-
8.0,
230-
3.0,
231-
4.0,
232-
5.0,
233-
3.0,
234-
4.0,
235-
5.0,
236-
],
237-
)
238-
self.assertEqual(
239-
permuted_jag_tensor.lengths().tolist(),
240-
[1, 1, 1, 0, 2, 0, 0, 3, 0, 1, 1, 1, 1, 1, 1],
241-
)
242-
self.assertEqual(permuted_jag_tensor.weights_or_none(), None)

0 commit comments

Comments
 (0)