Skip to content

Commit fc82d0c

Browse files
zoeygxyaaronmarkham
authored andcommitted
Numpy compatible vsplit; minor changes to split (apache#15983)
1 parent 9235e0e commit fc82d0c

File tree

4 files changed

+283
-23
lines changed

4 files changed

+283
-23
lines changed

python/mxnet/ndarray/numpy/_op.py

Lines changed: 85 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,11 @@
3232
'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2',
3333
'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
3434
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram', 'eye',
35-
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'dstack',
36-
'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
37-
'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad',
38-
'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer',
39-
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal']
35+
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit', 'concatenate', 'stack',
36+
'vstack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var',
37+
'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot',
38+
'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot',
39+
'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal']
4040

4141

4242
@set_module('mxnet.ndarray.numpy')
@@ -823,7 +823,6 @@ def eye(N, M=None, k=0, dtype=_np.float32, **kwargs):
823823
def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0, ctx=None): # pylint: disable=too-many-arguments
824824
r"""
825825
Return evenly spaced numbers over a specified interval.
826-
827826
Returns num evenly spaced samples, calculated over the interval [start, stop].
828827
The endpoint of the interval can optionally be excluded.
829828
@@ -2354,7 +2353,7 @@ def split(ary, indices_or_sections, axis=0):
23542353
----------
23552354
ary : ndarray
23562355
Array to be divided into sub-arrays.
2357-
indices_or_sections : int or 1-D array
2356+
indices_or_sections : int or 1-D python tuple, list or set.
23582357
If `indices_or_sections` is an integer, N, the array will be divided
23592358
into N equal arrays along `axis`. If such a split is not possible,
23602359
an error is raised.
@@ -2386,17 +2385,94 @@ def split(ary, indices_or_sections, axis=0):
23862385
raise ValueError('array split does not result in an equal division')
23872386
section_size = int(axis_size / sections)
23882387
indices = [i * section_size for i in range(sections)]
2389-
elif isinstance(indices_or_sections, tuple):
2388+
elif isinstance(indices_or_sections, (list, set, tuple)):
23902389
indices = [0] + list(indices_or_sections)
23912390
else:
2392-
raise ValueError('indices_or_sections must either int or tuple of ints')
2391+
raise ValueError('indices_or_sections must either int, or tuple / list / set of ints')
23932392
ret = _npi.split(ary, indices, axis, False)
23942393
if not isinstance(ret, list):
23952394
return [ret]
23962395
return ret
23972396
# pylint: enable=redefined-outer-name
23982397

23992398

2399+
@set_module('mxnet.ndarray.numpy')
2400+
def vsplit(ary, indices_or_sections):
2401+
r"""
2402+
vsplit(ary, indices_or_sections)
2403+
2404+
Split an array into multiple sub-arrays vertically (row-wise).
2405+
2406+
``vsplit`` is equivalent to ``split`` with `axis=0` (default): the array is always split
2407+
along the first axis regardless of the array dimension.
2408+
2409+
Parameters
2410+
----------
2411+
ary : ndarray
2412+
Array to be divided into sub-arrays.
2413+
indices_or_sections : int or 1 - D Python tuple, list or set.
2414+
If `indices_or_sections` is an integer, N, the array will be divided into N equal arrays
2415+
along axis 0. If such a split is not possible, an error is raised.
2416+
2417+
If `indices_or_sections` is a 1-D array of sorted integers, the entries indicate where
2418+
along axis 0 the array is split. For example, ``[2, 3]`` would result in
2419+
2420+
- ary[:2]
2421+
- ary[2:3]
2422+
- ary[3:]
2423+
2424+
If an index exceeds the dimension of the array along axis 0, an error will be thrown.
2425+
2426+
Returns
2427+
-------
2428+
sub-arrays : list of ndarrays
2429+
A list of sub-arrays.
2430+
2431+
See Also
2432+
--------
2433+
split : Split an array into multiple sub-arrays of equal size.
2434+
2435+
Notes
2436+
-------
2437+
This function differs from the original `numpy.degrees
2438+
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.degrees.html>`_ in
2439+
the following aspects:
2440+
2441+
- Currently parameter ``indices_or_sections`` does not support ndarray, but supports scalar,
2442+
tuple and list.
2443+
- In ``indices_or_sections``, if an index exceeds the dimension of the array along axis 0,
2444+
an error will be thrown.
2445+
2446+
Examples
2447+
--------
2448+
>>> x = np.arange(16.0).reshape(4, 4)
2449+
>>> x
2450+
array([[ 0., 1., 2., 3.],
2451+
[ 4., 5., 6., 7.],
2452+
[ 8., 9., 10., 11.],
2453+
[ 12., 13., 14., 15.]])
2454+
>>> np.vsplit(x, 2)
2455+
[array([[0., 1., 2., 3.],
2456+
[4., 5., 6., 7.]]), array([[ 8., 9., 10., 11.],
2457+
[12., 13., 14., 15.]])]
2458+
2459+
With a higher dimensional array the split is still along the first axis.
2460+
2461+
>>> x = np.arange(8.0).reshape(2, 2, 2)
2462+
>>> x
2463+
array([[[ 0., 1.],
2464+
[ 2., 3.]],
2465+
[[ 4., 5.],
2466+
[ 6., 7.]]])
2467+
>>> np.vsplit(x, 2)
2468+
[array([[[0., 1.],
2469+
[2., 3.]]]), array([[[4., 5.],
2470+
[6., 7.]]])]
2471+
2472+
"""
2473+
return split(ary, indices_or_sections, 0)
2474+
2475+
24002476
@set_module('mxnet.ndarray.numpy')
24012477
def concatenate(seq, axis=0, out=None):
24022478
"""Join a sequence of arrays along an existing axis.

