Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit f9140bb

Browse files
committed
support mixed-precision binary operations
1 parent 0c5677e commit f9140bb

File tree

9 files changed

+632
-21
lines changed

9 files changed

+632
-21
lines changed

src/common/utils.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -842,7 +842,7 @@ inline bool is_float(const int dtype) {
842842
return dtype == mshadow::kFloat32 || dtype == mshadow::kFloat64 || dtype == mshadow::kFloat16;
843843
}
844844

845-
inline int more_precise_type(const int type1, const int type2) {
845+
inline int get_more_precise_type(const int type1, const int type2) {
846846
if (type1 == type2) return type1;
847847
if (is_float(type1) && is_float(type2)) {
848848
if (type1 == mshadow::kFloat64 || type2 == mshadow::kFloat64) {
@@ -870,12 +870,12 @@ inline int more_precise_type(const int type1, const int type2) {
870870
return mshadow::kInt8;
871871
}
872872

873-
inline int np_binary_out_type(const int type1, const int type2) {
873+
inline int np_binary_out_infer_type(const int type1, const int type2) {
874874
if ((type1 == mshadow::kUint8 && type2 == mshadow::kInt8) ||
875875
(type1 == mshadow::kInt8 && type2 == mshadow::kUint8)) {
876876
return mshadow::kInt32;
877877
}
878-
return more_precise_type(type1, type2);
878+
return get_more_precise_type(type1, type2);
879879
}
880880

881881
} // namespace common

src/operator/mshadow_op.h

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,100 @@ MXNET_BINARY_MATH_OP_NC(right, b);
194194

195195
MXNET_BINARY_MATH_OP_NC(mul, a * b);
196196

197+
#ifndef _WIN32
198+
struct mixed_plus {
199+
template<typename DType,
200+
typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
201+
MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) {
202+
return static_cast<mshadow::half::half_t>(a) + b;
203+
}
204+
205+
template<typename DType,
206+
typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
207+
std::is_integral<DType>::value, int>::type = 0>
208+
MSHADOW_XINLINE static float Map(DType a, float b) {
209+
return static_cast<float>(a) + b;
210+
}
211+
212+
template<typename DType,
213+
typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
214+
std::is_same<DType, float>::value ||
215+
std::is_integral<DType>::value, int>::type = 0>
216+
MSHADOW_XINLINE static double Map(DType a, double b) {
217+
return static_cast<double>(a) + b;
218+
}
219+
};
220+
221+
struct mixed_minus {
222+
template<typename DType,
223+
typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
224+
MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) {
225+
return static_cast<mshadow::half::half_t>(a) - b;
226+
}
227+
228+
template<typename DType,
229+
typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
230+
std::is_integral<DType>::value, int>::type = 0>
231+
MSHADOW_XINLINE static float Map(DType a, float b) {
232+
return static_cast<float>(a) - b;
233+
}
234+
235+
template<typename DType,
236+
typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
237+
std::is_same<DType, float>::value ||
238+
std::is_integral<DType>::value, int>::type = 0>
239+
MSHADOW_XINLINE static double Map(DType a, double b) {
240+
return static_cast<double>(a) - b;
241+
}
242+
};
243+
244+
struct mixed_rminus {
245+
template<typename DType,
246+
typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
247+
MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) {
248+
return b - static_cast<mshadow::half::half_t>(a);
249+
}
250+
251+
template<typename DType,
252+
typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
253+
std::is_integral<DType>::value, int>::type = 0>
254+
MSHADOW_XINLINE static float Map(DType a, float b) {
255+
return b - static_cast<float>(a);
256+
}
257+
258+
template<typename DType,
259+
typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
260+
std::is_same<DType, float>::value ||
261+
std::is_integral<DType>::value, int>::type = 0>
262+
MSHADOW_XINLINE static double Map(DType a, double b) {
263+
return b - static_cast<double>(a);
264+
}
265+
};
266+
267+
struct mixed_mul {
268+
template<typename DType,
269+
typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
270+
MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) {
271+
return static_cast<mshadow::half::half_t>(a) * b;
272+
}
273+
274+
template<typename DType,
275+
typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
276+
std::is_integral<DType>::value, int>::type = 0>
277+
MSHADOW_XINLINE static float Map(DType a, float b) {
278+
return static_cast<float>(a) * b;
279+
}
280+
281+
template<typename DType,
282+
typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
283+
std::is_same<DType, float>::value ||
284+
std::is_integral<DType>::value, int>::type = 0>
285+
MSHADOW_XINLINE static double Map(DType a, double b) {
286+
return static_cast<double>(a) * b;
287+
}
288+
};
289+
#endif
290+
197291
MXNET_BINARY_MATH_OP_NC(div, a / b);
198292

199293
MXNET_BINARY_MATH_OP_NC(plus, a + b);

