Skip to content

Commit 5717479

Browse files
authored
port embedding bag check for 2.1.40 (#4504)(#4482)
* add check before embedding bag * fix for clang-format * add headers
1 parent 8b74d6c commit 5717479

File tree

1 file changed

+31
-2
lines changed

1 file changed

+31
-2
lines changed

csrc/gpu/aten/operators/EmbeddingBag.cpp

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
#include <ATen/ATen.h>
2-
#include <torch/torch.h>
3-
2+
#include <c10/util/Exception.h>
43
#include <core/Device.h>
54
#include <core/Memory.h>
65
#include <runtime/Utils.h>
6+
#include <torch/torch.h>
77
#include <utils/DPCPP.h>
88

99
#include "BitonicMergeSort.h"
@@ -1182,6 +1182,21 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _embedding_bag(
11821182
const c10::optional<at::Tensor>& per_sample_weights_opt,
11831183
bool include_last_offset,
11841184
int64_t padding_idx) {
1185+
TORCH_CHECK(
1186+
indices.dim() == 1 || indices.dim() == 2,
1187+
"input has to be a 1D or 2D Tensor, but got Tensor of dimension ",
1188+
indices.dim());
1189+
if (indices.dim() == 1) {
1190+
TORCH_CHECK(
1191+
offsets.dim() == 1,
1192+
"offsets has to be a 1D Tensor, but got Tensor of dimension ",
1193+
offsets.dim());
1194+
}
1195+
TORCH_CHECK(
1196+
weight.dim() == 2,
1197+
"weight has to be a 2D Tensor, but got Tensor of dimension ",
1198+
weight.dim());
1199+
11851200
c10::MaybeOwned<Tensor> per_sample_weights_maybe_owned =
11861201
at::borrow_from_optional_tensor(per_sample_weights_opt);
11871202
const Tensor& per_sample_weights = *per_sample_weights_maybe_owned;
@@ -1234,6 +1249,20 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _embedding_bag_forward_only(
12341249
const c10::optional<Tensor>& per_sample_weights_opt,
12351250
bool include_last_offset,
12361251
int64_t padding_idx) {
1252+
TORCH_CHECK(
1253+
indices.dim() == 1 || indices.dim() == 2,
1254+
"input has to be a 1D or 2D Tensor, but got Tensor of dimension ",
1255+
indices.dim());
1256+
if (indices.dim() == 1) {
1257+
TORCH_CHECK(
1258+
offsets.dim() == 1,
1259+
"offsets has to be a 1D Tensor, but got Tensor of dimension ",
1260+
offsets.dim());
1261+
}
1262+
TORCH_CHECK(
1263+
weight.dim() == 2,
1264+
"weight has to be a 2D Tensor, but got Tensor of dimension ",
1265+
weight.dim());
12371266
// See [Note: hacky wrapper removal for optional tensor]
12381267
c10::MaybeOwned<Tensor> per_sample_weights_maybe_owned =
12391268
at::borrow_from_optional_tensor(per_sample_weights_opt);

0 commit comments

Comments
 (0)