Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit f6aa9e9

Browse files
committed
workaround for windows
1 parent 727b84d commit f6aa9e9

File tree

7 files changed

+224
-94
lines changed

7 files changed

+224
-94
lines changed

src/operator/mshadow_op.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ struct true_divide : public mxnet_op::tunable {
133133
return static_cast<float>(a) / static_cast<float>(b);
134134
}
135135

136+
#ifndef _WIN32
136137
template<typename DType,
137138
typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
138139
MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) {
@@ -150,6 +151,7 @@ struct true_divide : public mxnet_op::tunable {
150151
MSHADOW_XINLINE static double Map(DType a, double b) {
151152
return static_cast<double>(a) / b;
152153
}
154+
#endif
153155
};
154156

155157
struct rtrue_divide : public mxnet_op::tunable {
@@ -165,6 +167,7 @@ struct rtrue_divide : public mxnet_op::tunable {
165167
return static_cast<float>(b) / static_cast<float>(a);
166168
}
167169

170+
#ifndef _WIN32
168171
template<typename DType,
169172
typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
170173
MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) {
@@ -182,6 +185,7 @@ struct rtrue_divide : public mxnet_op::tunable {
182185
MSHADOW_XINLINE static double Map(DType a, double b) {
183186
return b / static_cast<double>(a);
184187
}
188+
#endif
185189
};
186190

187191
MXNET_BINARY_MATH_OP_NC(left, a);
@@ -190,13 +194,15 @@ MXNET_BINARY_MATH_OP_NC(right, b);
190194

191195
MXNET_BINARY_MATH_OP_NC(mul, a * b);
192196

197+
#ifndef _WIN32
193198
struct mixed_mul {
194199
template<typename DType,
195200
typename std::enable_if<!std::is_pointer<DType>::value, int>::type = 0>
196201
MSHADOW_XINLINE static DType Map(bool a, DType b) {
197202
return static_cast<DType>(a) * b;
198203
}
199204
};
205+
#endif
200206

201207
MXNET_BINARY_MATH_OP_NC(div, a / b);
202208

src/operator/numpy/np_elemwise_broadcast_op.cc

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ bool NumpyBinaryScalarType(const nnvm::NodeAttrs& attrs,
5454
.add_argument("data", "NDArray-or-Symbol", "source input") \
5555
.add_argument("scalar", "float", "scalar input")
5656

57-
#ifndef _WIN32
5857
bool NumpyBinaryMixedPrecisionType(const nnvm::NodeAttrs& attrs,
5958
std::vector<int>* in_attrs,
6059
std::vector<int>* out_attrs) {
@@ -71,6 +70,28 @@ bool NumpyBinaryMixedPrecisionType(const nnvm::NodeAttrs& attrs,
7170
return true;
7271
}
7372

73+
#ifdef _WIN32
74+
#define MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(name) \
75+
NNVM_REGISTER_OP(name) \
76+
.set_num_inputs(2) \
77+
.set_num_outputs(1) \
78+
.set_attr<nnvm::FListInputNames>("FListInputNames", \
79+
[](const NodeAttrs& attrs) { \
80+
return std::vector<std::string>{"lhs", "rhs"}; \
81+
}) \
82+
.set_attr<mxnet::FInferShape>("FInferShape", BinaryBroadcastShape) \
83+
.set_attr<nnvm::FInferType>("FInferType", NumpyBinaryMixedPrecisionType) \
84+
.set_attr<nnvm::FInplaceOption>("FInplaceOption", \
85+
[](const NodeAttrs& attrs){ \
86+
return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}}; \
87+
}) \
88+
.set_attr<FResourceRequest>("FResourceRequest", \
89+
[](const NodeAttrs& attrs) { \
90+
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; \
91+
}) \
92+
.add_argument("lhs", "NDArray-or-Symbol", "First input to the function") \
93+
.add_argument("rhs", "NDArray-or-Symbol", "Second input to the function")
94+
#else
7495
#define MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(name) \
7596
NNVM_REGISTER_OP(name) \
7697
.set_num_inputs(2) \
@@ -97,12 +118,18 @@ MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_subtract)
97118
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, op::mshadow_op::minus>)
98119
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_broadcast_sub"});
99120

100-
#ifndef _WIN32
101121
MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_multiply)
122+
#ifndef _WIN32
102123
.set_attr<FCompute>(
103124
"FCompute<cpu>",
104125
MixedBinaryBroadcastCompute<cpu, op::mshadow_op::mul, op::mshadow_op::mixed_mul,
105126
op::mshadow_op::mixed_mul>)
127+
#else
128+
.set_attr<FCompute>(
129+
"FCompute<cpu>",
130+
MixedBinaryBroadcastCompute<cpu, op::mshadow_op::mul, op::mshadow_op::mul,
131+
op::mshadow_op::mul>)
132+
#endif
106133
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_broadcast_mul"});
107134

108135
NNVM_REGISTER_OP(_backward_npi_broadcast_mul)
@@ -119,11 +146,6 @@ NNVM_REGISTER_OP(_backward_npi_broadcast_mul)
119146
})
120147
.set_attr<FCompute>("FCompute<cpu>", MixedBinaryBackwardUseIn<cpu, mshadow_op::right,
121148
mshadow_op::left>);
122-
#else
123-
MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_multiply)
124-
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, op::mshadow_op::mul>)
125-
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_broadcast_mul"});
126-
#endif
127149

128150
MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_mod)
129151
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::mod>)

