Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 15c8cd1

Browse files
committed
add numpy operator remainder
1 parent 692f3c4 commit 15c8cd1

File tree

6 files changed

+88
-4
lines changed

6 files changed

+88
-4
lines changed

python/mxnet/ndarray/numpy/_op.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
3434
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot',
3535
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'mean',
36-
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var']
36+
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'remainder']
3737

3838

3939
@set_module('mxnet.ndarray.numpy')
@@ -2338,3 +2338,28 @@ def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False): # pylint:
23382338
0.2025
23392339
"""
23402340
return _npi.var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, out=out)
2341+
2342+
2343+
@set_module('mxnet.ndarray.numpy')
2344+
def remainder(x1, x2, out=None):
2345+
"""Return element-wise remainder of division.
2346+
2347+
Parameters
2348+
----------
2349+
x1 : ndarray or scalar
2350+
Dividend array.
2351+
2352+
x2 : ndarray or scalar
2353+
Divisor array.
2354+
2355+
out : ndarray
2356+
A location into which the result is stored. If provided, it must have a shape
2357+
that the inputs broadcast to. If not provided or None, a freshly-allocated array
2358+
is returned.
2359+
2360+
Returns
2361+
-------
2362+
out : ndarray or scalar
2363+
This is a scalar if both x1 and x2 are scalars.
2364+
"""
2365+
return _ufunc_helper(x1, x2, _npi.remainder, _np.remainder, _npi.remainder_scalar, _npi.rremainder_scalar, out)

python/mxnet/numpy/multiarray.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@
5252
'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative',
5353
'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh',
5454
'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate',
55-
'stack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var']
55+
'stack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var',
56+
'remainder']
5657

5758
# Return code for dispatching indexing function call
5859
_NDARRAY_UNSUPPORTED_INDEXING = -1
@@ -3779,3 +3780,28 @@ def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=None):
37793780
0.2025
37803781
"""
37813782
return _npi.var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, out=out)
3783+
3784+
3785+
@set_module('mxnet.numpy')
3786+
def remainder(x1, x2, out=None):
3787+
"""Return element-wise remainder of division.
3788+
3789+
Parameters
3790+
----------
3791+
x1 : ndarray or scalar
3792+
Dividend array.
3793+
3794+
x2 : ndarray or scalar
3795+
Divisor array.
3796+
3797+
out : ndarray
3798+
A location into which the result is stored. If provided, it must have a shape
3799+
that the inputs broadcast to. If not provided or None, a freshly-allocated array
3800+
is returned.
3801+
3802+
Returns
3803+
-------
3804+
out : ndarray or scalar
3805+
This is a scalar if both x1 and x2 are scalars.
3806+
"""
3807+
return _mx_nd_np.remainder(x1, x2, out=out)

python/mxnet/symbol/numpy/_symbol.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
3636
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot',
3737
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'mean',
38-
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var']
38+
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'remainder']
3939

4040

4141
def _num_outputs(sym):
@@ -1066,7 +1066,6 @@ def divide(x1, x2, out=None):
10661066
def mod(x1, x2, out=None):
10671067
return _ufunc_helper(x1, x2, _npi.mod, _np.mod, _npi.mod_scalar, _npi.rmod_scalar, out)
10681068

1069-
10701069
@set_module('mxnet.symbol.numpy')
10711070
def power(x1, x2, out=None):
10721071
return _ufunc_helper(x1, x2, _npi.power, _np.power, _npi.power_scalar, _npi.rpower_scalar, out)
@@ -2669,4 +2668,26 @@ def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False): # pylint:
26692668
return _npi.var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, out=out)
26702669

26712670

2671+
@set_module('mxnet.symbol.numpy')
2672+
def remainder(x1, x2, out=None):
2673+
"""Return element-wise remainder of division.
2674+
2675+
Parameters
2676+
----------
2677+
x1 : _Symbol or scalar
2678+
Dividend array.
2679+
2680+
x2 : _Symbol or scalar
2681+
Divisor array.
2682+
2683+
out : _Symbol, optional
2684+
Dummy parameter to keep the consistency with the ndarray counterpart.
2685+
2686+
Returns
2687+
-------
2688+
out : _Symbol or scalar
2689+
This is a scalar if both x1 and x2 are scalars.
2690+
"""
2691+
return _ufunc_helper(x1, x2, _npi.remainder, _np.remainder, _npi.remainder_scalar, _npi.rremainder_scalar, out)
2692+
26722693
_set_np_symbol_class(_Symbol)

src/operator/numpy/np_elemwise_broadcast_op.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_multiply)
6969
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_broadcast_mul"});
7070

7171
MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_mod)
72+
.add_alias("_npi_remainder")
7273
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::mod>)
7374
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_broadcast_mod"});
7475

@@ -93,10 +94,12 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_multiply_scalar)
9394
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_mul_scalar"});
9495

9596
MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_mod_scalar)
97+
.add_alias("_npi_remainder_scalar")
9698
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::mod>)
9799
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_mod_scalar"});
98100

99101
MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rmod_scalar)
102+
.add_alias("_npi_rremainder_scalar")
100103
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::rmod>)
101104
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_rmod_scalar"});
102105

src/operator/numpy/np_elemwise_broadcast_op.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ NNVM_REGISTER_OP(_npi_multiply)
3737
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, op::mshadow_op::mul>);
3838

3939
NNVM_REGISTER_OP(_npi_mod)
40+
.add_alias("_npi_remainder")
4041
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::mod>);
4142

4243
NNVM_REGISTER_OP(_npi_power)
@@ -55,9 +56,11 @@ NNVM_REGISTER_OP(_npi_multiply_scalar)
5556
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, op::mshadow_op::mul>);
5657

5758
NNVM_REGISTER_OP(_npi_mod_scalar)
59+
.add_alias("_npi_remainder_scalar")
5860
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::mod>);
5961

6062
NNVM_REGISTER_OP(_npi_rmod_scalar)
63+
.add_alias("_npi_rremainder_scalar")
6164
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::rmod>);
6265

6366
NNVM_REGISTER_OP(_npi_power_scalar)

tests/python/unittest/test_numpy_ndarray.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ def test_ndarray_binary_element_wise_ops():
155155
'*': _np.multiply,
156156
'-': _np.subtract,
157157
'/': _np.divide,
158+
'remainder': _np.remainder,
158159
'mod': _np.mod,
159160
'pow': _np.power,
160161
'==': _np.equal,
@@ -201,6 +202,11 @@ def hybrid_forward(self, F, x, *args):
201202
return x % self._scalar if not self._reverse else self._scalar % x
202203
else:
203204
return x % args[0] if not self._reverse else args[0] % x
205+
elif self._op == 'remainder':
206+
if self._scalar is not None:
207+
return x % self._scalar if not self._reverse else self._scalar % x
208+
else:
209+
return x % args[0] if not self._reverse else args[0] % x
204210
elif self._op == 'pow':
205211
if self._scalar is not None:
206212
return x ** self._scalar if not self._reverse else self._scalar ** x

0 commit comments

Comments
 (0)