Skip to content

Commit 01f226b

Browse files
zeshengzongpytorchmergebot
authored andcommitted
Add check for ctc_loss targets param (#150981)
Fixes #150835 ## Test Result ```python # cuda >>> import torch >>> import torch.nn.functional as F >>> device = "cuda" # "cpu" is fine >>> num_classes = 4 >>> log_probs = torch.rand(0, 0, num_classes, device=device) >>> targets = torch.tensor([], device=device, dtype=torch.long) >>> input_lengths = torch.tensor([], device=device, dtype=torch.long) >>> target_lengths = torch.tensor([], device=device, dtype=torch.long) >>> result = F.ctc_loss(log_probs, targets, input_lengths, target_lengths, reduction='none') Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/home/zong/code/pytorch/torch/nn/functional.py", line 3079, in ctc_loss return torch.ctc_loss( ^^^^^^^^^^^^^^^ RuntimeError: log_probs tensor must not be empty # cpu >>> device = "cpu" >>> num_classes = 4 >>> log_probs = torch.rand(0, 0, num_classes, device=device) >>> targets = torch.tensor([], device=device, dtype=torch.long) >>> input_lengths = torch.tensor([], device=device, dtype=torch.long) >>> target_lengths = torch.tensor([], device=device, dtype=torch.long) >>> result = F.ctc_loss(log_probs, targets, input_lengths, target_lengths, reduction='none') Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/home/zong/code/pytorch/torch/nn/functional.py", line 3079, in ctc_loss return torch.ctc_loss( ^^^^^^^^^^^^^^^ RuntimeError: log_probs tensor must not be empty ``` Pull Request resolved: #150981 Approved by: https://github.com/eqy
1 parent bbc5fe8 commit 01f226b

File tree

3 files changed

+11
-0
lines changed

3 files changed

+11
-0
lines changed

aten/src/ATen/native/LossCTC.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ std::tuple<Tensor, Tensor, size_t, std::vector<int64_t>> ctc_loss_allocate_outpu
126126
// the alphas from the user by only returning the loss.
127127
template<typename scalar_t, ScalarType target_scalar_type>
128128
std::tuple<Tensor, Tensor> ctc_loss_cpu_template(const Tensor& log_probs, const Tensor& targets, IntArrayRef input_lengths, IntArrayRef target_lengths, int64_t BLANK) {
129+
TORCH_CHECK(log_probs.numel() > 0, "log_probs tensor must not be empty");
129130
// log_probs: input_len x batch_size x num_labels
130131
// targets [int64]: batch_size x target_length OR sum(target_lengths)
131132
constexpr scalar_t neginf = -std::numeric_limits<scalar_t>::infinity();

aten/src/ATen/native/cuda/LossCTC.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ ctc_loss_log_alpha_gpu_kernel(scalar_t* __restrict__ log_alpha_data,
219219
// backward. The dispatch function will only return the loss.
220220
template<typename scalar_t, ScalarType target_scalar_type>
221221
std::tuple<Tensor, Tensor> ctc_loss_gpu_template(const Tensor& log_probs, const Tensor& targets, IntArrayRef input_lengths, IntArrayRef target_lengths, int64_t BLANK) {
222+
TORCH_CHECK(log_probs.numel() > 0, "log_probs tensor must not be empty");
222223
// log_probs: input_len x batch_size x num_labels
223224
// targets [int64]: batch_size x target_length OR sum(target_lengths)
224225
CheckedFrom c = "ctc_loss_gpu";

test/test_nn.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11532,6 +11532,15 @@ def test_ctc_loss_cudnn_tensor(self, device):
1153211532
grad_cudnn, = torch.autograd.grad(loss_cudnn, log_probs, grad_out)
1153311533
self.assertEqual(grad_cudnn, grad_native, atol=1e-4, rtol=0)
1153411534

11535+
@expectedFailureMPS
11536+
def test_ctc_loss_error(self, device):
11537+
log_probs = torch.rand(0, 0, 4, device=device)
11538+
targets = torch.tensor([], device=device, dtype=torch.long)
11539+
input_lengths = torch.tensor([], device=device, dtype=torch.long)
11540+
target_lengths = torch.tensor([], device=device, dtype=torch.long)
11541+
with self.assertRaisesRegex(RuntimeError, "log_probs tensor must not be empty"):
11542+
F.ctc_loss(log_probs, targets, input_lengths, target_lengths, reduction='none')
11543+
1153511544
@expectedFailureMPS # RuntimeError: LSTM with projections is not currently supported with MPS.
1153611545
@dtypesIfCUDA(torch.half, torch.float, torch.double)
1153711546
@dtypes(torch.float)

0 commit comments

Comments
 (0)