Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 18792d5

Browse files
author
Mike Mao
committed
Modify ndarray slice to have numpy compatbile behaviou
1 parent a3babc4 commit 18792d5

File tree

3 files changed

+127
-25
lines changed

3 files changed

+127
-25
lines changed

src/operator/tensor/matrix_op-inl.h

Lines changed: 54 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -685,13 +685,13 @@ inline void GetIndexRange(const mxnet::TShape& dshape,
685685
<< "Static array size=" << ndim
686686
<< " is not equal to data shape ndim=" << dshape.ndim();
687687

688-
if (param_step.ndim() != 0) {
688+
if (param_step.ndim() > 0) {
689689
CHECK_EQ(param_step.ndim(), param_begin.ndim())
690690
<< "step and begin must have the same length";
691691
}
692692

693693
for (int i = 0; i < param_begin.ndim(); ++i) {
694-
index_t s = param_step.ndim() != 0U && param_step[i].has_value() ? param_step[i].value() : 1;
694+
index_t s = param_step.ndim() > 0 && param_step[i].has_value() ? param_step[i].value() : 1;
695695
CHECK_NE(s, 0) << "slice op step[" << i << "] cannot be 0";
696696

697697
index_t b = 0, e = 0;
@@ -703,29 +703,44 @@ inline void GetIndexRange(const mxnet::TShape& dshape,
703703
// checking upper and lower bounds for begin
704704
if (b < 0) {
705705
b += len;
706-
CHECK_GE(b, 0) << "slicing with begin[" << i << "]=" << b - len
707-
<< " exceeds limit of input dimension[" << i << "]=" << len;
706+
if (!Imperative::Get()->is_np_shape()) {
707+
CHECK_GE(b, 0) << "slicing with begin[" << i << "]=" << b - len
708+
<< " exceeds limit of input dimension[" << i << "]=" << len;
709+
}
710+
}
711+
if (!Imperative::Get()->is_np_shape()) {
712+
CHECK_LT(b, len) << "slicing with begin[" << i << "]=" << b
713+
<< " exceeds limit of input dimension[" << i << "]=" << len;
708714
}
709-
CHECK_LT(b, len) << "slicing with begin[" << i << "]=" << b
710-
<< " exceeds limit of input dimension[" << i << "]=" << len;
711-
712715
// checking upper and lower bounds for end
713716
if (e < 0 && param_end[i].has_value()) {
714-
if (!(s < 0 && e == -1)) {
715-
// Keep end=-1 as one-beyond-limits index for negative stride
716-
e += len;
717+
e += len;
718+
if (!Imperative::Get()->is_np_shape()) {
719+
CHECK_GE(e, 0) << "slicing with end[" << i << "]=" << e - len
720+
<< " exceeds limit of input dimension[" << i << "]=" << len;
717721
}
718-
CHECK_GE(e, 0) << "slicing with end[" << i << "]=" << e - len
719-
<< " exceeds limit of input dimension[" << i << "]=" << len;
720722
}
721-
CHECK_LE(e, len) << "slicing with end[" << i << "]=" << e
722-
<< " exceeds limit of input dimension[" << i << "]=" << len;
723+
if (!Imperative::Get()->is_np_shape()) {
724+
CHECK_LE(e, len) << "slicing with end[" << i << "]=" << e
725+
<< " exceeds limit of input dimension[" << i << "]=" << len;
726+
}
723727

724728
// checking begin==end case which is not supported
725-
CHECK_NE(b, e) << "slicing with begin[" << i << "]=end[" << i << "]="
726-
<< e << " results in an empty tensor and is not supported";
729+
if (!Imperative::Get()->is_np_shape()) {
730+
CHECK_NE(b, e) << "slicing with begin[" << i << "]=end[" << i << "]="
731+
<< e << " results in an empty tensor and is not supported";
732+
}
727733
}
728734

735+
if (Imperative::Get()->is_np_shape()) {
736+
// move the begin and end to correct position for calculating dim size
737+
b = b < 0 && s > 0 ? 0 : b;
738+
b = b > len-1 && s < 0 ? len-1 : b;
739+
// if the start value lead to empty tensor under step s, use -1 for indication
740+
b = b < 0 || b > len-1 ? -1 : b;
741+
e = e > -1 ? e : -1;
742+
e = e > len ? len : e;
743+
}
729744
(*begin)[i] = b;
730745
(*end)[i] = e;
731746
(*step)[i] = s;
@@ -741,17 +756,29 @@ inline void GetIndexRange(const mxnet::TShape& dshape,
741756
inline void SetSliceOpOutputDimSize(const index_t i, const int b,
742757
const int e, const int s,
743758
mxnet::TShape* oshape) {
744-
if (e != b) {
745-
if (s > 0) {
746-
CHECK_LT(b, e) << "slicing with begin[" << i << "]=" << b << ", end[" << i << "]="
747-
<< e << ", and step[" << i << "]=" << s << " is invalid";
748-
(*oshape)[i] = (e - b - 1) / s + 1;
759+
if (!Imperative::Get()->is_np_shape()) { //handle as ndarray
760+
if (e != b) {
761+
if (s > 0) {
762+
CHECK_LT(b, e) << "slicing with begin=[" << i << "]=" << b << ", end[" << i << "]="
763+
<< e << ", and step[" << i << "]=" << s << " is invalid";
764+
(*oshape)[i] = (e - b - 1) / s + 1;
765+
} else {
766+
CHECK_LT(e, b) << "slicing with begin=[" << i << "]=" << b << ", end[" << i << "]="
767+
<< e << ", and step[" << i << "]=" << s << " is invalid";
768+
(*oshape)[i] = (b - e - 1) / (-s) + 1;
769+
}
770+
} // else leave oshape[i] as 0 for partial infer
771+
} else { //handle as numpy compatible array
772+
if (e != b && b >= 0) {
773+
if (s > 0) {
774+
(*oshape)[i] = e > b ? (e - b - 1) / s + 1 : 0;
775+
} else {
776+
(*oshape)[i] = e < b ? (b - e - 1) / (-s) + 1 : 0;
777+
}
749778
} else {
750-
CHECK_LT(e, b) << "slicing with begin[" << i << "]=" << b << ", end[" << i << "]="
751-
<< e << ", and step[" << i << "]=" << s << " is invalid";
752-
(*oshape)[i] = (b - e - 1) / (-s) + 1;
779+
(*oshape)[i] = 0;
753780
}
754-
} // else leave oshape[i] as 0 for partial infer
781+
}
755782
}
756783

757784
inline bool SliceOpShape(const nnvm::NodeAttrs& attrs,
@@ -852,6 +879,7 @@ void SliceOpForward(const nnvm::NodeAttrs& attrs,
852879
Stream<xpu>* s = ctx.get_stream<xpu>();
853880
const TBlob& data = inputs[0];
854881
const TBlob& out = outputs[0];
882+
if (Imperative::Get()->is_np_shape() && out.Size() == 0) return;
855883
const SliceParam& param = nnvm::get<SliceParam>(attrs.parsed);
856884
MXNET_NDIM_SWITCH(data.ndim(), ndim, {
857885
common::StaticArray<index_t, ndim> begin, end, step;
@@ -951,6 +979,7 @@ void SliceOpBackward(const nnvm::NodeAttrs& attrs,
951979
} else if (req[0] == kWriteInplace) {
952980
LOG(FATAL) << "_slice_backward does not support kWriteInplace";
953981
}
982+
if (Imperative::Get()->is_np_shape() && ograd.Size() == 0) return;
954983
MXNET_NDIM_SWITCH(ograd.ndim(), ndim, {
955984
common::StaticArray<index_t, ndim> begin, end, step;
956985
GetIndexRange(igrad.shape_, param.begin, param.end, param.step, &begin, &end, &step);

src/operator/tensor/matrix_op.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,7 @@ Example::
506506
[5., 7.],
507507
[1., 3.]]
508508
)code" ADD_FILELINE)
509+
.add_alias("_npx_slice")
509510
.set_attr_parser(ParamParser<SliceParam>)
510511
.set_attr<mxnet::FInferShape>("FInferShape", SliceOpShape)
511512
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)

tests/python/unittest/test_numpy_op.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,78 @@ def is_int(dtype):
9292
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
9393

9494

95+
@with_seed()
96+
@use_np
97+
def test_npx_slice():
98+
class TestSlice(HybridBlock):
99+
def __init__(self, begin, end, step):
100+
super(TestSlice, self).__init__()
101+
self._begin = begin
102+
self._end = end
103+
self._step = step
104+
105+
def hybrid_forward(self, F, a, *args, **kwargs):
106+
return F.npx.slice(a, begin=self._begin, end=self._end, step=self._step)
107+
108+
def get_start_end_step(shape):
109+
start = []
110+
end = []
111+
step_switch = random.randint(-1,1)
112+
step = None if step_switch == 0 else []
113+
for i in range(len(shape)):
114+
s = random.randint(0, shape[i]-1)
115+
e = random.randint(s+1, shape[i])
116+
if step_switch == 1:
117+
step.append(1)
118+
start.append(s)
119+
end.append(e)
120+
elif step_switch == -1:
121+
step.append(-1)
122+
if e == shape[i]:
123+
e -= 1
124+
s -= 1
125+
if s == -1:
126+
s = None
127+
start.append(e)
128+
end.append(s)
129+
else:
130+
start.append(s)
131+
end.append(e)
132+
return start, end, step
133+
134+
for hybridize in [True, False]:
135+
for i in range(10):
136+
dim = random.randint(1,4)
137+
shape = [random.randint(1,5) for i in range(dim)]
138+
139+
# test gluon
140+
start, end, step = get_start_end_step(shape)
141+
test_slice = TestSlice(begin=start, end=end, step=step)
142+
if hybridize:
143+
test_slice.hybridize()
144+
145+
a = mx.nd.random.uniform(shape=shape).as_np_ndarray()
146+
a.attach_grad()
147+
if step is not None:
148+
expected_ret = a.as_nd_ndarray().slice(start, end, step)
149+
else:
150+
expected_ret = a.as_nd_ndarray().slice(start, end)
151+
with mx.autograd.record():
152+
y = test_slice(a)
153+
154+
assert_almost_equal(y.asnumpy(), expected_ret.asnumpy(), rtol=1e-3, atol=1e-5)
155+
156+
# test backward
157+
mx.autograd.backward(y)
158+
expected_grad = _np.zeros(shape)
159+
basic_index = tuple([
160+
slice(start[i], end[i], step[i]) if step is not None else slice(start[i], end[i])
161+
for i in range(len(start))
162+
])
163+
expected_grad[basic_index] = 1
164+
assert_almost_equal(a.grad.asnumpy(), expected_grad, rtol=1e-3, atol=1e-5)
165+
166+
95167
if __name__ == '__main__':
96168
import nose
97169
nose.runmodule()

0 commit comments

Comments
 (0)