Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 90091b1

Browse files
hzfansxjscience
authored andcommitted
[Numpy] Numpy copysign (#15851)
* add numpy compatible copysign * fix scalar op registration error * add test
1 parent e9e267e commit 90091b1

File tree

8 files changed

+316
-3
lines changed

8 files changed

+316
-3
lines changed

python/mxnet/ndarray/numpy/_op.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
3434
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot',
3535
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'mean',
36-
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices']
36+
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign']
3737

3838

3939
@set_module('mxnet.ndarray.numpy')
@@ -2432,3 +2432,54 @@ def indices(dimensions, dtype=_np.int32, ctx=None):
24322432
else:
24332433
raise ValueError("The dimensions must be sequence of ints")
24342434
# pylint: enable=redefined-outer-name
2435+
2436+
2437+
@set_module('mxnet.ndarray.numpy')
2438+
def copysign(x1, x2, out=None):
2439+
r"""copysign(x1, x2, out=None)
2440+
2441+
Change the sign of x1 to that of x2, element-wise.
2442+
2443+
If `x2` is a scalar, its sign will be copied to all elements of `x1`.
2444+
2445+
Parameters
2446+
----------
2447+
x1 : ndarray or scalar
2448+
Values to change the sign of.
2449+
x2 : ndarray or scalar
2450+
The sign of `x2` is copied to `x1`.
2451+
out : ndarray or None, optional
2452+
A location into which the result is stored. It must be of the
2453+
right shape and right type to hold the output. If not provided
2454+
or `None`,a freshly-allocated array is returned.
2455+
2456+
Returns
2457+
-------
2458+
out : ndarray or scalar
2459+
The values of `x1` with the sign of `x2`.
2460+
This is a scalar if both `x1` and `x2` are scalars.
2461+
2462+
Notes
2463+
-------
2464+
This function differs from the original `numpy.copysign
2465+
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.copysign.html>`_ in
2466+
the following aspects:
2467+
2468+
- ``where`` param is not supported.
2469+
2470+
Examples
2471+
--------
2472+
>>> np.copysign(1.3, -1)
2473+
-1.3
2474+
>>> 1/np.copysign(0, 1)
2475+
inf
2476+
>>> 1/np.copysign(0, -1)
2477+
-inf
2478+
2479+
>>> a = np.array([-1, 0, 1])
2480+
>>> np.copysign(a, -1.1)
2481+
array([-1., -0., -1.])
2482+
>>> np.copysign(a, np.arange(3)-1)
2483+
array([-1., 0., 1.])
2484+
"""
2485+
return _ufunc_helper(x1, x2, _npi.copysign, _np.copysign, _npi.copysign_scalar, _npi.rcopysign_scalar, out)

python/mxnet/numpy/multiarray.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative',
5353
'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh',
5454
'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate',
55-
'stack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices']
55+
'stack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign']
5656