src/operator/numpy/np_elemwise_broadcast_op.cu

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,16 @@ NNVM_REGISTER_OP(_npi_multiply)
4141
"FCompute<gpu>",
4242
MixedBinaryBroadcastCompute<gpu, op::mshadow_op::mul, op::mshadow_op::mixed_mul,
4343
op::mshadow_op::mixed_mul>);
44+
#else
45+
.set_attr<FCompute>(
46+
"FCompute<gpu>",
47+
MixedBinaryBroadcastCompute<gpu, op::mshadow_op::mul, op::mshadow_op::mul,
48+
op::mshadow_op::mul>);
49+
#endif
4450

4551
NNVM_REGISTER_OP(_backward_npi_broadcast_mul)
4652
.set_attr<FCompute>("FCompute<gpu>", MixedBinaryBackwardUseIn<gpu, mshadow_op::right,
4753
mshadow_op::left>);
48-
#else
49-
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, op::mshadow_op::mul>);
50-
#endif
5154

5255
NNVM_REGISTER_OP(_npi_mod)
5356
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::mod>);

src/operator/numpy/np_elemwise_broadcast_op.h

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ void MixedBinaryElemwiseCompute(const nnvm::NodeAttrs& attrs,
3939
const std::vector<TBlob>& inputs,
4040
const std::vector<OpReqType>& req,
4141
const std::vector<TBlob>& outputs) {
42-
// TODO(haojin2): No mixed-precision multiply on windows temporarily due to CI issues.
4342
#ifndef _WIN32
4443
using namespace mshadow;
4544
using namespace mxnet_op;
@@ -71,7 +70,7 @@ void MixedBinaryElemwiseCompute(const nnvm::NodeAttrs& attrs,
7170
});
7271
});
7372
#else
74-
LOG(ERROR) << "mixed precision multiply is not supported on windows yet...";
73+
LOG(ERROR) << "windows should not reach here...";
7574
#endif
7675
}
7776

@@ -92,22 +91,18 @@ void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs,
9291

9392
if ((out.shape_.Size() == 0U) || (req[0] == kNullOp)) return;
9493

95-
mxnet::TShape new_lshape, new_rshape, new_oshape;
96-
int ndim = BinaryBroadcastShapeCompact(lhs.shape_, rhs.shape_, out.shape_,
97-
&new_lshape, &new_rshape, &new_oshape);
98-
99-
10094
if (lhs.type_flag_ == rhs.type_flag_) {
10195
BinaryBroadcastCompute<xpu, OP>(attrs, ctx, inputs, req, outputs);
10296
return;
10397
}
10498

105-
// TODO(haojin2): No mixed-precision multiply on windows temporarily due to CI issues.
106-
#ifndef _WIN32
10799
CHECK((lhs.type_flag_ == mshadow::kBool) || (rhs.type_flag_ == mshadow::kBool))
108100
<< "now supports bool with another type only";
109101

110-
102+
#ifndef _WIN32
103+
mxnet::TShape new_lshape, new_rshape, new_oshape;
104+
int ndim = BinaryBroadcastShapeCompact(lhs.shape_, rhs.shape_, out.shape_,
105+
&new_lshape, &new_rshape, &new_oshape);
111106
if (!ndim) {
112107
MixedBinaryElemwiseCompute<xpu, LOP, ROP>(attrs, ctx, inputs, req, outputs);
113108
} else {
@@ -130,7 +125,37 @@ void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs,
130125
});
131126
}
132127
#else
133-
LOG(ERROR) << "mixed precision multiply is not supported on windows yet...";
128+
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
129+
if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) {
130+
LOG(ERROR) << "not implemented yet...";
131+
} else if (common::is_float(lhs.type_flag_) || common::is_float(rhs.type_flag_)) {
132+
TBlob temp_tblob;
133+
// one is float, the other is bool
134+
CHECK_EQ(out.type_flag_,
135+
common::is_float(lhs.type_flag_) ? lhs.type_flag_ : rhs.type_flag_)
136+
<< "This case out type should be same as the float type";
137+
if (common::is_float(lhs.type_flag_)) {
138+
MSHADOW_REAL_TYPE_SWITCH(lhs.type_flag_, LType, {
139+
Tensor<xpu, 1, LType> temp_tensor =
140+
ctx.requested[0].get_space_typed<xpu, 1, LType>(Shape1(rhs.Size()), s);
141+
temp_tblob = TBlob(temp_tensor);
142+
});
143+
CastCompute<xpu>(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob});
144+
BinaryBroadcastCompute<xpu, OP>(
145+
attrs, ctx, {lhs, temp_tblob.reshape(rhs.shape_)}, req, outputs);
146+
} else {
147+
MSHADOW_REAL_TYPE_SWITCH(rhs.type_flag_, RType, {
148+
Tensor<xpu, 1, RType> temp_tensor =
149+
ctx.requested[0].get_space_typed<xpu, 1, RType>(Shape1(lhs.Size()), s);
150+
temp_tblob = TBlob(temp_tensor);
151+
});
152+
CastCompute<xpu>(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob});
153+
BinaryBroadcastCompute<xpu, OP>(
154+
attrs, ctx, {temp_tblob.reshape(lhs.shape_), rhs}, req, outputs);
155+
}
156+
} else {
157+
LOG(ERROR) << "not implemented yet...";
158+
}
134159
#endif
135160
}
136161

0 commit comments

Comments
 (0)