src/operator/mxnet_op.h

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -859,14 +859,17 @@ struct op_with_req {
859859

860860
/*! \brief inputs are two tensors with a float output tensor */
861861
template<typename DType,
862-
typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
862+
typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
863+
std::is_integral<DType>::value, int>::type = 0>
863864
MSHADOW_XINLINE static void Map(index_t i, float *out, const DType *lhs, const float *rhs) {
864865
KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], rhs[i]));
865866
}
866867

867868
/*! \brief inputs are two tensors with a double output tensor */
868869
template<typename DType,
869-
typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
870+
typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
871+
std::is_same<DType, float>::value ||
872+
std::is_integral<DType>::value, int>::type = 0>
870873
MSHADOW_XINLINE static void Map(index_t i, double *out, const DType *lhs, const double *rhs) {
871874
KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], rhs[i]));
872875
}
@@ -883,14 +886,17 @@ struct op_with_req {
883886

884887
/*! \brief inputs are two tensors with a float output tensor */
885888
template<typename DType,
886-
typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
889+
typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
890+
std::is_integral<DType>::value, int>::type = 0>
887891
MSHADOW_XINLINE static void Map(index_t i, float *out, const DType *lhs, const float value) {
888892
KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], value));
889893
}
890894

891895
/*! \brief inputs are two tensors with a double output tensor */
892896
template<typename DType,
893-
typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
897+
typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
898+
std::is_same<DType, float>::value ||
899+
std::is_integral<DType>::value, int>::type = 0>
894900
MSHADOW_XINLINE static void Map(index_t i, double *out, const DType *lhs, const double value) {
895901
KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], value));
896902
}

src/operator/numpy/np_elemwise_broadcast_op.cc

Lines changed: 92 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@
2323
* \brief CPU Implementation of basic functions for elementwise numpy binary broadcast operator.
2424
*/
2525

26-
#include "../tensor/elemwise_binary_broadcast_op.h"
27-
#include "../tensor/elemwise_binary_scalar_op.h"
26+
#include "./np_elemwise_broadcast_op.h"
2827

2928
namespace mxnet {
3029
namespace op {
@@ -55,17 +54,102 @@ bool NumpyBinaryScalarType(const nnvm::NodeAttrs& attrs,
5554
.add_argument("data", "NDArray-or-Symbol", "source input") \
5655
.add_argument("scalar", "float", "scalar input")
5756

57+
bool NumpyBinaryMixedPrecisionType(const nnvm::NodeAttrs& attrs,
58+
std::vector<int>* in_attrs,
59+
std::vector<int>* out_attrs) {
60+
CHECK_EQ(in_attrs->size(), 2U);
61+
CHECK_EQ(out_attrs->size(), 1U);
62+
const int ltype = in_attrs->at(0);
63+
const int rtype = in_attrs->at(1);
64+
if (ltype != -1 && rtype != -1 && (ltype != rtype)) {
65+
// Only when both input types are known and not the same, we enter the mixed-precision mode
66+
TYPE_ASSIGN_CHECK(*out_attrs, 0, common::np_binary_out_infer_type(ltype, rtype));
67+
} else {
68+
return ElemwiseType<2, 1>(attrs, in_attrs, out_attrs);
69+
}
70+
return true;
71+
}
5872

59-
MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_add)
60-
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, op::mshadow_op::plus>)
73+
#ifdef _WIN32
74+
#define MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(name) \
75+
NNVM_REGISTER_OP(name) \
76+
.set_num_inputs(2) \
77+
.set_num_outputs(1) \
78+
.set_attr<nnvm::FListInputNames>("FListInputNames", \
79+
[](const NodeAttrs& attrs) { \
80+
return std::vector<std::string>{"lhs", "rhs"}; \
81+
}) \
82+
.set_attr<mxnet::FInferShape>("FInferShape", BinaryBroadcastShape) \
83+
.set_attr<nnvm::FInferType>("FInferType", NumpyBinaryMixedPrecisionType) \
84+
.set_attr<nnvm::FInplaceOption>("FInplaceOption", \
85+
[](const NodeAttrs& attrs){ \
86+
return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}}; \
87+
}) \
88+
.set_attr<FResourceRequest>("FResourceRequest", \
89+
[](const NodeAttrs& attrs) { \
90+
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; \
91+
}) \
92+
.add_argument("lhs", "NDArray-or-Symbol", "First input to the function") \
93+
.add_argument("rhs", "NDArray-or-Symbol", "Second input to the function")
94+
#else
95+
#define MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(name) \
96+
NNVM_REGISTER_OP(name) \
97+
.set_num_inputs(2) \
98+
.set_num_outputs(1) \
99+
.set_attr<nnvm::FListInputNames>("FListInputNames", \
100+
[](const NodeAttrs& attrs) { \
101+
return std::vector<std::string>{"lhs", "rhs"}; \
102+
}) \
103+
.set_attr<mxnet::FInferShape>("FInferShape", BinaryBroadcastShape) \
104+
.set_attr<nnvm::FInferType>("FInferType", NumpyBinaryMixedPrecisionType) \
105+
.set_attr<nnvm::FInplaceOption>("FInplaceOption", \
106+
[](const NodeAttrs& attrs){ \
107+
return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}}; \
108+
}) \
109+
.add_argument("lhs", "NDArray-or-Symbol", "First input to the function") \
110+
.add_argument("rhs", "NDArray-or-Symbol", "Second input to the function")
111+
#endif
112+
113+
MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_add)
114+
#ifndef _WIN32
115+
.set_attr<FCompute>(
116+
"FCompute<cpu>",
117+
MixedBinaryBroadcastCompute<cpu, op::mshadow_op::plus, op::mshadow_op::mixed_plus,
118+
op::mshadow_op::mixed_plus>)
119+
#else
120+
.set_attr<FCompute>(
121+
"FCompute<cpu>",
122+
MixedBinaryBroadcastCompute<cpu, op::mshadow_op::plus, op::mshadow_op::plus,
123+
op::mshadow_op::plus>)
124+
#endif
61125
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_broadcast_add"});
62126

