Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 1ae73ea

Browse files
committed
improvement for documentations and error messages
1 parent f9140bb commit 1ae73ea

File tree

4 files changed

+129
-17
lines changed

4 files changed

+129
-17
lines changed

python/mxnet/ndarray/numpy/_op.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,14 @@ def add(x1, x2, out=None, **kwargs):
523523
-------
524524
add : ndarray or scalar
525525
The sum of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars.
526+
527+
Notes
528+
-----
529+
This operator now supports automatic type promotion. The resulting type will be determined
530+
according to the following rules:
531+
* If both inputs are of floating number types, the output is the more precise type.
532+
* If only one of the inputs is floating number type, the result is that type.
533+
* If both inputs are of integer types (including boolean), not supported yet.
526534
"""
527535
return _ufunc_helper(x1, x2, _npi.add, _np.add, _npi.add_scalar, None, out)
528536

@@ -549,6 +557,14 @@ def subtract(x1, x2, out=None, **kwargs):
549557
-------
550558
subtract : ndarray or scalar
551559
The difference of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars.
560+
561+
Notes
562+
-----
563+
This operator now supports automatic type promotion. The resulting type will be determined
564+
according to the following rules:
565+
* If both inputs are of floating number types, the output is the more precise type.
566+
* If only one of the inputs is floating number type, the result is that type.
567+
* If both inputs are of integer types (including boolean), not supported yet.
552568
"""
553569
return _ufunc_helper(x1, x2, _npi.subtract, _np.subtract, _npi.subtract_scalar,
554570
_npi.rsubtract_scalar, out)
@@ -576,6 +592,14 @@ def multiply(x1, x2, out=None, **kwargs):
576592
out : ndarray or scalar
577593
The multiplication of x1 and x2, element-wise. This is a scalar if both x1 and x2
578594
are scalars.
595+
596+
Notes
597+
-----
598+
This operator now supports automatic type promotion. The resulting type will be determined
599+
according to the following rules:
600+
* If both inputs are of floating number types, the output is the more precise type.
601+
* If only one of the inputs is floating number type, the result is that type.
602+
* If both inputs are of integer types (including boolean), not supported yet.
579603
"""
580604
return _ufunc_helper(x1, x2, _npi.multiply, _np.multiply, _npi.multiply_scalar, None, out)
581605

@@ -603,6 +627,14 @@ def divide(x1, x2, out=None, **kwargs):
603627
-------
604628
out : ndarray or scalar
605629
This is a scalar if both x1 and x2 are scalars.
630+
631+
Notes
632+
-----
633+
This operator now supports automatic type promotion. The resulting type will be determined
634+
according to the following rules:
635+
* If both inputs are of floating number types, the output is the more precise type.
636+
* If only one of the inputs is floating number type, the result is that type.
637+
* If both inputs are of integer types (including boolean), the output is of float32 type.
606638
"""
607639
return _ufunc_helper(x1, x2, _npi.true_divide, _np.divide, _npi.true_divide_scalar,
608640
_npi.rtrue_divide_scalar, out)
@@ -633,6 +665,14 @@ def true_divide(x1, x2, out=None):
633665
-------
634666
out : ndarray or scalar
635667
This is a scalar if both x1 and x2 are scalars.
668+
669+
Notes
670+
-----
671+
This operator now supports automatic type promotion. The resulting type will be determined
672+
according to the following rules:
673+
* If both inputs are of floating number types, the output is the more precise type.
674+
* If only one of the inputs is floating number type, the result is that type.
675+
* If both inputs are of integer types (including boolean), the output is of float32 type.
636676
"""
637677
return _ufunc_helper(x1, x2, _npi.true_divide, _np.divide, _npi.true_divide_scalar,
638678
_npi.rtrue_divide_scalar, out)

python/mxnet/numpy/multiarray.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2402,6 +2402,14 @@ def add(x1, x2, out=None, **kwargs):
24022402
add : ndarray or scalar
24032403
The sum of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars.
24042404
2405+
Notes
2406+
-----
2407+
This operator now supports automatic type promotion. The resulting type will be determined
2408+
according to the following rules:
2409+
* If both inputs are of floating number types, the output is the more precise type.
2410+
* If only one of the inputs is floating number type, the result is that type.
2411+
* If both inputs are of integer types (including boolean), not supported yet.
2412+
24052413
Examples
24062414
--------
24072415
>>> np.add(1.0, 4.0)
@@ -2440,6 +2448,14 @@ def subtract(x1, x2, out=None, **kwargs):
24402448
subtract : ndarray or scalar
24412449
The difference of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars.
24422450
2451+
Notes
2452+
-----
2453+
This operator now supports automatic type promotion. The resulting type will be determined
2454+
according to the following rules:
2455+
* If both inputs are of floating number types, the output is the more precise type.
2456+
* If only one of the inputs is floating number type, the result is that type.
2457+
* If both inputs are of integer types (including boolean), not supported yet.
2458+
24432459
Examples
24442460
--------
24452461
>>> np.subtract(1.0, 4.0)
@@ -2476,6 +2492,14 @@ def multiply(x1, x2, out=None, **kwargs):
24762492
out : ndarray or scalar
24772493
The difference of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars.
24782494
2495+
Notes
2496+
-----
2497+
This operator now supports automatic type promotion. The resulting type will be determined
2498+
according to the following rules:
2499+
* If both inputs are of floating number types, the output is the more precise type.
2500+
* If only one of the inputs is floating number type, the result is that type.
2501+
* If both inputs are of integer types (including boolean), not supported yet.
2502+
24792503
Examples
24802504
--------
24812505
>>> np.multiply(2.0, 4.0)
@@ -2514,6 +2538,14 @@ def divide(x1, x2, out=None, **kwargs):
25142538
out : ndarray or scalar
25152539
This is a scalar if both x1 and x2 are scalars.
25162540
2541+
Notes
2542+
-----
2543+
This operator now supports automatic type promotion. The resulting type will be determined
2544+
according to the following rules:
2545+
* If both inputs are of floating number types, the output is the more precise type.
2546+
* If only one of the inputs is floating number type, the result is that type.
2547+
* If both inputs are of integer types (including boolean), the output is of float32 type.
2548+
25172549
Examples
25182550
--------
25192551
>>> np.true_divide(x, 4)
@@ -2548,6 +2580,14 @@ def true_divide(x1, x2, out=None):
25482580
out : ndarray or scalar
25492581
This is a scalar if both x1 and x2 are scalars.
25502582
2583+
Notes
2584+
-----
2585+
This operator now supports automatic type promotion. The resulting type will be determined
2586+
according to the following rules:
2587+
* If both inputs are of floating number types, the output is the more precise type.
2588+
* If only one of the inputs is floating number type, the result is that type.
2589+
* If both inputs are of integer types (including boolean), the output is of float32 type.
2590+
25512591
Examples
25522592
--------
25532593
>>> x = np.arange(5)

src/common/utils.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,30 @@ inline bool ContainsStorageType(const std::vector<int>& ndstypes,
365365
return false;
366366
}
367367

368+
inline std::string dtype_string(const int dtype) {
369+
switch (dtype) {
370+
case mshadow::kFloat32:
371+
return "float";
372+
case mshadow::kFloat64:
373+
return "double";
374+
case mshadow::kFloat16:
375+
return "half";
376+
case mshadow::kUint8:
377+
return "unsigned char";
378+
case mshadow::kInt8:
379+
return "char";
380+
case mshadow::kInt32:
381+
return "int";
382+
case mshadow::kInt64:
383+
return "long long";
384+
case mshadow::kBool:
385+
return "bool";
386+
default:
387+
LOG(FATAL) << "Unknown type enum " << dtype;
388+
}
389+
return "unknown";
390+
}
391+
368392
/*! \brief get string representation of dispatch_mode */
369393
inline std::string dispatch_mode_string(const DispatchMode x) {
370394
switch (x) {

src/operator/numpy/np_elemwise_broadcast_op.h

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
/*!
2121
* Copyright (c) 2019 by Contributors
2222
* \file np_elemwise_binary_op.h
23-
* \brief
23+
* \brief Function definition of elemwise and broadcast operators
2424
*/
2525
#ifndef MXNET_OPERATOR_NUMPY_NP_ELEMWISE_BROADCAST_OP_H_
2626
#define MXNET_OPERATOR_NUMPY_NP_ELEMWISE_BROADCAST_OP_H_
@@ -33,9 +33,16 @@
3333
namespace mxnet {
3434
namespace op {
3535

36+
inline void PrintErrorMessage(const std::string& name, const int dtype1, const int dtype2) {
37+
LOG(FATAL) << "Operator " << name << " does not support combination of "
38+
<< common::dtype_string(dtype1) << " with " << common::dtype_string(dtype2)
39+
<< " yet...";
40+
}
41+
3642
#ifndef _WIN32
3743
template<typename xpu, typename OP>
38-
void MixedAllRealBinaryElemwiseCompute(const OpContext& ctx,
44+
void MixedAllRealBinaryElemwiseCompute(const std::string& op_name,
45+
const OpContext& ctx,
3946
const TBlob& lhs,
4047
const TBlob& rhs,
4148
const TBlob& out,
@@ -61,7 +68,7 @@ void MixedAllRealBinaryElemwiseCompute(const OpContext& ctx,
6168
lhs.dptr<float>());
6269
});
6370
} else {
64-
LOG(ERROR) << "Should not reach here!";
71+
PrintErrorMessage(op_name, lhs.type_flag_, rhs.type_flag_);
6572
}
6673
break;
6774
}
@@ -80,13 +87,13 @@ void MixedAllRealBinaryElemwiseCompute(const OpContext& ctx,
8087
lhs.dptr<double>());
8188
});
8289
} else {
83-
LOG(ERROR) << "Should not reach here!";
90+
PrintErrorMessage(op_name, lhs.type_flag_, rhs.type_flag_);
8491
}
8592
break;
8693
}
8794
default:
8895
{
89-
LOG(ERROR) << "Not supported case of ...";
96+
PrintErrorMessage(op_name, lhs.type_flag_, rhs.type_flag_);
9097
break;
9198
}
9299
}
@@ -137,9 +144,9 @@ void MixedBinaryElemwiseCompute(const nnvm::NodeAttrs& attrs,
137144

138145
if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) {
139146
if (lhs.type_flag_ == out.type_flag_) {
140-
MixedAllRealBinaryElemwiseCompute<xpu, ROP>(ctx, lhs, rhs, out, req[0]);
147+
MixedAllRealBinaryElemwiseCompute<xpu, ROP>(attrs.op->name, ctx, lhs, rhs, out, req[0]);
141148
} else {
142-
MixedAllRealBinaryElemwiseCompute<xpu, LOP>(ctx, rhs, lhs, out, req[0]);
149+
MixedAllRealBinaryElemwiseCompute<xpu, LOP>(attrs.op->name, ctx, rhs, lhs, out, req[0]);
143150
}
144151
} else if (common::is_float(lhs.type_flag_) || common::is_float(rhs.type_flag_)) {
145152
if (lhs.type_flag_ == out.type_flag_) {
@@ -148,12 +155,13 @@ void MixedBinaryElemwiseCompute(const nnvm::NodeAttrs& attrs,
148155
MixedIntRealBinaryElemwiseCompute<xpu, LOP>(ctx, rhs, lhs, out, req[0]);
149156
}
150157
} else {
151-
LOG(ERROR) << "not implemented yet...";
158+
PrintErrorMessage(attrs.op->name, lhs.type_flag_, rhs.type_flag_);
152159
}
153160
}
154161

155162
template<typename xpu, typename OP>
156-
void MixedAllRealBinaryBroadcastCompute(const OpContext& ctx,
163+
void MixedAllRealBinaryBroadcastCompute(const std::string& op_name,
164+
const OpContext& ctx,
157165
const TBlob& lhs,
158166
const TBlob& rhs,
159167
const TBlob& out,
@@ -180,7 +188,7 @@ void MixedAllRealBinaryBroadcastCompute(const OpContext& ctx,
180188
template LaunchEx(s, new_oshape.Size(), req, rstride, lstride, oshape,
181189
rhs.dptr<mshadow::half::half_t>(), lhs.dptr<float>(), out.dptr<float>());
182190
} else {
183-
LOG(ERROR) << "Should not reach here!";
191+
PrintErrorMessage(op_name, lhs.type_flag_, rhs.type_flag_);
184192
}
185193
break;
186194
}
@@ -195,13 +203,13 @@ void MixedAllRealBinaryBroadcastCompute(const OpContext& ctx,
195203
template LaunchEx(s, new_oshape.Size(), req, rstride, lstride, oshape,
196204
rhs.dptr<float>(), lhs.dptr<double>(), out.dptr<double>());
197205
} else {
198-
LOG(ERROR) << "Should not reach here!";
206+
PrintErrorMessage(op_name, lhs.type_flag_, rhs.type_flag_);
199207
}
200208
break;
201209
}
202210
default:
203211
{
204-
LOG(ERROR) << "Not supported case of ...";
212+
PrintErrorMessage(op_name, lhs.type_flag_, rhs.type_flag_);
205213
break;
206214
}
207215
}
@@ -242,10 +250,10 @@ void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs,
242250
if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) {
243251
if (lhs.type_flag_ == out.type_flag_) {
244252
MixedAllRealBinaryBroadcastCompute<xpu, ROP>(
245-
ctx, lhs, rhs, out, req[0], ndim, new_oshape, new_lshape, new_rshape);
253+
attrs.op->name, ctx, lhs, rhs, out, req[0], ndim, new_oshape, new_lshape, new_rshape);
246254
} else {
247255
MixedAllRealBinaryBroadcastCompute<xpu, LOP>(
248-
ctx, rhs, lhs, out, req[0], ndim, new_oshape, new_rshape, new_lshape);
256+
attrs.op->name, ctx, rhs, lhs, out, req[0], ndim, new_oshape, new_rshape, new_lshape);
249257
}
250258
} else if (common::is_float(lhs.type_flag_) || common::is_float(rhs.type_flag_)) {
251259
CHECK(lhs.type_flag_ == out.type_flag_ || rhs.type_flag_ == out.type_flag_)
@@ -273,7 +281,7 @@ void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs,
273281
}
274282
});
275283
} else {
276-
LOG(ERROR) << "not implemented yet...";
284+
PrintErrorMessage(attrs.op->name, lhs.type_flag_, rhs.type_flag_);
277285
}
278286
}
279287
#else
@@ -303,7 +311,7 @@ void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs,
303311
attrs, ctx, {temp_tblob.reshape(lhs.shape_), rhs}, req, outputs);
304312
}
305313
} else {
306-
LOG(ERROR) << "not implemented yet...";
314+
PrintErrorMessage(attrs.op->name, lhs.type_flag_, rhs.type_flag_);
307315
}
308316
#endif
309317
}
@@ -324,7 +332,7 @@ void MixedBinaryBackwardUseIn(const nnvm::NodeAttrs& attrs,
324332
return;
325333
}
326334

327-
LOG(ERROR) << "Binary operation with mixed input data types does not support backward yet...";
335+
PrintErrorMessage(attrs.op->name, lhs.type_flag_, rhs.type_flag_);
328336
}
329337

330338
} // namespace op

0 commit comments

Comments
 (0)