@@ -670,13 +670,13 @@ inline void GetIndexRange(const mxnet::TShape& dshape,
670
670
<< " Static array size=" << ndim
671
671
<< " is not equal to data shape ndim=" << dshape.ndim ();
672
672
673
- if (param_step.ndim () != 0 ) {
673
+ if (param_step.ndim () > 0 ) {
674
674
CHECK_EQ (param_step.ndim (), param_begin.ndim ())
675
675
<< " step and begin must have the same length" ;
676
676
}
677
677
678
678
for (int i = 0 ; i < param_begin.ndim (); ++i) {
679
- index_t s = param_step.ndim () != 0U && param_step[i].has_value () ? param_step[i].value () : 1 ;
679
+ index_t s = param_step.ndim () > 0 && param_step[i].has_value () ? param_step[i].value () : 1 ;
680
680
CHECK_NE (s, 0 ) << " slice op step[" << i << " ] cannot be 0" ;
681
681
682
682
index_t b = 0 , e = 0 ;
@@ -685,58 +685,54 @@ inline void GetIndexRange(const mxnet::TShape& dshape,
685
685
b = param_begin[i].has_value () ? param_begin[i].value () : (s < 0 ? len - 1 : 0 );
686
686
e = param_end[i].has_value () ? param_end[i].value () : (s < 0 ? -1 : len);
687
687
688
- // checking upper and lower bounds for begin
689
688
if (b < 0 ) {
690
689
b += len;
691
- CHECK_GE (b, 0 ) << " slicing with begin[" << i << " ]=" << b - len
692
- << " exceeds limit of input dimension[" << i << " ]=" << len;
693
690
}
694
- CHECK_LT (b, len) << " slicing with begin[" << i << " ]=" << b
695
- << " exceeds limit of input dimension[" << i << " ]=" << len;
696
-
697
- // checking upper and lower bounds for end
698
691
if (e < 0 && param_end[i].has_value ()) {
699
- if (!(s < 0 && e == -1 )) {
700
- // Keep end=-1 as one-beyond-limits index for negative stride
701
- e += len;
702
- }
703
- CHECK_GE (e, 0 ) << " slicing with end[" << i << " ]=" << e - len
704
- << " exceeds limit of input dimension[" << i << " ]=" << len;
692
+ e += len;
705
693
}
706
- CHECK_LE (e, len) << " slicing with end[" << i << " ]=" << e
707
- << " exceeds limit of input dimension[" << i << " ]=" << len;
708
694
709
- // checking begin==end case which is not supported
710
- CHECK_NE (b, e) << " slicing with begin[" << i << " ]=end[" << i << " ]="
711
- << e << " results in an empty tensor and is not supported" ;
695
+ // move the begin and end to correct position for calculating dim size
696
+ b = (b < 0 && s > 0 ) ? 0 : b;
697
+ b = (b > len - 1 && s < 0 ) ? len - 1 : b;
698
+ // if the start value lead to empty tensor under step s, use -1 for indication
699
+ b = (b < 0 || b > len - 1 ) ? -1 : b;
700
+ e = e > -1 ? e : -1 ;
701
+ e = e > len ? len : e;
702
+ } else if (len == 0 ) {
703
+ b = 0 ;
704
+ e = 0 ;
712
705
}
713
706
714
707
(*begin)[i] = b;
715
708
(*end)[i] = e;
716
709
(*step)[i] = s;
717
710
}
718
711
719
- for (index_t i = param_begin.ndim (); i < dshape.ndim (); ++i) {
712
+ for (int i = param_begin.ndim (); i < dshape.ndim (); ++i) {
720
713
(*begin)[i] = 0 ;
721
714
(*end)[i] = dshape[i];
722
715
(*step)[i] = 1 ;
723
716
}
724
717
}
725
718
726
- inline void SetSliceOpOutputDimSize (const index_t i, const int b,
719
+ inline void SetSliceOpOutputDimSize (const mxnet::TShape& dshape,
720
+ const index_t i, const int b,
727
721
const int e, const int s,
728
722
mxnet::TShape* oshape) {
729
- if (e != b) {
723
+ if (!mxnet::dim_size_is_known (dshape, i)) {
724
+ (*oshape)[i] = -1 ;
725
+ return ;
726
+ }
727
+ if (e != b && b >= 0 ) {
730
728
if (s > 0 ) {
731
- CHECK_LT (b, e) << " slicing with begin[" << i << " ]=" << b << " , end[" << i << " ]="
732
- << e << " , and step[" << i << " ]=" << s << " is invalid" ;
733
- (*oshape)[i] = (e - b - 1 ) / s + 1 ;
729
+ (*oshape)[i] = e > b ? (e - b - 1 ) / s + 1 : 0 ;
734
730
} else {
735
- CHECK_LT (e, b) << " slicing with begin[" << i << " ]=" << b << " , end[" << i << " ]="
736
- << e << " , and step[" << i << " ]=" << s << " is invalid" ;
737
- (*oshape)[i] = (b - e - 1 ) / (-s) + 1 ;
731
+ (*oshape)[i] = e < b ? (b - e - 1 ) / (-s) + 1 : 0 ;
738
732
}
739
- } // else leave oshape[i] as 0 for partial infer
733
+ } else {
734
+ (*oshape)[i] = 0 ;
735
+ }
740
736
}
741
737
742
738
inline bool SliceOpShape (const nnvm::NodeAttrs& attrs,
@@ -746,6 +742,7 @@ inline bool SliceOpShape(const nnvm::NodeAttrs& attrs,
746
742
CHECK_EQ (out_attrs->size (), 1U );
747
743
const mxnet::TShape& dshape = (*in_attrs)[0 ];
748
744
if (!mxnet::ndim_is_known (dshape)) return false ;
745
+ CHECK_GT (dshape.ndim (), 0 ) << " slice only works for ndim > 0" ;
749
746
const SliceParam& param = nnvm::get<SliceParam>(attrs.parsed );
750
747
mxnet::TShape oshape = dshape;
751
748
@@ -754,12 +751,12 @@ inline bool SliceOpShape(const nnvm::NodeAttrs& attrs,
754
751
GetIndexRange (dshape, param.begin , param.end , param.step , &begin, &end, &step);
755
752
for (int i = 0 ; i < param.begin .ndim (); ++i) {
756
753
const int b = begin[i], e = end[i], s = step[i];
757
- SetSliceOpOutputDimSize (i, b, e, s, &oshape);
754
+ SetSliceOpOutputDimSize (dshape, i, b, e, s, &oshape);
758
755
}
759
756
})
760
757
761
758
SHAPE_ASSIGN_CHECK (*out_attrs, 0 , oshape);
762
- return shape_is_known (oshape);
759
+ return shape_is_known (dshape) && shape_is_known ( oshape);
763
760
}
764
761
765
762
template <int ndim, int req, typename xpu>
@@ -837,6 +834,7 @@ void SliceOpForward(const nnvm::NodeAttrs& attrs,
837
834
Stream<xpu>* s = ctx.get_stream <xpu>();
838
835
const TBlob& data = inputs[0 ];
839
836
const TBlob& out = outputs[0 ];
837
+ if (out.Size () == 0 ) return ;
840
838
const SliceParam& param = nnvm::get<SliceParam>(attrs.parsed );
841
839
MXNET_NDIM_SWITCH (data.ndim (), ndim, {
842
840
common::StaticArray<index_t , ndim> begin, end, step;
@@ -936,6 +934,7 @@ void SliceOpBackward(const nnvm::NodeAttrs& attrs,
936
934
} else if (req[0 ] == kWriteInplace ) {
937
935
LOG (FATAL) << " _slice_backward does not support kWriteInplace" ;
938
936
}
937
+ if (ograd.Size () == 0 ) return ;
939
938
MXNET_NDIM_SWITCH (ograd.ndim (), ndim, {
940
939
common::StaticArray<index_t , ndim> begin, end, step;
941
940
GetIndexRange (igrad.shape_ , param.begin , param.end , param.step , &begin, &end, &step);
@@ -967,7 +966,7 @@ inline bool SliceAssignOpShape(const nnvm::NodeAttrs& attrs,
967
966
GetIndexRange (dshape, param.begin , param.end , param.step , &begin, &end, &step);
968
967
for (int i = 0 ; i < param.begin .ndim (); ++i) {
969
968
const int b = begin[i], e = end[i], s = step[i];
970
- SetSliceOpOutputDimSize (i, b, e, s, &vshape);
969
+ SetSliceOpOutputDimSize (dshape, i, b, e, s, &vshape);
971
970
}
972
971
})
973
972
SHAPE_ASSIGN_CHECK (*in_attrs, 1 , vshape);
@@ -1106,7 +1105,7 @@ void SliceAssignScalarOpForward(const nnvm::NodeAttrs& attrs,
1106
1105
GetIndexRange (data.shape_ , param.begin , param.end , param.step , &begin, &end, &step);
1107
1106
for (index_t i = 0 ; i < param.begin .ndim (); ++i) {
1108
1107
const int b = begin[i], e = end[i], s = step[i];
1109
- SetSliceOpOutputDimSize (i, b, e, s, &vshape);
1108
+ SetSliceOpOutputDimSize (data. shape_ , i, b, e, s, &vshape);
1110
1109
}
1111
1110
MSHADOW_TYPE_SWITCH (out.type_flag_ , DType, {
1112
1111
mxnet_op::Kernel<slice_assign_scalar<ndim>, xpu>::Launch (s, vshape.FlatTo2D ()[0 ],
0 commit comments