python/mxnet/numpy/multiarray.py

Lines changed: 84 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,12 @@
5050
'sqrt', 'cbrt', 'abs', 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log',
5151
'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative',
5252
'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh',
53-
'tensordot', 'histogram', 'eye', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate',
54-
'stack', 'vstack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var',
55-
'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'arctan2', 'hypot',
56-
'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer',
57-
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal']
53+
'tensordot', 'histogram', 'eye', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit',
54+
'concatenate', 'stack', 'vstack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip',
55+
'argmax', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip',
56+
'around', 'arctan2', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take',
57+
'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal',
58+
'less_equal']
5859

5960

6061
# Return code for dispatching indexing function call
@@ -3951,7 +3952,7 @@ def split(ary, indices_or_sections, axis=0):
39513952
----------
39523953
ary : ndarray
39533954
Array to be divided into sub-arrays.
3954-
indices_or_sections : int or 1-D array
3955+
indices_or_sections : int or 1-D Python tuple, list or set.
39553956
If `indices_or_sections` is an integer, N, the array will be divided
39563957
into N equal arrays along `axis`. If such a split is not possible,
39573958
an error is raised.
@@ -3977,6 +3978,83 @@ def split(ary, indices_or_sections, axis=0):
39773978
return _mx_nd_np.split(ary, indices_or_sections, axis=axis)
39783979

39793980