63-
MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_subtract)
64-
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, op::mshadow_op::minus>)
127+
MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_subtract)
128+
#ifndef _WIN32
129+
.set_attr<FCompute>(
130+
"FCompute<cpu>",
131+
MixedBinaryBroadcastCompute<cpu, op::mshadow_op::minus, op::mshadow_op::mixed_minus,
132+
op::mshadow_op::mixed_rminus>)
133+
#else
134+
.set_attr<FCompute>(
135+
"FCompute<cpu>",
136+
MixedBinaryBroadcastCompute<cpu, op::mshadow_op::minus, op::mshadow_op::minus,
137+
op::mshadow_op::minus>)
138+
#endif
65139
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_broadcast_sub"});
66140

67-
MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_multiply)
68-
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, op::mshadow_op::mul>)
141+
MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_multiply)
142+
#ifndef _WIN32
143+
.set_attr<FCompute>(
144+
"FCompute<cpu>",
145+
MixedBinaryBroadcastCompute<cpu, op::mshadow_op::mul, op::mshadow_op::mixed_mul,
146+
op::mshadow_op::mixed_mul>)
147+
#else
148+
.set_attr<FCompute>(
149+
"FCompute<cpu>",
150+
MixedBinaryBroadcastCompute<cpu, op::mshadow_op::mul, op::mshadow_op::mul,
151+
op::mshadow_op::mul>)
152+
#endif
69153
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_broadcast_mul"});
70154

71155
MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_mod)

src/operator/numpy/np_elemwise_broadcast_op.cu

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,50 @@
2222
* \file np_elemwise_broadcast_op.cu
2323
* \brief GPU Implementation of basic functions for elementwise binary broadcast operator.
2424
*/
25-
#include "../tensor/elemwise_binary_broadcast_op.h"
26-
#include "../tensor/elemwise_binary_scalar_op.h"
25+
26+
#include "./np_elemwise_broadcast_op.h"
2727

2828
namespace mxnet {
2929
namespace op {
3030

3131
NNVM_REGISTER_OP(_npi_add)
32-
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, op::mshadow_op::plus>);
32+
#ifndef _WIN32
33+
.set_attr<FCompute>(
34+
"FCompute<gpu>",
35+
MixedBinaryBroadcastCompute<gpu, op::mshadow_op::plus, op::mshadow_op::mixed_plus,
36+
op::mshadow_op::mixed_plus>);
37+
#else
38+
.set_attr<FCompute>(
39+
"FCompute<gpu>",
40+
MixedBinaryBroadcastCompute<gpu, op::mshadow_op::plus, op::mshadow_op::plus,
41+
op::mshadow_op::plus>);
42+
#endif
3343

3444
NNVM_REGISTER_OP(_npi_subtract)
35-
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, op::mshadow_op::minus>);
45+
#ifndef _WIN32
46+
.set_attr<FCompute>(
47+
"FCompute<gpu>",
48+
MixedBinaryBroadcastCompute<gpu, op::mshadow_op::minus, op::mshadow_op::mixed_minus,
49+
op::mshadow_op::mixed_rminus>);
50+
#else
51+
.set_attr<FCompute>(
52+
"FCompute<gpu>",
53+
MixedBinaryBroadcastCompute<gpu, op::mshadow_op::minus, op::mshadow_op::minus,
54+
op::mshadow_op::minus>);
55+
#endif
3656

3757
NNVM_REGISTER_OP(_npi_multiply)
38-
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, op::mshadow_op::mul>);
58+
#ifndef _WIN32
59+
.set_attr<FCompute>(
60+
"FCompute<gpu>",
61+
MixedBinaryBroadcastCompute<gpu, op::mshadow_op::mul, op::mshadow_op::mixed_mul,
62+
op::mshadow_op::mixed_mul>);
63+
#else
64+
.set_attr<FCompute>(
65+
"FCompute<gpu>",
66+
MixedBinaryBroadcastCompute<gpu, op::mshadow_op::mul, op::mshadow_op::mul,
67+
op::mshadow_op::mul>);
68+
#endif
3969

4070
NNVM_REGISTER_OP(_npi_mod)
4171
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::mod>);

0 commit comments

Comments
 (0)