Skip to content

Commit f56c638

Browse files
dccipytorchmergebot
authored andcommitted
[c10/metal] Add a vectype variant for short/int/long (pytorch#145430)
Some of the kernels (exp_complex/atan_complex) need the specialization. Pull Request resolved: pytorch#145430 Approved by: https://github.com/malfet, https://github.com/jansel
1 parent c581981 commit f56c638

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

c10/metal/utils.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,28 @@ struct vectypes<bfloat> {
2929
using type2 = bfloat2;
3030
};
3131
#endif
32+
33+
template <>
34+
struct vectypes<short> {
35+
using type4 = short4;
36+
using type3 = short3;
37+
using type2 = short2;
38+
};
39+
40+
template <>
41+
struct vectypes<int> {
42+
using type4 = int4;
43+
using type3 = int3;
44+
using type2 = int2;
45+
};
46+
47+
template <>
48+
struct vectypes<long> {
49+
using type4 = short4;
50+
using type3 = short3;
51+
using type2 = short2;
52+
};
53+
3254
} // namespace detail
3355

3456
template <typename T>

0 commit comments

Comments
 (0)