@@ -34,6 +34,9 @@ DMLC_REGISTER_PARAMETER(NumpyTransposeParam);
34
34
DMLC_REGISTER_PARAMETER (NumpyRollParam);
35
35
DMLC_REGISTER_PARAMETER (NumpyMoveaxisParam);
36
36
DMLC_REGISTER_PARAMETER (NumpyRot90Param);
37
+ DMLC_REGISTER_PARAMETER (NumpyReshapeParam);
38
+ DMLC_REGISTER_PARAMETER (NumpyXReshapeParam);
39
+
37
40
38
41
bool NumpyTransposeShape (const nnvm::NodeAttrs& attrs,
39
42
mxnet::ShapeVector *in_attrs,
@@ -126,26 +129,6 @@ NNVM_REGISTER_OP(_np_transpose)
126
129
.add_argument(" a" , " NDArray-or-Symbol" , " Source input" )
127
130
.add_arguments(NumpyTransposeParam::__FIELDS__());
128
131
129
- struct NumpyReshapeParam : public dmlc ::Parameter<NumpyReshapeParam> {
130
- mxnet::TShape newshape;
131
- std::string order;
132
- DMLC_DECLARE_PARAMETER (NumpyReshapeParam) {
133
- DMLC_DECLARE_FIELD (newshape)
134
- .describe (" The new shape should be compatible with the original shape."
135
- " If an integer, then the result will be a 1-D array of that length."
136
- " One shape dimension can be -1. In this case, the value is inferred"
137
- " from the length of the array and remaining dimensions." );
138
- DMLC_DECLARE_FIELD (order)
139
- .set_default (" C" )
140
- .describe (" Read the elements of a using this index order, and place the elements into"
141
- " the reshaped array using this index order. 'C' means to read/write the elements"
142
- " using C-like index order, with the last axis index changing fastest, back to the"
143
- " first axis index changing slowest. Note that currently only C-like order is"
144
- " supported" );
145
- }
146
- };
147
-
148
- DMLC_REGISTER_PARAMETER (NumpyReshapeParam);
149
132
150
133
bool NumpyReshapeInferShape (const mxnet::TShape& src, mxnet::TShape* dst) {
151
134
if (shape_is_known (src) && shape_is_known (*dst)) {
@@ -202,6 +185,164 @@ bool NumpyReshapeShape(const nnvm::NodeAttrs& attrs,
202
185
return success;
203
186
}
204
187
188
+ bool NumpyXReshapeInferShape (const mxnet::TShape& src,
189
+ const mxnet::TShape& target,
190
+ mxnet::TShape* output,
191
+ const std::string &default_error_msg) {
192
+ bool target_shape_is_known = true ;
193
+ dim_t target_size = 1 ;
194
+ for (int i = 0 ; i < target.ndim (); ++i) {
195
+ if (target[i] < 0 ) {
196
+ target_shape_is_known = false ;
197
+ target_size = -1 ;
198
+ break ;
199
+ } else {
200
+ target_size *= target[i];
201
+ }
202
+ }
203
+ if (shape_is_known (src) && target_shape_is_known) {
204
+ CHECK_EQ (src.Size (), target_size) << default_error_msg;
205
+ *output = TShape (target.begin (), target.end ());
206
+ return true ;
207
+ } else if (!shape_is_known (src) || target.ndim () == -1 ) {
208
+ return false ;
209
+ } else {
210
+ int unknown_axis = -1 ;
211
+ dim_t known_dim_size_prod = 1 ;
212
+ std::vector<dim_t > output_shape_vector;
213
+ int src_inx = 0 ;
214
+ for (int i = 0 ; i < target.ndim (); ++i) {
215
+ dim_t proposed_dim = target[i];
216
+ CHECK (proposed_dim >= -6 )
217
+ << " Dimension size must be greater than -6, received " << proposed_dim;
218
+ if (proposed_dim == -1 ) {
219
+ // infer the known dimension
220
+ CHECK_LT (unknown_axis, 0 )
221
+ << " One and only one dim can be inferred" ;
222
+ unknown_axis = output_shape_vector.size ();
223
+ output_shape_vector.push_back (-1 );
224
+ src_inx++;
225
+ } else if (proposed_dim == -2 ) {
226
+ // copy the dimension from src to output
227
+ CHECK_LT (src_inx, src.ndim ())
228
+ << " Unmatching dimension of proposed new shape" ;
229
+ known_dim_size_prod *= src[src_inx];
230
+ output_shape_vector.push_back (src[src_inx++]);
231
+ } else if (proposed_dim == -3 ) {
232
+ // skip the source dimension if and only if it is one
233
+ CHECK_EQ (src[src_inx], 1 )
234
+ <<" -3 index should only be used to skip dimension size 1" ;
235
+ src_inx++;
236
+ } else if (proposed_dim == -4 ) {
237
+ // copy all remaining dims from source
238
+ while (src_inx < src.ndim ()) {
239
+ known_dim_size_prod *= src[src_inx];
240
+ const dim_t dn = src[src_inx++];
241
+ output_shape_vector.push_back (dn);
242
+ }
243
+ } else if (proposed_dim == -5 ) {
244
+ // merge two dims from source
245
+ CHECK_LT (src_inx, src.ndim ()-1 )
246
+ <<" Not enough dimensions left for the product" ;
247
+ const dim_t d1 = src[src_inx++];
248
+ const dim_t d2 = src[src_inx++];
249
+ if (!mxnet::dim_size_is_known (d1) || !mxnet::dim_size_is_known (d2)) {
250
+ CHECK_LT (unknown_axis, 0 )
251
+ << " One and only one dim can be inferred" ;
252
+ unknown_axis = output_shape_vector.size ();
253
+ output_shape_vector.push_back (-1 );
254
+ } else {
255
+ known_dim_size_prod *= d1*d2;
256
+ output_shape_vector.push_back (d1 * d2);
257
+ }
258
+ } else if (proposed_dim == -6 ) {
259
+ // split the source dim s into two dims
260
+ // read the left dim and then the right dim (either can be -1)
261
+ CHECK_LT (i + 2 , target.ndim ());
262
+ CHECK_LT (src_inx, src.ndim ());
263
+ const dim_t d0 = src[src_inx++];
264
+ dim_t d1 = target[++i];
265
+ dim_t d2 = target[++i];
266
+ CHECK (d1 != -1 || d2 != -1 ) << " Split dims cannot both be -1." ;
267
+ if (d1 == -1 && d0 >= 0 ) d1 = d0 / d2; // d0 must be known to do this
268
+ if (d2 == -1 && d0 >= 0 ) d2 = d0 / d1; // d0 must be known to do this
269
+ CHECK (d1 * d2 == static_cast <dim_t >(d0) || static_cast <dim_t >(d0) == dim_t (-1 ))
270
+ <<" Split dims " << d1 << " , " << d2 << " do not divide original dim " << d0;
271
+ if (d1 == -1 ) {
272
+ CHECK_LT (unknown_axis, 0 )
273
+ << " One and only one dim can be inferred" ;
274
+ unknown_axis = output_shape_vector.size ();
275
+ } else if (d2 == -1 ) {
276
+ CHECK_LT (unknown_axis, 0 )
277
+ << " One and only one dim can be inferred" ;
278
+ unknown_axis = output_shape_vector.size () + 1 ;
279
+ }
280
+ known_dim_size_prod *= d0 == -1 ? 1 : d0;
281
+ output_shape_vector.push_back (d1);
282
+ output_shape_vector.push_back (d2);
283
+ } else {
284
+ // greater than 0, new shape
285
+ known_dim_size_prod *= proposed_dim;
286
+ output_shape_vector.push_back (proposed_dim);
287
+ src_inx++;
288
+ }
289
+ }
290
+
291
+ if (unknown_axis > -1 ) {
292
+ // if the input in zero size tensor, the output must be of known shape of zero size
293
+ CHECK_NE (known_dim_size_prod, 0 ) << default_error_msg;
294
+ CHECK (src.Size () % known_dim_size_prod == 0 ) << default_error_msg;
295
+ output_shape_vector[unknown_axis] = src.Size () / known_dim_size_prod;
296
+ }
297
+
298
+ *output = mxnet::TShape (output_shape_vector.begin (), output_shape_vector.end ());
299
+ CHECK_EQ ((*output).Size (), src.Size ()) << default_error_msg;
300
+ return true ;
301
+ }
302
+ }
303
+
304
+ bool NumpyXReshapeShape (const nnvm::NodeAttrs& attrs,
305
+ mxnet::ShapeVector* in_attrs,
306
+ mxnet::ShapeVector* out_attrs) {
307
+ CHECK_EQ (in_attrs->size (), 1U ) << " Input: [data]" ;
308
+ CHECK_EQ (out_attrs->size (), 1U );
309
+ const NumpyXReshapeParam& param = nnvm::get<NumpyXReshapeParam>(attrs.parsed );
310
+ // sanity check
311
+ bool has_unknown_dim_size = false ;
312
+ for (int i = 0 ; i < param.newshape .ndim (); ++i) {
313
+ if (param.newshape [i] < 0 ) {
314
+ CHECK_GE (param.newshape [i], -6 )
315
+ << " Dimension size must be greater than or equal to -6" ;
316
+ if (param.newshape [i] == -1 ) {
317
+ CHECK (!has_unknown_dim_size) << " Can only specify one unknown dimension" ;
318
+ has_unknown_dim_size = true ;
319
+ }
320
+ }
321
+ }
322
+
323
+ mxnet::TShape output_shape;
324
+ bool success;
325
+ std::stringstream ss;
326
+ ss << " Cannot reshape array of shape " << in_attrs->at (0 )
327
+ << " into shape " << param.newshape
328
+ << " , reverse = " << param.reverse ;
329
+ std::string err_msg = ss.str ();
330
+ if (!param.reverse ) {
331
+ success = NumpyXReshapeInferShape (in_attrs->at (0 ),
332
+ param.newshape , &output_shape, err_msg);
333
+ } else {
334
+ mxnet::TShape rev_in_shape = in_attrs->at (0 );
335
+ mxnet::TShape rev_newshape = param.newshape ;
336
+ std::reverse (rev_in_shape.begin (), rev_in_shape.end ());
337
+ std::reverse (rev_newshape.begin (), rev_newshape.end ());
338
+ success = NumpyXReshapeInferShape (rev_in_shape,
339
+ rev_newshape, &output_shape, err_msg);
340
+ std::reverse (output_shape.begin (), output_shape.end ());
341
+ }
342
+ SHAPE_ASSIGN_CHECK (*out_attrs, 0 , output_shape);
343
+ return success;
344
+ }
345
+
205
346
NNVM_REGISTER_OP (_np_reshape)
206
347
.describe(R"code( )code" ADD_FILELINE)
207
348
.add_alias(" _npi_reshape" )
@@ -227,6 +368,31 @@ NNVM_REGISTER_OP(_np_reshape)
227
368
.add_argument(" a" , " NDArray-or-Symbol" , " Array to be reshaped." )
228
369
.add_arguments(NumpyReshapeParam::__FIELDS__());
229
370
371
+
372
+ NNVM_REGISTER_OP (_npx_reshape)
373
+ .describe(R"code( )code" ADD_FILELINE)
374
+ .set_num_inputs(1 )
375
+ .set_num_outputs(1 )
376
+ .set_attr_parser(ParamParser<NumpyXReshapeParam>)
377
+ .set_attr<mxnet::FInferShape>(" FInferShape" , NumpyXReshapeShape)
378
+ .set_attr<nnvm::FInferType>(" FInferType" , ElemwiseType<1 , 1 >)
379
+ .set_attr<nnvm::FGradient>(" FGradient" , ElemwiseGradUseNone{" _backward_reshape" })
380
+ .set_attr<FCompute>(" FCompute<cpu>" , UnaryOp::IdentityCompute<cpu>)
381
+ .set_attr<nnvm::FInplaceOption>(" FInplaceOption" ,
382
+ [](const NodeAttrs& attrs) {
383
+ return std::vector<std::pair<int , int > >{{0 , 0 }};
384
+ })
385
+ .set_attr<nnvm::FInplaceIdentity>(" FInplaceIdentity" ,
386
+ [](const NodeAttrs& attrs){
387
+ return std::vector<bool >{true };
388
+ })
389
+ .set_attr<nnvm::FListInputNames>(" FListInputNames" ,
390
+ [](const NodeAttrs& attrs) {
391
+ return std::vector<std::string>{" a" };
392
+ })
393
+ .add_argument(" a" , " NDArray-or-Symbol" , " Array to be reshaped." )
394
+ .add_arguments(NumpyXReshapeParam::__FIELDS__());
395
+
230
396
bool NumpySqueezeShape (const nnvm::NodeAttrs& attrs,
231
397
mxnet::ShapeVector *in_attrs,
232
398
mxnet::ShapeVector *out_attrs) {
0 commit comments