3981+
@set_module('mxnet.numpy')
3982+
def vsplit(ary, indices_or_sections):
3983+
r"""
3984+
vsplit(ary, indices_or_sections)
3985+
3986+
Split an array into multiple sub-arrays vertically (row-wise).
3987+
3988+
``vsplit`` is equivalent to ``split`` with `axis=0` (default): the array is always split
3989+
along the first axis regardless of the array dimension.
3990+
3991+
Parameters
3992+
----------
3993+
ary : ndarray
3994+
Array to be divided into sub-arrays.
3995+
indices_or_sections : int or 1 - D Python tuple, list or set.
3996+
If `indices_or_sections` is an integer, N, the array will be divided into N equal arrays
3997+
along axis 0. If such a split is not possible, an error is raised.
3998+
3999+
If `indices_or_sections` is a 1-D array of sorted integers, the entries indicate where
4000+
along axis 0 the array is split. For example, ``[2, 3]`` would result in
4001+
4002+
- ary[:2]
4003+
- ary[2:3]
4004+
- ary[3:]
4005+
4006+
If an index exceeds the dimension of the array along axis 0, an error will be thrown.
4007+
4008+
Returns
4009+
-------
4010+
sub-arrays : list of ndarrays
4011+
A list of sub-arrays.
4012+
4013+
See Also
4014+
--------
4015+
split : Split an array into multiple sub-arrays of equal size.
4016+
4017+
Notes
4018+
-------
4019+
This function differs from the original `numpy.degrees
4020+
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.degrees.html>`_ in
4021+
the following aspects:
4022+
4023+
- Currently parameter ``indices_or_sections`` does not support ndarray, but supports scalar,
4024+
tuple and list.
4025+
- In ``indices_or_sections``, if an index exceeds the dimension of the array along axis 0,
4026+
an error will be thrown.
4027+
4028+
Examples
4029+
--------
4030+
>>> x = np.arange(16.0).reshape(4, 4)
4031+
>>> x
4032+
array([[ 0., 1., 2., 3.],
4033+
[ 4., 5., 6., 7.],
4034+
[ 8., 9., 10., 11.],
4035+
[ 12., 13., 14., 15.]])
4036+
>>> np.vsplit(x, 2)
4037+
[array([[0., 1., 2., 3.],
4038+
[4., 5., 6., 7.]]), array([[ 8., 9., 10., 11.],
4039+
[12., 13., 14., 15.]])]
4040+
4041+
With a higher dimensional array the split is still along the first axis.
4042+
4043+
>>> x = np.arange(8.0).reshape(2, 2, 2)
4044+
>>> x
4045+
array([[[ 0., 1.],
4046+
[ 2., 3.]],
4047+
[[ 4., 5.],
4048+
[ 6., 7.]]])
4049+
>>> np.vsplit(x, 2)
4050+
[array([[[0., 1.],
4051+
[2., 3.]]]), array([[[4., 5.],
4052+
[6., 7.]]])]
4053+
4054+
"""
4055+
return split(ary, indices_or_sections, 0)
4056+
4057+
39804058
@set_module('mxnet.numpy')
39814059
def concatenate(seq, axis=0, out=None):
39824060
"""Join a sequence of arrays along an existing axis.

python/mxnet/symbol/numpy/_symbol.py

Lines changed: 59 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,11 @@
3434
'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p',
3535
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
3636
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram', 'eye',
37-
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'dstack',
38-
'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
39-
'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad',
40-
'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer',
41-
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal']
37+
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit', 'concatenate', 'stack', 'vstack',
38+
'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices',
39+
'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot',
40+
'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot',
41+
'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal']
4242

4343

4444
def _num_outputs(sym):
@@ -2601,7 +2601,7 @@ def split(ary, indices_or_sections, axis=0):
26012601
----------
26022602
ary : ndarray
26032603
Array to be divided into sub-arrays.
2604-
indices_or_sections : int or 1-D array
2604+
indices_or_sections : int or 1-D python tuple, list or set.
26052605
If `indices_or_sections` is an integer, N, the array will be divided
26062606
into N equal arrays along `axis`. If such a split is not possible,
26072607
an error is raised.
@@ -2628,15 +2628,66 @@ def split(ary, indices_or_sections, axis=0):
26282628
sections = 0
26292629
if isinstance(indices_or_sections, int):
26302630
sections = indices_or_sections
2631-
elif isinstance(indices_or_sections, tuple):
2631+
elif isinstance(indices_or_sections, (list, set, tuple)):
26322632
indices = [0] + list(indices_or_sections)
26332633
else:
2634-
raise ValueError('indices_or_sections must either int or tuple of ints')
2634+
raise ValueError('indices_or_sections must either int or tuple / list / set of ints')
26352635
ret = _npi.split(ary, indices, axis, False, sections)
26362636
return ret
26372637
# pylint: enable=redefined-outer-name
26382638

26392639

2640+
@set_module('mxnet.symbol.numpy')
2641+
def vsplit(ary, indices_or_sections):
2642+
r"""
2643+
vsplit(ary, indices_or_sections)
2644+
2645+
Split an array into multiple sub-arrays vertically (row-wise).
2646+
2647+
``vsplit`` is equivalent to ``split`` with `axis=0` (default): the array is always split
2648+
along the first axis regardless of the array dimension.
2649+
2650+
Parameters
2651+
----------
2652+
ary : _Symbol
2653+
Array to be divided into sub-arrays.
2654+
indices_or_sections : int or 1 - D Python tuple, list or set.
2655+
If `indices_or_sections` is an integer, N, the array will be divided into N equal arrays
2656+
along axis 0. If such a split is not possible, an error is raised.
2657+
2658+
If `indices_or_sections` is a 1-D array of sorted integers, the entries indicate where
2659+
along axis 0 the array is split. For example, ``[2, 3]`` would result in
2660+
2661+
- ary[:2]
2662+
- ary[2:3]
2663+
- ary[3:]
2664+
2665+
If an index exceeds the dimension of the array along axis 0, an error will be thrown.
2666+
2667+
Returns
2668+
-------
2669+
sub-arrays : list of _Symbols
2670+
A list of sub-arrays.
2671+
2672+
See Also
2673+
--------
2674+
split : Split an array into multiple sub-arrays of equal size.
2675+
2676+
Notes
2677+
-------
2678+
This function differs from the original `numpy.degrees
2679+
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.degrees.html>`_ in
2680+
the following aspects:
2681+
2682+
- Currently parameter ``indices_or_sections`` does not support ndarray, but supports scalar,
2683+
tuple and list
2684+
- In ``indices_or_sections``, if an index exceeds the dimension of the array along axis 0,
2685+
an error will be thrown.
2686+
2687+
"""
2688+
return split(ary, indices_or_sections, 0)
2689+
2690+
26402691
@set_module('mxnet.symbol.numpy')
26412692
def concatenate(seq, axis=0, out=None):
26422693
"""Join a sequence of arrays along an existing axis.