5757
# Return code for dispatching indexing function call
5858
_NDARRAY_UNSUPPORTED_INDEXING = -1
@@ -3935,3 +3935,54 @@ def indices(dimensions, dtype=_np.int32, ctx=None):
39353935
"""
39363936
return _mx_nd_np.indices(dimensions=dimensions, dtype=dtype, ctx=ctx)
39373937
# pylint: enable=redefined-outer-name
3938+
3939+
3940+
@set_module('mxnet.numpy')
3941+
def copysign(x1, x2, out=None):
3942+
r"""copysign(x1, x2, out=None)
3943+
3944+
Change the sign of x1 to that of x2, element-wise.
3945+
3946+
If `x2` is a scalar, its sign will be copied to all elements of `x1`.
3947+
3948+
Parameters
3949+
----------
3950+
x1 : ndarray or scalar
3951+
Values to change the sign of.
3952+
x2 : ndarray or scalar
3953+
The sign of `x2` is copied to `x1`.
3954+
out : ndarray or None, optional
3955+
A location into which the result is stored. It must be of the
3956+
right shape and right type to hold the output. If not provided
3957+
or `None`,a freshly-allocated array is returned.
3958+
3959+
Returns
3960+
-------
3961+
out : ndarray or scalar
3962+
The values of `x1` with the sign of `x2`.
3963+
This is a scalar if both `x1` and `x2` are scalars.
3964+
3965+
Notes
3966+
-------
3967+
This function differs from the original `numpy.copysign
3968+
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.copysign.html>`_ in
3969+
the following aspects:
3970+
3971+
- ``where`` param is not supported.
3972+
3973+
Examples
3974+
--------
3975+
>>> np.copysign(1.3, -1)
3976+
-1.3
3977+
>>> 1/np.copysign(0, 1)
3978+
inf
3979+
>>> 1/np.copysign(0, -1)
3980+
-inf
3981+
3982+
>>> a = np.array([-1, 0, 1])
3983+
>>> np.copysign(a, -1.1)
3984+
array([-1., -0., -1.])
3985+
>>> np.copysign(a, np.arange(3)-1)
3986+
array([-1., 0., 1.])
3987+
"""
3988+
return _mx_nd_np.copysign(x1, x2, out=out)

python/mxnet/symbol/numpy/_symbol.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
3636
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot',
3737
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'mean',
38-
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices']
38+
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign']
3939

4040

4141
def _num_outputs(sym):
@@ -2744,4 +2744,38 @@ def indices(dimensions, dtype=_np.int32, ctx=None):
27442744
# pylint: enable=redefined-outer-name
27452745

27462746

2747+
@set_module('mxnet.symbol.numpy')
2748+
def copysign(x1, x2, out=None):
2749+
r"""copysign(x1, x2, out=None)
2750+
2751+
Change the sign of x1 to that of x2, element-wise.
2752+
2753+
If `x2` is a scalar, its sign will be copied to all elements of `x1`.
2754+
2755+
Parameters
2756+
----------
2757+
x1 : _Symbol or scalar
2758+
Values to change the sign of.
2759+
x2 : _Symbol or scalar
2760+
The sign of `x2` is copied to `x1`.
2761+
out : _Symbol or None
2762+
Dummy parameter to keep the consistency with the ndarray counterpart.
2763+
2764+
Returns
2765+
-------
2766+
out : _Symbol
2767+
The values of `x1` with the sign of `x2`.
2768+
This is a scalar if both `x1` and `x2` are scalars.
2769+
2770+
Notes
2771+
-------
2772+
This function differs from the original `numpy.copysign
2773+
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.copysign.html>`_ in
2774+
the following aspects:
2775+
2776+
- ``where`` param is not supported.
2777+
"""
2778+
return _ufunc_helper(x1, x2, _npi.copysign, _np.copysign, _npi.copysign_scalar, _npi.rcopysign_scalar, out)
2779+
2780+
27472781
_set_np_symbol_class(_Symbol)

src/operator/mshadow_op.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,16 @@ MXNET_BINARY_MATH_OP(rdiv, math::id(b) / math::id(a));
417417

418418
MXNET_BINARY_MATH_OP(rdiv_grad, -math::id(b) / math::sqr(a));
419419

420+
MXNET_BINARY_MATH_OP(copysign, (a >= 0 && b >= 0) || (a < 0 && b < 0) ? a : -a);
421+
422+
MXNET_BINARY_MATH_OP(copysign_grad, (a >= 0 && b >= 0) || (a < 0 && b < 0) ? 1: -1);
423+
424+
MXNET_BINARY_MATH_OP(copysign_rgrad, 0);
425+
426+
MXNET_BINARY_MATH_OP(rcopysign, (b >= 0 && a >= 0) || (b < 0 && a < 0) ? b : -b);
427+
428+
MXNET_BINARY_MATH_OP(rcopysign_grad, 0);
429+
420430
struct mod : public mxnet_op::tunable {
421431
template<typename DType>
422432
MSHADOW_XINLINE static typename enable_if<!is_unsigned<DType>::value, DType>::type

src/operator/numpy/np_elemwise_broadcast_op.cc

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,26 @@ MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_power)
7676
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::power>)
7777
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_broadcast_power"});
7878

79+
MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_copysign)
80+
.describe(R"code()code" ADD_FILELINE)
81+
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::copysign>)
82+
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_copysign"});
83+
84+
NNVM_REGISTER_OP(_backward_npi_copysign)
85+
.set_num_inputs(3)
86+
.set_num_outputs(2)
87+
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
88+
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
89+
[](const NodeAttrs& attrs){
90+
return std::vector<std::pair<int, int> >{{0, 1}};
91+
})
92+
.set_attr<FResourceRequest>("FResourceRequest",
93+
[](const NodeAttrs& attrs) {
94+
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
95+
})
96+
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastBackwardUseIn<cpu, mshadow_op::copysign_grad,
97+
mshadow_op::copysign_rgrad>);
98+
7999
MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_add_scalar)
80100
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, op::mshadow_op::plus>)
81101
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_copy"});
@@ -108,5 +128,21 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rpower_scalar)
108128
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::rpower>)
109129
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_rpower_scalar"});
110130

131+
MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_copysign_scalar)
132+
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::copysign>)
133+
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_copysign_scalar"});
134+
135+
MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rcopysign_scalar)
136+
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::rcopysign>)
137+
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_rcopysign_scalar"});
138+
139+
MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_backward_npi_copysign_scalar)
140+
.set_attr<FCompute>("FCompute<cpu>",
141+
BinaryScalarOp::Backward<cpu, mshadow_op::copysign_grad>);
142+
143+
MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_backward_npi_rcopysign_scalar)
144+
.set_attr<FCompute>("FCompute<cpu>",
145+
BinaryScalarOp::Backward<cpu, mshadow_op::rcopysign_grad>);
146+
111147
} // namespace op
112148
} // namespace mxnet

src/operator/numpy/np_elemwise_broadcast_op.cu

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,13 @@ NNVM_REGISTER_OP(_npi_mod)
4242
NNVM_REGISTER_OP(_npi_power)
4343
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::power>);
4444

45+
NNVM_REGISTER_OP(_npi_copysign)
46+
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::copysign>);
47+
48+
NNVM_REGISTER_OP(_backward_npi_copysign)
49+
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastBackwardUseIn<gpu, mshadow_op::copysign_grad,
50+
mshadow_op::copysign_rgrad>);
51+
4552
NNVM_REGISTER_OP(_npi_add_scalar)
4653
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, op::mshadow_op::plus>);
4754

@@ -66,5 +73,19 @@ NNVM_REGISTER_OP(_npi_power_scalar)
6673
NNVM_REGISTER_OP(_npi_rpower_scalar)
6774
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::rpower>);
6875

76+
NNVM_REGISTER_OP(_npi_copysign_scalar)
77+
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::copysign>);
78+
79+
NNVM_REGISTER_OP(_npi_rcopysign_scalar)
80+
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::rcopysign>);
81+
82+
NNVM_REGISTER_OP(_backward_npi_copysign_scalar)
83+
.set_attr<FCompute>("FCompute<gpu>",
84+
BinaryScalarOp::Backward<gpu, mshadow_op::copysign_grad>);
85+
86+
NNVM_REGISTER_OP(_backward_npi_rcopysign_scalar)
87+
.set_attr<FCompute>("FCompute<gpu>",
88+
BinaryScalarOp::Backward<gpu, mshadow_op::rcopysign_grad>);
89+
6990
} // namespace op
7091
} // namespace mxnet

src/operator/operator_tune.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,11 @@ IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::elu); // NOLINT()
328328
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::power_grad); // NOLINT()
329329
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rpower_grad); // NOLINT()
330330
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::power_rgrad); // NOLINT()
331+
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::copysign); // NOLINT()
332+
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rcopysign); // NOLINT()
333+
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::copysign_grad); // NOLINT()
334+
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::copysign_rgrad); // NOLINT()
335+
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rcopysign_grad); // NOLINT()
331336
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::xelu_grad); // NOLINT()
332337
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::gelu_grad); // NOLINT()
333338
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::prelu_grad); // NOLINT()

tests/python/unittest/test_numpy_op.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1853,6 +1853,111 @@ def hybrid_forward(self, F, x):
18531853
assert_almost_equal(mx_ret.asnumpy(), np_ret, atol=1e-5, rtol=1e-4)
18541854

18551855

1856+
@with_seed()
1857+
@use_np
1858+
def test_np_copysign():
1859+
class TestCopysign(HybridBlock):
1860+
def __init__(self):
1861+
super(TestCopysign, self).__init__()
1862+
1863+
def hybrid_forward(self, F, a1, a2):
1864+
return F.np.copysign(a1, a2)
1865+
1866+
def get_grad(a1, a2):
1867+
sign = _np.logical_or(_np.logical_and(a1 < 0, a2 < 0),
1868+
_np.logical_and(a1 >= 0, a2 >= 0))
1869+
sign = 2 * sign.astype(int) - 1
1870+
sign = sign.reshape(-1, *a1.shape)
1871+
sign = _np.sum(sign, axis=0)
1872+
return sign, _np.zeros_like(a2)
1873+
1874+
def get_grad_left(a1, a2):
1875+
sign = _np.logical_or(_np.logical_and(a1 < 0, a2 < 0),
1876+
_np.logical_and(a1 >= 0, a2 >= 0))
1877+
sign = 2 * sign.astype(int) - 1
1878+
sign = sign.reshape(a1.shape)
1879+
return sign
1880+
1881+
def get_grad_right(a1, a2):
1882+
return _np.zeros_like(a2)
1883+
1884+
shapes = [
1885+
(),
1886+
(1),
1887+
(2, 1),
1888+
(3, 2, 1),
1889+
(4, 3, 2, 1),
1890+
(2, 4, 3, 2, 1)
1891+
]
1892+
types = ['float16', 'float32', 'float64', 'int8', 'int32', 'int64']
1893+
for a1shape in shapes:
1894+
for a2shape in shapes:
1895+
for hybridize in [True, False]:
1896+
for dtype in types:
1897+
test_copysign = TestCopysign()
1898+
if hybridize:
1899+
test_copysign.hybridize()
1900+
rtol = 1e-3
1901+
atol = 1e-5
1902+
a1_np = _np.array(_np.random.uniform(-1.0, 1.0, a1shape), dtype=dtype)
1903+
a2_np = _np.array(_np.random.uniform(-1.0, 1.0, a2shape), dtype=dtype)
1904+
a1 = np.array(a1_np, dtype=dtype)
1905+
a2 = np.array(a2_np, dtype=dtype)
1906+
a1.attach_grad()
1907+
a2.attach_grad()
1908+
expected_np = _np.copysign(a1_np, a2_np)
1909+
with mx.autograd.record():
1910+
mx_out = test_copysign(a1, a2)
1911+
assert mx_out.shape == expected_np.shape
1912+
assert_almost_equal(mx_out.asnumpy(), expected_np, rtol=rtol, atol=atol)
1913+
1914+
# Test gradient
1915+
mx_out.backward()
1916+
a1_grad, a2_grad = get_grad(a1_np, a2_np)
1917+
assert_almost_equal(a1.grad.asnumpy(), a1_grad, rtol=rtol, atol=atol)
1918+
assert_almost_equal(a2.grad.asnumpy(), a2_grad, rtol=rtol, atol=atol)
1919+
1920+
# Test imperative once again
1921+
mx_out = np.copysign(a1, a2)
1922+
expected_np = _np.copysign(a1_np, a2_np)
1923+
assert_almost_equal(mx_out.asnumpy(), expected_np, rtol=rtol, atol=atol)
1924+
1925+
types = ['float16', 'float32', 'float64']
1926+
for x_shape in shapes:
1927+
for dtype in types:
1928+
# Test left
1929+
x_np = _np.array(_np.random.uniform(-2.0, 2.0, x_shape), dtype=dtype)
1930+
scalar = _np.random.uniform(-2.0, 2.0)
1931+
x = np.array(x_np, dtype=dtype)
1932+
x.attach_grad()
1933+
expected_np = _np.copysign(x_np, scalar)
1934+
with mx.autograd.record():
1935+
mx_out = np.copysign(x, scalar)
1936+
assert mx_out.shape == expected_np.shape
1937+
assert_almost_equal(mx_out.asnumpy(), expected_np, rtol=rtol, atol=atol)
1938+
1939+
# Test gradient
1940+
mx_out.backward()
1941+
x_grad = get_grad_left(x_np, scalar)
1942+
assert_almost_equal(x.grad.asnumpy(), x_grad, rtol=rtol, atol=atol)
1943+
1944+
# Test right
1945+
x_np = _np.array(_np.random.uniform(-2.0, 2.0, x_shape), dtype=dtype)
1946+
scalar = _np.random.uniform(-2.0, 2.0)
1947+
x = np.array(x_np, dtype=dtype)
1948+
x.attach_grad()
1949+
expected_np = _np.copysign(scalar, x_np)
1950+
with mx.autograd.record():
1951+
mx_out = np.copysign(scalar, x)
1952+
assert mx_out.shape == expected_np.shape
1953+
assert_almost_equal(mx_out.asnumpy(), expected_np, rtol=rtol, atol=atol)
1954+
1955+
# Test gradient
1956+
mx_out.backward()
1957+
x_grad = get_grad_right(scalar, x_np)
1958+
assert_almost_equal(x.grad.asnumpy(), x_grad, rtol=rtol, atol=atol)
1959+
1960+
18561961
if __name__ == '__main__':
18571962
import nose
18581963
nose.runmodule()

0 commit comments

Comments
 (0)