|
11 | 11 | import unittest
|
12 | 12 |
|
13 | 13 | import torch
|
14 |
| -from torchrec.sparse.jagged_tensor import ( |
15 |
| - _regroup_keyed_tensors, |
16 |
| - KeyedJaggedTensor, |
17 |
| - KeyedTensor, |
18 |
| -) |
| 14 | +from torchrec.sparse.jagged_tensor import _regroup_keyed_tensors, KeyedTensor |
19 | 15 | from torchrec.sparse.tests.utils import build_groups, build_kts
|
20 | 16 | from torchrec.test_utils import skip_if_asan_class
|
21 | 17 |
|
@@ -115,128 +111,3 @@ def test_regroup_backward(self) -> None:
|
115 | 111 |
|
116 | 112 | torch.allclose(actual_kt_0_grad, expected_kt_0_grad)
|
117 | 113 | torch.allclose(actual_kt_1_grad, expected_kt_1_grad)
|
118 |
| - |
119 |
| - # pyre-ignore |
120 |
| - @unittest.skipIf( |
121 |
| - torch.cuda.device_count() <= 0, |
122 |
| - "Not enough GPUs, this test requires at least one GPUs", |
123 |
| - ) |
124 |
| - def test_permute(self) -> None: |
125 |
| - values = torch.tensor( |
126 |
| - [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], device=self.device |
127 |
| - ) |
128 |
| - lengths = torch.tensor([0, 2, 0, 1, 1, 1, 0, 3, 0], device=self.device) |
129 |
| - keys = ["index_0", "index_1", "index_2"] |
130 |
| - |
131 |
| - jag_tensor = KeyedJaggedTensor.from_lengths_sync( |
132 |
| - values=values, |
133 |
| - keys=keys, |
134 |
| - lengths=lengths, |
135 |
| - ) |
136 |
| - indices = [1, 0, 2] |
137 |
| - permuted_jag_tensor = jag_tensor.permute(indices) |
138 |
| - |
139 |
| - self.assertEqual(permuted_jag_tensor.keys(), ["index_1", "index_0", "index_2"]) |
140 |
| - self.assertEqual( |
141 |
| - permuted_jag_tensor.offset_per_key(), |
142 |
| - [0, 3, 5, 8], |
143 |
| - ) |
144 |
| - self.assertEqual( |
145 |
| - permuted_jag_tensor.values().tolist(), |
146 |
| - [3.0, 4.0, 5.0, 1.0, 2.0, 6.0, 7.0, 8.0], |
147 |
| - ) |
148 |
| - self.assertEqual( |
149 |
| - permuted_jag_tensor.lengths().tolist(), [1, 1, 1, 0, 2, 0, 0, 3, 0] |
150 |
| - ) |
151 |
| - self.assertEqual(permuted_jag_tensor.weights_or_none(), None) |
152 |
| - |
153 |
| - # pyre-ignore |
154 |
| - @unittest.skipIf( |
155 |
| - torch.cuda.device_count() <= 0, |
156 |
| - "Not enough GPUs, this test requires at least one GPUs", |
157 |
| - ) |
158 |
| - def test_permute_vb(self) -> None: |
159 |
| - values = torch.tensor( |
160 |
| - [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], device=self.device |
161 |
| - ) |
162 |
| - lengths = torch.tensor([1, 0, 1, 3, 0, 1, 0, 2, 0], device=self.device) |
163 |
| - keys = ["index_0", "index_1", "index_2"] |
164 |
| - stride_per_key_per_rank = [[2], [4], [3]] |
165 |
| - |
166 |
| - jag_tensor = KeyedJaggedTensor.from_lengths_sync( |
167 |
| - values=values, |
168 |
| - keys=keys, |
169 |
| - lengths=lengths, |
170 |
| - stride_per_key_per_rank=stride_per_key_per_rank, |
171 |
| - ) |
172 |
| - |
173 |
| - indices = [1, 0, 2] |
174 |
| - permuted_jag_tensor = jag_tensor.permute(indices) |
175 |
| - |
176 |
| - self.assertEqual(permuted_jag_tensor.keys(), ["index_1", "index_0", "index_2"]) |
177 |
| - self.assertEqual( |
178 |
| - permuted_jag_tensor.offset_per_key(), |
179 |
| - [0, 5, 6, 8], |
180 |
| - ) |
181 |
| - self.assertEqual( |
182 |
| - permuted_jag_tensor.values().tolist(), |
183 |
| - [2.0, 3.0, 4.0, 5.0, 6.0, 1.0, 7.0, 8.0], |
184 |
| - ) |
185 |
| - self.assertEqual( |
186 |
| - permuted_jag_tensor.lengths().tolist(), [1, 3, 0, 1, 1, 0, 0, 2, 0] |
187 |
| - ) |
188 |
| - self.assertEqual(permuted_jag_tensor.weights_or_none(), None) |
189 |
| - |
190 |
| - # pyre-ignore |
191 |
| - @unittest.skipIf( |
192 |
| - torch.cuda.device_count() <= 0, |
193 |
| - "Not enough GPUs, this test requires at least one GPUs", |
194 |
| - ) |
195 |
| - def test_permute_duplicates(self) -> None: |
196 |
| - values = torch.tensor( |
197 |
| - [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], device=self.device |
198 |
| - ) |
199 |
| - lengths = torch.tensor([0, 2, 0, 1, 1, 1, 0, 3, 0], device=self.device) |
200 |
| - keys = ["index_0", "index_1", "index_2"] |
201 |
| - |
202 |
| - jag_tensor = KeyedJaggedTensor.from_lengths_sync( |
203 |
| - values=values, |
204 |
| - keys=keys, |
205 |
| - lengths=lengths, |
206 |
| - ) |
207 |
| - |
208 |
| - indices = [1, 0, 2, 1, 1] |
209 |
| - permuted_jag_tensor = jag_tensor.permute(indices) |
210 |
| - |
211 |
| - self.assertEqual( |
212 |
| - permuted_jag_tensor.keys(), |
213 |
| - ["index_1", "index_0", "index_2", "index_1", "index_1"], |
214 |
| - ) |
215 |
| - self.assertEqual( |
216 |
| - permuted_jag_tensor.offset_per_key(), |
217 |
| - [0, 3, 5, 8, 11, 14], |
218 |
| - ) |
219 |
| - self.assertEqual( |
220 |
| - permuted_jag_tensor.values().tolist(), |
221 |
| - [ |
222 |
| - 3.0, |
223 |
| - 4.0, |
224 |
| - 5.0, |
225 |
| - 1.0, |
226 |
| - 2.0, |
227 |
| - 6.0, |
228 |
| - 7.0, |
229 |
| - 8.0, |
230 |
| - 3.0, |
231 |
| - 4.0, |
232 |
| - 5.0, |
233 |
| - 3.0, |
234 |
| - 4.0, |
235 |
| - 5.0, |
236 |
| - ], |
237 |
| - ) |
238 |
| - self.assertEqual( |
239 |
| - permuted_jag_tensor.lengths().tolist(), |
240 |
| - [1, 1, 1, 0, 2, 0, 0, 3, 0, 1, 1, 1, 1, 1, 1], |
241 |
| - ) |
242 |
| - self.assertEqual(permuted_jag_tensor.weights_or_none(), None) |
0 commit comments