Skip to content

Commit 688ad97

Browse files
MikeRohit Kumar Srivastava
Mike
authored and
Rohit Kumar Srivastava
committed
[Numpy] Numpy compatible slicing (apache#15798)
* Modify ndarray slice to have numpy compatbile behaviou * Minor syntax fix * Fix slice inconsistency * Allow empty outputs after slicing ndarrays * Fix
1 parent 9e12ca1 commit 688ad97

File tree

3 files changed

+34
-43
lines changed

3 files changed

+34
-43
lines changed

src/operator/tensor/matrix_op-inl.h

Lines changed: 33 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -670,13 +670,13 @@ inline void GetIndexRange(const mxnet::TShape& dshape,
670670
<< "Static array size=" << ndim
671671
<< " is not equal to data shape ndim=" << dshape.ndim();
672672

673-
if (param_step.ndim() != 0) {
673+
if (param_step.ndim() > 0) {
674674
CHECK_EQ(param_step.ndim(), param_begin.ndim())
675675
<< "step and begin must have the same length";
676676
}
677677

678678
for (int i = 0; i < param_begin.ndim(); ++i) {
679-
index_t s = param_step.ndim() != 0U && param_step[i].has_value() ? param_step[i].value() : 1;
679+
index_t s = param_step.ndim() > 0 && param_step[i].has_value() ? param_step[i].value() : 1;
680680
CHECK_NE(s, 0) << "slice op step[" << i << "] cannot be 0";
681681

682682
index_t b = 0, e = 0;
@@ -685,58 +685,54 @@ inline void GetIndexRange(const mxnet::TShape& dshape,
685685
b = param_begin[i].has_value() ? param_begin[i].value() : (s < 0 ? len - 1 : 0);
686686
e = param_end[i].has_value() ? param_end[i].value() : (s < 0 ? -1 : len);
687687

688-
// checking upper and lower bounds for begin
689688
if (b < 0) {
690689
b += len;
691-
CHECK_GE(b, 0) << "slicing with begin[" << i << "]=" << b - len
692-
<< " exceeds limit of input dimension[" << i << "]=" << len;
693690
}
694-
CHECK_LT(b, len) << "slicing with begin[" << i << "]=" << b
695-
<< " exceeds limit of input dimension[" << i << "]=" << len;
696-
697-
// checking upper and lower bounds for end
698691
if (e < 0 && param_end[i].has_value()) {
699-
if (!(s < 0 && e == -1)) {
700-
// Keep end=-1 as one-beyond-limits index for negative stride
701-
e += len;
702-
}
703-
CHECK_GE(e, 0) << "slicing with end[" << i << "]=" << e - len
704-
<< " exceeds limit of input dimension[" << i << "]=" << len;
692+
e += len;
705693
}
706-
CHECK_LE(e, len) << "slicing with end[" << i << "]=" << e
707-
<< " exceeds limit of input dimension[" << i << "]=" << len;
708694

709-
// checking begin==end case which is not supported
710-
CHECK_NE(b, e) << "slicing with begin[" << i << "]=end[" << i << "]="
711-
<< e << " results in an empty tensor and is not supported";
695+
// move the begin and end to correct position for calculating dim size
696+
b = (b < 0 && s > 0) ? 0 : b;
697+
b = (b > len - 1 && s < 0) ? len - 1 : b;
698+
// if the start value lead to empty tensor under step s, use -1 for indication
699+
b = (b < 0 || b > len - 1) ? -1 : b;
700+
e = e > -1 ? e : -1;
701+
e = e > len ? len : e;
702+
} else if (len == 0) {
703+
b = 0;
704+
e = 0;
712705
}
713706

714707
(*begin)[i] = b;
715708
(*end)[i] = e;
716709
(*step)[i] = s;
717710
}
718711

719-
for (index_t i = param_begin.ndim(); i < dshape.ndim(); ++i) {
712+
for (int i = param_begin.ndim(); i < dshape.ndim(); ++i) {
720713
(*begin)[i] = 0;
721714
(*end)[i] = dshape[i];
722715
(*step)[i] = 1;
723716
}
724717
}
725718

726-
inline void SetSliceOpOutputDimSize(const index_t i, const int b,
719+
inline void SetSliceOpOutputDimSize(const mxnet::TShape& dshape,
720+
const index_t i, const int b,
727721
const int e, const int s,
728722
mxnet::TShape* oshape) {
729-
if (e != b) {
723+
if (!mxnet::dim_size_is_known(dshape, i)) {
724+
(*oshape)[i] = -1;
725+
return;
726+
}
727+
if (e != b && b >= 0) {
730728
if (s > 0) {
731-
CHECK_LT(b, e) << "slicing with begin[" << i << "]=" << b << ", end[" << i << "]="
732-
<< e << ", and step[" << i << "]=" << s << " is invalid";
733-
(*oshape)[i] = (e - b - 1) / s + 1;
729+
(*oshape)[i] = e > b ? (e - b - 1) / s + 1 : 0;
734730
} else {
735-
CHECK_LT(e, b) << "slicing with begin[" << i << "]=" << b << ", end[" << i << "]="
736-
<< e << ", and step[" << i << "]=" << s << " is invalid";
737-
(*oshape)[i] = (b - e - 1) / (-s) + 1;
731+
(*oshape)[i] = e < b ? (b - e - 1) / (-s) + 1 : 0;
738732
}
739-
} // else leave oshape[i] as 0 for partial infer
733+
} else {
734+
(*oshape)[i] = 0;
735+
}
740736
}
741737

742738
inline bool SliceOpShape(const nnvm::NodeAttrs& attrs,
@@ -746,6 +742,7 @@ inline bool SliceOpShape(const nnvm::NodeAttrs& attrs,
746742
CHECK_EQ(out_attrs->size(), 1U);
747743
const mxnet::TShape& dshape = (*in_attrs)[0];
748744
if (!mxnet::ndim_is_known(dshape)) return false;
745+
CHECK_GT(dshape.ndim(), 0) << "slice only works for ndim > 0";
749746
const SliceParam& param = nnvm::get<SliceParam>(attrs.parsed);
750747
mxnet::TShape oshape = dshape;
751748

@@ -754,12 +751,12 @@ inline bool SliceOpShape(const nnvm::NodeAttrs& attrs,
754751
GetIndexRange(dshape, param.begin, param.end, param.step, &begin, &end, &step);
755752
for (int i = 0; i < param.begin.ndim(); ++i) {
756753
const int b = begin[i], e = end[i], s = step[i];
757-
SetSliceOpOutputDimSize(i, b, e, s, &oshape);
754+
SetSliceOpOutputDimSize(dshape, i, b, e, s, &oshape);
758755
}
759756
})
760757

761758
SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape);
762-
return shape_is_known(oshape);
759+
return shape_is_known(dshape) && shape_is_known(oshape);
763760
}
764761

765762
template<int ndim, int req, typename xpu>
@@ -837,6 +834,7 @@ void SliceOpForward(const nnvm::NodeAttrs& attrs,
837834
Stream<xpu>* s = ctx.get_stream<xpu>();
838835
const TBlob& data = inputs[0];
839836
const TBlob& out = outputs[0];
837+
if (out.Size() == 0) return;
840838
const SliceParam& param = nnvm::get<SliceParam>(attrs.parsed);
841839
MXNET_NDIM_SWITCH(data.ndim(), ndim, {
842840
common::StaticArray<index_t, ndim> begin, end, step;
@@ -936,6 +934,7 @@ void SliceOpBackward(const nnvm::NodeAttrs& attrs,
936934
} else if (req[0] == kWriteInplace) {
937935
LOG(FATAL) << "_slice_backward does not support kWriteInplace";
938936
}
937+
if (ograd.Size() == 0) return;
939938
MXNET_NDIM_SWITCH(ograd.ndim(), ndim, {
940939
common::StaticArray<index_t, ndim> begin, end, step;
941940
GetIndexRange(igrad.shape_, param.begin, param.end, param.step, &begin, &end, &step);
@@ -967,7 +966,7 @@ inline bool SliceAssignOpShape(const nnvm::NodeAttrs& attrs,
967966
GetIndexRange(dshape, param.begin, param.end, param.step, &begin, &end, &step);
968967
for (int i = 0; i < param.begin.ndim(); ++i) {
969968
const int b = begin[i], e = end[i], s = step[i];
970-
SetSliceOpOutputDimSize(i, b, e, s, &vshape);
969+
SetSliceOpOutputDimSize(dshape, i, b, e, s, &vshape);
971970
}
972971
})
973972
SHAPE_ASSIGN_CHECK(*in_attrs, 1, vshape);
@@ -1106,7 +1105,7 @@ void SliceAssignScalarOpForward(const nnvm::NodeAttrs& attrs,
11061105
GetIndexRange(data.shape_, param.begin, param.end, param.step, &begin, &end, &step);
11071106
for (index_t i = 0; i < param.begin.ndim(); ++i) {
11081107
const int b = begin[i], e = end[i], s = step[i];
1109-
SetSliceOpOutputDimSize(i, b, e, s, &vshape);
1108+
SetSliceOpOutputDimSize(data.shape_, i, b, e, s, &vshape);
11101109
}
11111110
MSHADOW_TYPE_SWITCH(out.type_flag_, DType, {
11121111
mxnet_op::Kernel<slice_assign_scalar<ndim>, xpu>::Launch(s, vshape.FlatTo2D()[0],

src/operator/tensor/matrix_op.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,7 @@ Example::
506506
[5., 7.],
507507
[1., 3.]]
508508
)code" ADD_FILELINE)
509+
.add_alias("_npx_slice")
509510
.set_attr_parser(ParamParser<SliceParam>)
510511
.set_attr<mxnet::FInferShape>("FInferShape", SliceOpShape)
511512
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)

tests/python/unittest/test_operator.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7336,15 +7336,6 @@ def test_slice_forward_backward(a, index):
73367336
for index in index_list:
73377337
test_slice_forward_backward(arr, index)
73387338

7339-
def test_begin_equals_end(shape, begin, end, step):
7340-
in_arr = mx.nd.arange(np.prod(shape)).reshape(shape=shape)
7341-
out_arr = mx.nd.slice(in_arr, begin=begin, end=end, step=step)
7342-
7343-
assertRaises(MXNetError, test_begin_equals_end, (4,), (2,), (2,), (1,))
7344-
assertRaises(MXNetError, test_begin_equals_end, (1, 5), (None, 3), (None, 3), (-1, 1))
7345-
assertRaises(MXNetError, test_begin_equals_end, (3, 4, 5), (1, 3, 1), (3, 3, 1), (1, -3, 2))
7346-
assertRaises(MXNetError, test_begin_equals_end, (2, 4), (None, 2), (None, 2), (1, -1))
7347-
73487339
# check numeric gradient
73497340
in_data = np.arange(36).reshape(2, 2, 3, 3)
73507341
data = mx.sym.Variable('data')

0 commit comments

Comments
 (0)