Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit f01bcaa

Browse files
xidulureminisce
authored andcommitted
[Numpy] More numpy dispatch tests (#16426)
* tests added * remove not equal * fix tiny bug * remove meshgrid test * modify meshgrid return type, add test
1 parent 63fbfb1 commit f01bcaa

File tree

3 files changed

+25
-1
lines changed

3 files changed

+25
-1
lines changed

python/mxnet/numpy/stride_tricks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,6 @@ def broadcast_arrays(*args):
5151

5252
if all(array.shape == shape for array in args):
5353
# Common case where nothing needs to be broadcasted.
54-
return args
54+
return list(args)
5555

5656
return [_mx_np_op.broadcast_to(array, shape) for array in args]

python/mxnet/numpy_dispatch_protocol.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ def _run_with_array_ufunc_proto(*args, **kwargs):
113113
'transpose',
114114
'var',
115115
'zeros_like',
116+
'meshgrid',
117+
'outer'
116118
]
117119

118120

@@ -196,6 +198,7 @@ def _register_array_function():
196198
'ceil',
197199
'trunc',
198200
'floor',
201+
'logical_not',
199202
]
200203

201204

tests/python/unittest/test_numpy_interoperability.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ def _prepare_workloads():
7777
OpArgMngr.add_workload('min', array_pool['4x1'])
7878
OpArgMngr.add_workload('mean', array_pool['4x1'])
7979
OpArgMngr.add_workload('mean', array_pool['4x1'], axis=0, keepdims=True)
80+
OpArgMngr.add_workload('mean', np.array([[1, 2, 3], [4, 5, 6]]))
81+
OpArgMngr.add_workload('mean', np.array([[1, 2, 3], [4, 5, 6]]), axis=0)
82+
OpArgMngr.add_workload('mean', np.array([[1, 2, 3], [4, 5, 6]]), axis=1)
8083
OpArgMngr.add_workload('ones_like', array_pool['4x1'])
8184
OpArgMngr.add_workload('prod', array_pool['4x1'])
8285

@@ -157,6 +160,10 @@ def _prepare_workloads():
157160
OpArgMngr.add_workload('transpose', array_pool['4x1'])
158161
OpArgMngr.add_workload('var', array_pool['4x1'])
159162
OpArgMngr.add_workload('zeros_like', array_pool['4x1'])
163+
OpArgMngr.add_workload('outer', np.ones((5)), np.ones((2)))
164+
OpArgMngr.add_workload('meshgrid', np.array([1, 2, 3]))
165+
OpArgMngr.add_workload('meshgrid', np.array([1, 2, 3]), np.array([4, 5, 6, 7]))
166+
OpArgMngr.add_workload('meshgrid', np.array([1, 2, 3]), np.array([4, 5, 6, 7]), indexing='ij')
160167

161168
# workloads for array ufunc protocol
162169
OpArgMngr.add_workload('add', array_pool['4x1'], array_pool['1x2'])
@@ -175,6 +182,9 @@ def _prepare_workloads():
175182
OpArgMngr.add_workload('power', array_pool['4x1'], 2)
176183
OpArgMngr.add_workload('power', 2, array_pool['4x1'])
177184
OpArgMngr.add_workload('power', array_pool['4x1'], array_pool['1x1x0'])
185+
OpArgMngr.add_workload('power', np.array([1, 2, 3], np.int32), 2.00001)
186+
OpArgMngr.add_workload('power', np.array([15, 15], np.int64), np.array([15, 15], np.int64))
187+
OpArgMngr.add_workload('power', 0, np.arange(1, 10))
178188
OpArgMngr.add_workload('mod', array_pool['4x1'], array_pool['1x2'])
179189
OpArgMngr.add_workload('mod', array_pool['4x1'], 2)
180190
OpArgMngr.add_workload('mod', 2, array_pool['4x1'])
@@ -256,6 +266,12 @@ def _signs(dt):
256266
OpArgMngr.add_workload('exp', array_pool['4x1'])
257267
OpArgMngr.add_workload('log', array_pool['4x1'])
258268
OpArgMngr.add_workload('log2', array_pool['4x1'])
269+
OpArgMngr.add_workload('log2', np.array(2.**65))
270+
OpArgMngr.add_workload('log2', np.array(np.inf))
271+
OpArgMngr.add_workload('log2', np.array(1.))
272+
OpArgMngr.add_workload('log1p', np.array(-1.))
273+
OpArgMngr.add_workload('log1p', np.array(np.inf))
274+
OpArgMngr.add_workload('log1p', np.array(1e-6))
259275
OpArgMngr.add_workload('log10', array_pool['4x1'])
260276
OpArgMngr.add_workload('expm1', array_pool['4x1'])
261277
OpArgMngr.add_workload('sqrt', array_pool['4x1'])
@@ -282,6 +298,11 @@ def _signs(dt):
282298
OpArgMngr.add_workload('ceil', array_pool['4x1'])
283299
OpArgMngr.add_workload('trunc', array_pool['4x1'])
284300
OpArgMngr.add_workload('floor', array_pool['4x1'])
301+
OpArgMngr.add_workload('logical_not', np.ones(10, dtype=np.int32))
302+
OpArgMngr.add_workload('logical_not', array_pool['4x1'])
303+
OpArgMngr.add_workload('logical_not', np.array([True, False, True, False], dtype=np.bool))
304+
305+
285306

286307

287308
_prepare_workloads()

0 commit comments

Comments
 (0)