tests/python/unittest/test_numpy_op.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1591,6 +1591,61 @@ def get_indices(axis_size):
15911591
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
15921592

15931593

1594+
@with_seed()
1595+
@use_np
1596+
def test_np_vsplit():
1597+
class TestVsplit(HybridBlock):
1598+
def __init__(self, indices_or_sections):
1599+
super(TestVsplit, self).__init__()
1600+
self._indices_or_sections = indices_or_sections
1601+
1602+
def hybrid_forward(self, F, a, *args, **kwargs):
1603+
return F.np.vsplit(a, indices_or_sections=self._indices_or_sections)
1604+
1605+
def get_indices(axis_size):
1606+
if axis_size is 0:
1607+
axis_size = random.randint(3, 6)
1608+
samples = random.randint(1, axis_size - 1)
1609+
indices = sorted(random.sample([i for i in range(1, axis_size)], samples))
1610+
indices = tuple(indices)
1611+
return indices
1612+
1613+
shapes = [
1614+
(2, 1, 2, 9),
1615+
(4, 3, 3),
1616+
(4, 0, 2), # zero-size shape
1617+
(0, 3), # first dim being zero
1618+
]
1619+
for hybridize in [True, False]:
1620+
for shape in shapes:
1621+
axis_size = shape[0]
1622+
indices = get_indices(axis_size)
1623+
sections = 7 if axis_size is 0 else axis_size
1624+
for indices_or_sections in [indices, sections]:
1625+
# test gluon
1626+
test_vsplit = TestVsplit(indices_or_sections=indices_or_sections)
1627+
if hybridize:
1628+
test_vsplit.hybridize()
1629+
a = rand_ndarray(shape).as_np_ndarray() # TODO: check type
1630+
a.attach_grad()
1631+
expected_ret = _np.vsplit(a.asnumpy(), indices_or_sections=indices_or_sections)
1632+
with mx.autograd.record():
1633+
y = test_vsplit(a)
1634+
assert len(y) == len(expected_ret)
1635+
for mx_out, np_out in zip(y, expected_ret):
1636+
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
1637+
1638+
mx.autograd.backward(y)
1639+
1640+
assert_almost_equal(a.grad.asnumpy(), _np.ones(a.shape), rtol=1e-3, atol=1e-5)
1641+
1642+
# test imperative
1643+
mx_outs = np.vsplit(a, indices_or_sections=indices_or_sections)
1644+
np_outs = _np.vsplit(a.asnumpy(), indices_or_sections=indices_or_sections)
1645+
for mx_out, np_out in zip(mx_outs, np_outs):
1646+
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
1647+
1648+
15941649
@with_seed()
15951650
@use_np
15961651
def test_np_concat():

0 commit comments

Comments
 (0)