Skip to content

Commit 451ee17

Browse files
committed
Move test location
The newly added tests no longer require `sox` CLI, thus it is better located at transforms_test.
1 parent c9983a8 commit 451ee17

File tree

2 files changed

+15
-12
lines changed

2 files changed

+15
-12
lines changed

test/torchaudio_unittest/transforms/sox_compatibility_test.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -70,18 +70,6 @@ def test_vad(self, filename):
7070
result = T.Vad(sample_rate)(data)
7171
self.assert_sox_effect(result, path, ["vad"])
7272

73-
@parameterized.expand(
74-
[
75-
(torch.zeros(32000), torch.zeros(0), 16000),
76-
(torch.zeros(1, 32000), torch.zeros(1, 0), 32000),
77-
(torch.zeros(2, 44100), torch.zeros(2, 0), 32000),
78-
(torch.zeros(2, 2, 44100), torch.zeros(2, 2, 0), 32000),
79-
]
80-
)
81-
def test_vad_on_zero_audio(self, inpt: torch.Tensor, expected_output: torch.Tensor, sample_rate: int):
82-
result = T.Vad(sample_rate)(inpt)
83-
self.assertEqual(result, expected_output)
84-
8573
def test_vad_warning(self):
8674
"""vad should throw a warning if input dimension is greater than 2"""
8775
sample_rate = 41100

test/torchaudio_unittest/transforms/transforms_test_impl.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,3 +478,18 @@ def test_specaugment(self, n_time_masks, time_mask_param, n_freq_masks, freq_mas
478478
self.assertTrue(diff > 0)
479479
else:
480480
self.assertTrue(diff == 0)
481+
482+
@parameterized.expand(
483+
[
484+
((32000,), (0,), 16000),
485+
((1, 32000), (1, 0), 32000),
486+
((2, 44100), (2, 0), 32000),
487+
((2, 2, 44100), (2, 2, 0), 32000),
488+
]
489+
)
490+
def test_vad_on_zero_audio(self, input_shape, output_shape, sample_rate: int):
491+
"""VAD should return zero when input is zero Tensor"""
492+
inpt = torch.zeros(input_shape, dtype=self.dtype, device=self.device)
493+
expected_output = torch.zeros(output_shape, dtype=self.dtype, device=self.device)
494+
result = T.Vad(sample_rate)(inpt)
495+
self.assertEqual(result, expected_output)

0 commit comments

Comments
 (0)