|
1 | 1 | #include <ATen/ATen.h>
|
2 |
| -#include <torch/torch.h> |
3 |
| - |
| 2 | +#include <c10/util/Exception.h> |
4 | 3 | #include <core/Device.h>
|
5 | 4 | #include <core/Memory.h>
|
6 | 5 | #include <runtime/Utils.h>
|
| 6 | +#include <torch/torch.h> |
7 | 7 | #include <utils/DPCPP.h>
|
8 | 8 |
|
9 | 9 | #include "BitonicMergeSort.h"
|
@@ -1182,6 +1182,21 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _embedding_bag(
|
1182 | 1182 | const c10::optional<at::Tensor>& per_sample_weights_opt,
|
1183 | 1183 | bool include_last_offset,
|
1184 | 1184 | int64_t padding_idx) {
|
| 1185 | + TORCH_CHECK( |
| 1186 | + indices.dim() == 1 || indices.dim() == 2, |
| 1187 | + "input has to be a 1D or 2D Tensor, but got Tensor of dimension ", |
| 1188 | + indices.dim()); |
| 1189 | + if (indices.dim() == 1) { |
| 1190 | + TORCH_CHECK( |
| 1191 | + offsets.dim() == 1, |
| 1192 | + "offsets has to be a 1D Tensor, but got Tensor of dimension ", |
| 1193 | + offsets.dim()); |
| 1194 | + } |
| 1195 | + TORCH_CHECK( |
| 1196 | + weight.dim() == 2, |
| 1197 | + "weight has to be a 2D Tensor, but got Tensor of dimension ", |
| 1198 | + weight.dim()); |
| 1199 | + |
1185 | 1200 | c10::MaybeOwned<Tensor> per_sample_weights_maybe_owned =
|
1186 | 1201 | at::borrow_from_optional_tensor(per_sample_weights_opt);
|
1187 | 1202 | const Tensor& per_sample_weights = *per_sample_weights_maybe_owned;
|
@@ -1234,6 +1249,20 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _embedding_bag_forward_only(
|
1234 | 1249 | const c10::optional<Tensor>& per_sample_weights_opt,
|
1235 | 1250 | bool include_last_offset,
|
1236 | 1251 | int64_t padding_idx) {
|
| 1252 | + TORCH_CHECK( |
| 1253 | + indices.dim() == 1 || indices.dim() == 2, |
| 1254 | + "input has to be a 1D or 2D Tensor, but got Tensor of dimension ", |
| 1255 | + indices.dim()); |
| 1256 | + if (indices.dim() == 1) { |
| 1257 | + TORCH_CHECK( |
| 1258 | + offsets.dim() == 1, |
| 1259 | + "offsets has to be a 1D Tensor, but got Tensor of dimension ", |
| 1260 | + offsets.dim()); |
| 1261 | + } |
| 1262 | + TORCH_CHECK( |
| 1263 | + weight.dim() == 2, |
| 1264 | + "weight has to be a 2D Tensor, but got Tensor of dimension ", |
| 1265 | + weight.dim()); |
1237 | 1266 | // See [Note: hacky wrapper removal for optional tensor]
|
1238 | 1267 | c10::MaybeOwned<Tensor> per_sample_weights_maybe_owned =
|
1239 | 1268 | at::borrow_from_optional_tensor(per_sample_weights_opt);
|
|
0 commit comments