Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit fe6336d

Browse files
committed
Fix unit test failure
1 parent 4234412 commit fe6336d

File tree

3 files changed

+14
-2
lines changed

3 files changed

+14
-2
lines changed

python/mxnet/ndarray/ndarray.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,10 @@ def __setitem__(self, key, value):
442442
array([[ 6., 5., 5.],
443443
[ 6., 0., 4.]], dtype=float32)
444444
"""
445+
if self.ndim == 0 and key == ():
446+
_internal._full(shape=self.shape, value=float(value), ctx=self.context,
447+
dtype=self.dtype, out=self)
448+
return
445449
key = _indexing_key_expand_implicit_axes(key, self.shape)
446450
slc_key = tuple(idx for idx in key if idx is not None)
447451

@@ -602,6 +606,8 @@ def __getitem__(self, key):
602606
array([[[4., 5.],
603607
[6., 7.]]], dtype=float32)
604608
"""
609+
if self.ndim == 0 and key == ():
610+
return self
605611
key = _indexing_key_expand_implicit_axes(key, self.shape)
606612
if len(key) == 0:
607613
raise ValueError('indexing key cannot be an empty tuple')
@@ -2741,6 +2747,8 @@ def _get_dim_size(start, stop, step):
27412747
"""Given start, stop, and stop, calculate the number of elements
27422748
of this slice."""
27432749
assert step != 0
2750+
if stop == start:
2751+
return 0
27442752
if step > 0:
27452753
assert start < stop
27462754
dim_size = (stop - start - 1) // step + 1

python/mxnet/test_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1062,7 +1062,11 @@ def check_symbolic_forward(sym, location, expected, rtol=1E-4, atol=None,
10621062

10631063
executor = sym.bind(ctx=ctx, args=location, args_grad=args_grad_data, aux_states=aux_states)
10641064
for g in executor.grad_arrays:
1065-
g[:] = 0
1065+
print(g.shape)
1066+
if g.ndim == 0:
1067+
g[()] = 0
1068+
else:
1069+
g[:] = 0
10661070

10671071
executor.forward(is_train=False)
10681072

tests/python/unittest/test_operator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8862,7 +8862,7 @@ def test_index_array_default():
88628862

88638863
@mx.use_np_shape
88648864
def test_index_array_default_zero_dim():
8865-
data = mx.symbol.Variable("data")
8865+
data = mx.symbol.Variable("data")
88668866
index_array = mx.sym.contrib.index_array(data)
88678867

88688868
input_array = np.ones(())

0 commit comments

Comments
 (0)