@@ -685,13 +685,13 @@ inline void GetIndexRange(const mxnet::TShape& dshape,
685
685
<< " Static array size=" << ndim
686
686
<< " is not equal to data shape ndim=" << dshape.ndim ();
687
687
688
- if (param_step.ndim () != 0 ) {
688
+ if (param_step.ndim () > 0 ) {
689
689
CHECK_EQ (param_step.ndim (), param_begin.ndim ())
690
690
<< " step and begin must have the same length" ;
691
691
}
692
692
693
693
for (int i = 0 ; i < param_begin.ndim (); ++i) {
694
- index_t s = param_step.ndim () != 0U && param_step[i].has_value () ? param_step[i].value () : 1 ;
694
+ index_t s = param_step.ndim () > 0 && param_step[i].has_value () ? param_step[i].value () : 1 ;
695
695
CHECK_NE (s, 0 ) << " slice op step[" << i << " ] cannot be 0" ;
696
696
697
697
index_t b = 0 , e = 0 ;
@@ -703,29 +703,44 @@ inline void GetIndexRange(const mxnet::TShape& dshape,
703
703
// checking upper and lower bounds for begin
704
704
if (b < 0 ) {
705
705
b += len;
706
- CHECK_GE (b, 0 ) << " slicing with begin[" << i << " ]=" << b - len
707
- << " exceeds limit of input dimension[" << i << " ]=" << len;
706
+ if (!Imperative::Get ()->is_np_shape ()) {
707
+ CHECK_GE (b, 0 ) << " slicing with begin[" << i << " ]=" << b - len
708
+ << " exceeds limit of input dimension[" << i << " ]=" << len;
709
+ }
710
+ }
711
+ if (!Imperative::Get ()->is_np_shape ()) {
712
+ CHECK_LT (b, len) << " slicing with begin[" << i << " ]=" << b
713
+ << " exceeds limit of input dimension[" << i << " ]=" << len;
708
714
}
709
- CHECK_LT (b, len) << " slicing with begin[" << i << " ]=" << b
710
- << " exceeds limit of input dimension[" << i << " ]=" << len;
711
-
712
715
// checking upper and lower bounds for end
713
716
if (e < 0 && param_end[i].has_value ()) {
714
- if (!(s < 0 && e == -1 )) {
715
- // Keep end=-1 as one-beyond-limits index for negative stride
716
- e += len;
717
+ e += len;
718
+ if (!Imperative::Get ()->is_np_shape ()) {
719
+ CHECK_GE (e, 0 ) << " slicing with end[" << i << " ]=" << e - len
720
+ << " exceeds limit of input dimension[" << i << " ]=" << len;
717
721
}
718
- CHECK_GE (e, 0 ) << " slicing with end[" << i << " ]=" << e - len
719
- << " exceeds limit of input dimension[" << i << " ]=" << len;
720
722
}
721
- CHECK_LE (e, len) << " slicing with end[" << i << " ]=" << e
722
- << " exceeds limit of input dimension[" << i << " ]=" << len;
723
+ if (!Imperative::Get ()->is_np_shape ()) {
724
+ CHECK_LE (e, len) << " slicing with end[" << i << " ]=" << e
725
+ << " exceeds limit of input dimension[" << i << " ]=" << len;
726
+ }
723
727
724
728
// checking begin==end case which is not supported
725
- CHECK_NE (b, e) << " slicing with begin[" << i << " ]=end[" << i << " ]="
726
- << e << " results in an empty tensor and is not supported" ;
729
+ if (!Imperative::Get ()->is_np_shape ()) {
730
+ CHECK_NE (b, e) << " slicing with begin[" << i << " ]=end[" << i << " ]="
731
+ << e << " results in an empty tensor and is not supported" ;
732
+ }
727
733
}
728
734
735
+ if (Imperative::Get ()->is_np_shape ()) {
736
+ // move the begin and end to correct position for calculating dim size
737
+ b = b < 0 && s > 0 ? 0 : b;
738
+ b = b > len-1 && s < 0 ? len-1 : b;
739
+ // if the start value lead to empty tensor under step s, use -1 for indication
740
+ b = b < 0 || b > len-1 ? -1 : b;
741
+ e = e > -1 ? e : -1 ;
742
+ e = e > len ? len : e;
743
+ }
729
744
(*begin)[i] = b;
730
745
(*end)[i] = e;
731
746
(*step)[i] = s;
@@ -741,17 +756,29 @@ inline void GetIndexRange(const mxnet::TShape& dshape,
741
756
inline void SetSliceOpOutputDimSize (const index_t i, const int b,
742
757
const int e, const int s,
743
758
mxnet::TShape* oshape) {
744
- if (e != b) {
745
- if (s > 0 ) {
746
- CHECK_LT (b, e) << " slicing with begin[" << i << " ]=" << b << " , end[" << i << " ]="
747
- << e << " , and step[" << i << " ]=" << s << " is invalid" ;
748
- (*oshape)[i] = (e - b - 1 ) / s + 1 ;
759
+ if (!Imperative::Get ()->is_np_shape ()) { // handle as ndarray
760
+ if (e != b) {
761
+ if (s > 0 ) {
762
+ CHECK_LT (b, e) << " slicing with begin=[" << i << " ]=" << b << " , end[" << i << " ]="
763
+ << e << " , and step[" << i << " ]=" << s << " is invalid" ;
764
+ (*oshape)[i] = (e - b - 1 ) / s + 1 ;
765
+ } else {
766
+ CHECK_LT (e, b) << " slicing with begin=[" << i << " ]=" << b << " , end[" << i << " ]="
767
+ << e << " , and step[" << i << " ]=" << s << " is invalid" ;
768
+ (*oshape)[i] = (b - e - 1 ) / (-s) + 1 ;
769
+ }
770
+ } // else leave oshape[i] as 0 for partial infer
771
+ } else { // handle as numpy compatible array
772
+ if (e != b && b >= 0 ) {
773
+ if (s > 0 ) {
774
+ (*oshape)[i] = e > b ? (e - b - 1 ) / s + 1 : 0 ;
775
+ } else {
776
+ (*oshape)[i] = e < b ? (b - e - 1 ) / (-s) + 1 : 0 ;
777
+ }
749
778
} else {
750
- CHECK_LT (e, b) << " slicing with begin[" << i << " ]=" << b << " , end[" << i << " ]="
751
- << e << " , and step[" << i << " ]=" << s << " is invalid" ;
752
- (*oshape)[i] = (b - e - 1 ) / (-s) + 1 ;
779
+ (*oshape)[i] = 0 ;
753
780
}
754
- } // else leave oshape[i] as 0 for partial infer
781
+ }
755
782
}
756
783
757
784
inline bool SliceOpShape (const nnvm::NodeAttrs& attrs,
@@ -852,6 +879,7 @@ void SliceOpForward(const nnvm::NodeAttrs& attrs,
852
879
Stream<xpu>* s = ctx.get_stream <xpu>();
853
880
const TBlob& data = inputs[0 ];
854
881
const TBlob& out = outputs[0 ];
882
+ if (Imperative::Get ()->is_np_shape () && out.Size () == 0 ) return ;
855
883
const SliceParam& param = nnvm::get<SliceParam>(attrs.parsed );
856
884
MXNET_NDIM_SWITCH (data.ndim (), ndim, {
857
885
common::StaticArray<index_t , ndim> begin, end, step;
@@ -951,6 +979,7 @@ void SliceOpBackward(const nnvm::NodeAttrs& attrs,
951
979
} else if (req[0 ] == kWriteInplace ) {
952
980
LOG (FATAL) << " _slice_backward does not support kWriteInplace" ;
953
981
}
982
+ if (Imperative::Get ()->is_np_shape () && ograd.Size () == 0 ) return ;
954
983
MXNET_NDIM_SWITCH (ograd.ndim (), ndim, {
955
984
common::StaticArray<index_t , ndim> begin, end, step;
956
985
GetIndexRange (igrad.shape_ , param.begin , param.end , param.step , &begin, &end, &step);
0 commit comments