Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit c130cc9

Browse files
sxjsciencereminisce
authored andcommitted
add npx reshape (#16640)
1 parent 29e467b commit c130cc9

File tree

5 files changed

+371
-21
lines changed

5 files changed

+371
-21
lines changed

python/mxnet/_numpy_op_doc.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -961,3 +961,69 @@ def _np_broadcast_to(array, shape, out=None):
961961
[1., 2., 3.]])
962962
"""
963963
pass
964+
965+
966+
def _npx_reshape(a, newshape, reverse=False, order='C'):
967+
"""
968+
Gives a new shape to an array without changing its data.
969+
This function always returns a copy of the input array if
970+
``out`` is not provided.
971+
972+
Parameters
973+
----------
974+
a : ndarray
975+
Array to be reshaped.
976+
newshape : int or tuple of ints
977+
The new shape should be compatible with the original shape.
978+
If an integer, then the result will be a 1-D array of that length.
979+
One shape dimension can be -1. In this case, the value is inferred
980+
from the length of the array and remaining dimensions.
981+
-2 to -6 are used for data manipulation.
982+
983+
- -2 copy this dimension from the input to the output shape.
984+
- -3 will skip current dimension if and only if the current dim size is one.
985+
- -4 copy all remain of the input dimensions to the output shape.
986+
- -5 use the product of two consecutive dimensions of the input
987+
shape as the output.
988+
- -6 split one dimension of the input into two dimensions passed
989+
subsequent to -6 in the new shape.
990+
991+
reverse : bool, optional
992+
If set to true, the special values will be inferred from right to left.
993+
order : {'C'}, optional
994+
Read the elements of `a` using this index order, and place the
995+
elements into the reshaped array using this index order. 'C'
996+
means to read / write the elements using C-like index order,
997+
with the last axis index changing fastest, back to the first
998+
axis index changing slowest. Other order types such as 'F'/'A'
999+
may be added in the future.
1000+
1001+
Returns
1002+
-------
1003+
reshaped_array : ndarray
1004+
It will be always a copy of the original array. This behavior is different
1005+
from the official NumPy ``reshape`` operator where views of the original array may be
1006+
generated.
1007+
1008+
Examples
1009+
--------
1010+
>>> x = np.ones((2, 3, 8))
1011+
>>> npx.reshape(x, (-2, -2, 2, -1)).shape
1012+
(2, 3, 2, 4)
1013+
>>> x = np.ones((8, 3, 3, 3, 4, 4))
1014+
>>> npx.reshape(x, (-6, 2, -1, -4)).shape
1015+
(2, 4, 3, 3, 3, 4, 4)
1016+
>>> x = np.ones((8, 3, 3, 3, 4, 4))
1017+
>>> npx.reshape(x, (-5, -4)).shape
1018+
(24, 3, 3, 4, 4)
1019+
>>> x = np.ones((8, 1, 1, 1, 3))
1020+
>>> npx.reshape(x, (-2, -3, -3, -3, -2)).shape
1021+
(8, 3)
1022+
>>> x = np.ones((8, 3, 3, 3, 3, 8))
1023+
>>> npx.reshape(x, (-4, -5), reverse=True).shape
1024+
(8, 3, 3, 3, 24)
1025+
>>> x = np.ones((8, 3, 2, 4, 8))
1026+
>>> npx.reshape(x, (-4, -1, 2, -6), reverse=True).shape
1027+
(8, 3, 2, 4, 4, 2)
1028+
"""
1029+
pass

src/operator/numpy/np_matrix_op-inl.h

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
#include <vector>
2929
#include <algorithm>
30+
#include <string>
3031
#include "../tensor/matrix_op-inl.h"
3132
#include "../nn/concat-inl.h"
3233
#include "../../common/utils.h"
@@ -51,6 +52,58 @@ struct NumpyVstackParam : public dmlc::Parameter<NumpyVstackParam> {
5152
}
5253
};
5354

55+
struct NumpyReshapeParam : public dmlc::Parameter<NumpyReshapeParam> {
56+
mxnet::TShape newshape;
57+
std::string order;
58+
DMLC_DECLARE_PARAMETER(NumpyReshapeParam) {
59+
DMLC_DECLARE_FIELD(newshape)
60+
.describe("The new shape should be compatible with the original shape."
61+
" If an integer, then the result will be a 1-D array of that length."
62+
" One shape dimension can be -1. In this case, the value is inferred"
63+
" from the length of the array and remaining dimensions.");
64+
DMLC_DECLARE_FIELD(order)
65+
.set_default("C")
66+
.describe("Read the elements of a using this index order, and place the elements into"
67+
" the reshaped array using this index order. 'C' means to read/write the elements"
68+
" using C-like index order, with the last axis index changing fastest,"
69+
" back to the first axis index changing slowest."
70+
" Note that currently only C-like order is"
71+
" supported");
72+
}
73+
};
74+
75+
struct NumpyXReshapeParam : public dmlc::Parameter<NumpyXReshapeParam> {
76+
mxnet::TShape newshape;
77+
bool reverse;
78+
std::string order;
79+
DMLC_DECLARE_PARAMETER(NumpyXReshapeParam) {
80+
DMLC_DECLARE_FIELD(newshape)
81+
.describe("The new shape should be compatible with the original shape."
82+
" If an integer, then the result will be a 1-D array of that length."
83+
" One shape dimension can be -1. In this case, the value is inferred"
84+
" from the length of the array and remaining dimensions."
85+
" -2 to -6 are used for data manipulation."
86+
" -2 copy this dimension from the input to the output shape."
87+
" -3 will skip current dimension if and only if the current dim size is one."
88+
" -4 copy all remain of the input dimensions to the output shape."
89+
" -5 use the product of two consecutive dimensions of the input"
90+
" shape as the output."
91+
" -6 split one dimension of the input into two dimensions passed"
92+
" subsequent to -6 in the new shape.");
93+
DMLC_DECLARE_FIELD(reverse)
94+
.set_default(false)
95+
.describe("If true then the special values are inferred from right to left");
96+
DMLC_DECLARE_FIELD(order)
97+
.set_default("C")
98+
.describe("Read the elements of a using this index order, and place the elements into"
99+
" the reshaped array using this index order. 'C' means to read/write the elements"
100+
" using C-like index order, with the last axis index changing fastest,"
101+
" back to the first axis index changing slowest."
102+
" Note that currently only C-like order is"
103+
" supported");
104+
}
105+
};
106+
54107
template<typename xpu>
55108
void NumpyTranspose(const nnvm::NodeAttrs& attrs,
56109
const OpContext& ctx,
@@ -731,7 +784,6 @@ inline void HSplitOpBackward(const nnvm::NodeAttrs &attrs,
731784
}
732785
SplitOpBackwardImpl<xpu>(attrs, ctx, inputs, req, outputs, real_axis);
733786
}
734-
735787
} // namespace op
736788
} // namespace mxnet
737789

src/operator/numpy/np_matrix_op.cc

Lines changed: 186 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ DMLC_REGISTER_PARAMETER(NumpyTransposeParam);
3434
DMLC_REGISTER_PARAMETER(NumpyRollParam);
3535
DMLC_REGISTER_PARAMETER(NumpyMoveaxisParam);
3636
DMLC_REGISTER_PARAMETER(NumpyRot90Param);
37+
DMLC_REGISTER_PARAMETER(NumpyReshapeParam);
38+
DMLC_REGISTER_PARAMETER(NumpyXReshapeParam);
39+
3740

3841
bool NumpyTransposeShape(const nnvm::NodeAttrs& attrs,
3942
mxnet::ShapeVector *in_attrs,
@@ -126,26 +129,6 @@ NNVM_REGISTER_OP(_np_transpose)
126129
.add_argument("a", "NDArray-or-Symbol", "Source input")
127130
.add_arguments(NumpyTransposeParam::__FIELDS__());
128131

129-
struct NumpyReshapeParam : public dmlc::Parameter<NumpyReshapeParam> {
130-
mxnet::TShape newshape;
131-
std::string order;
132-
DMLC_DECLARE_PARAMETER(NumpyReshapeParam) {
133-
DMLC_DECLARE_FIELD(newshape)
134-
.describe("The new shape should be compatible with the original shape."
135-
" If an integer, then the result will be a 1-D array of that length."
136-
" One shape dimension can be -1. In this case, the value is inferred"
137-
" from the length of the array and remaining dimensions.");
138-
DMLC_DECLARE_FIELD(order)
139-
.set_default("C")
140-
.describe("Read the elements of a using this index order, and place the elements into"
141-
" the reshaped array using this index order. 'C' means to read/write the elements"
142-
" using C-like index order, with the last axis index changing fastest, back to the"
143-
" first axis index changing slowest. Note that currently only C-like order is"
144-
" supported");
145-
}
146-
};
147-
148-
DMLC_REGISTER_PARAMETER(NumpyReshapeParam);
149132

150133
bool NumpyReshapeInferShape(const mxnet::TShape& src, mxnet::TShape* dst) {
151134
if (shape_is_known(src) && shape_is_known(*dst)) {
@@ -202,6 +185,164 @@ bool NumpyReshapeShape(const nnvm::NodeAttrs& attrs,
202185
return success;
203186
}
204187

188+
bool NumpyXReshapeInferShape(const mxnet::TShape& src,
189+
const mxnet::TShape& target,
190+
mxnet::TShape* output,
191+
const std::string &default_error_msg) {
192+
bool target_shape_is_known = true;
193+
dim_t target_size = 1;
194+
for (int i = 0; i < target.ndim(); ++i) {
195+
if (target[i] < 0) {
196+
target_shape_is_known = false;
197+
target_size = -1;
198+
break;
199+
} else {
200+
target_size *= target[i];
201+
}
202+
}
203+
if (shape_is_known(src) && target_shape_is_known) {
204+
CHECK_EQ(src.Size(), target_size) << default_error_msg;
205+
*output = TShape(target.begin(), target.end());
206+
return true;
207+
} else if (!shape_is_known(src) || target.ndim() == -1) {
208+
return false;
209+
} else {
210+
int unknown_axis = -1;
211+
dim_t known_dim_size_prod = 1;
212+
std::vector<dim_t> output_shape_vector;
213+
int src_inx = 0;
214+
for (int i = 0; i < target.ndim(); ++i) {
215+
dim_t proposed_dim = target[i];
216+
CHECK(proposed_dim >= -6)
217+
<< "Dimension size must be greater than -6, received " << proposed_dim;
218+
if (proposed_dim == -1) {
219+
// infer the known dimension
220+
CHECK_LT(unknown_axis, 0)
221+
<< "One and only one dim can be inferred";
222+
unknown_axis = output_shape_vector.size();
223+
output_shape_vector.push_back(-1);
224+
src_inx++;
225+
} else if (proposed_dim == -2) {
226+
// copy the dimension from src to output
227+
CHECK_LT(src_inx, src.ndim())
228+
<< "Unmatching dimension of proposed new shape";
229+
known_dim_size_prod *= src[src_inx];
230+
output_shape_vector.push_back(src[src_inx++]);
231+
} else if (proposed_dim == -3) {
232+
// skip the source dimension if and only if it is one
233+
CHECK_EQ(src[src_inx], 1)
234+
<<"-3 index should only be used to skip dimension size 1";
235+
src_inx++;
236+
} else if (proposed_dim == -4) {
237+
// copy all remaining dims from source
238+
while (src_inx < src.ndim()) {
239+
known_dim_size_prod *= src[src_inx];
240+
const dim_t dn = src[src_inx++];
241+
output_shape_vector.push_back(dn);
242+
}
243+
} else if (proposed_dim == -5) {
244+
// merge two dims from source
245+
CHECK_LT(src_inx, src.ndim()-1)
246+
<<"Not enough dimensions left for the product";
247+
const dim_t d1 = src[src_inx++];
248+
const dim_t d2 = src[src_inx++];
249+
if (!mxnet::dim_size_is_known(d1) || !mxnet::dim_size_is_known(d2)) {
250+
CHECK_LT(unknown_axis, 0)
251+
<< "One and only one dim can be inferred";
252+
unknown_axis = output_shape_vector.size();
253+
output_shape_vector.push_back(-1);
254+
} else {
255+
known_dim_size_prod *= d1*d2;
256+
output_shape_vector.push_back(d1 * d2);
257+
}
258+
} else if (proposed_dim == -6) {
259+
// split the source dim s into two dims
260+
// read the left dim and then the right dim (either can be -1)
261+
CHECK_LT(i + 2, target.ndim());
262+
CHECK_LT(src_inx, src.ndim());
263+
const dim_t d0 = src[src_inx++];
264+
dim_t d1 = target[++i];
265+
dim_t d2 = target[++i];
266+
CHECK(d1 != -1 || d2 != -1) << "Split dims cannot both be -1.";
267+
if (d1 == -1 && d0 >= 0) d1 = d0 / d2; // d0 must be known to do this
268+
if (d2 == -1 && d0 >= 0) d2 = d0 / d1; // d0 must be known to do this
269+
CHECK(d1 * d2 == static_cast<dim_t>(d0) || static_cast<dim_t>(d0) == dim_t(-1))
270+
<<"Split dims " << d1 << ", " << d2 << " do not divide original dim " << d0;
271+
if (d1 == -1) {
272+
CHECK_LT(unknown_axis, 0)
273+
<< "One and only one dim can be inferred";
274+
unknown_axis = output_shape_vector.size();
275+
} else if (d2 == -1) {
276+
CHECK_LT(unknown_axis, 0)
277+
<< "One and only one dim can be inferred";
278+
unknown_axis = output_shape_vector.size() + 1;
279+
}
280+
known_dim_size_prod *= d0 == -1 ? 1 : d0;
281+
output_shape_vector.push_back(d1);
282+
output_shape_vector.push_back(d2);
283+
} else {
284+
// greater than 0, new shape
285+
known_dim_size_prod *= proposed_dim;
286+
output_shape_vector.push_back(proposed_dim);
287+
src_inx++;
288+
}
289+
}
290+
291+
if (unknown_axis > -1) {
292+
// if the input in zero size tensor, the output must be of known shape of zero size
293+
CHECK_NE(known_dim_size_prod, 0) << default_error_msg;
294+
CHECK(src.Size() % known_dim_size_prod == 0) << default_error_msg;
295+
output_shape_vector[unknown_axis] = src.Size() / known_dim_size_prod;
296+
}
297+
298+
*output = mxnet::TShape(output_shape_vector.begin(), output_shape_vector.end());
299+
CHECK_EQ((*output).Size(), src.Size()) << default_error_msg;
300+
return true;
301+
}
302+
}
303+
304+
bool NumpyXReshapeShape(const nnvm::NodeAttrs& attrs,
305+
mxnet::ShapeVector* in_attrs,
306+
mxnet::ShapeVector* out_attrs) {
307+
CHECK_EQ(in_attrs->size(), 1U) << "Input: [data]";
308+
CHECK_EQ(out_attrs->size(), 1U);
309+
const NumpyXReshapeParam& param = nnvm::get<NumpyXReshapeParam>(attrs.parsed);
310+
// sanity check
311+
bool has_unknown_dim_size = false;
312+
for (int i = 0; i < param.newshape.ndim(); ++i) {
313+
if (param.newshape[i] < 0) {
314+
CHECK_GE(param.newshape[i], -6)
315+
<< "Dimension size must be greater than or equal to -6";
316+
if (param.newshape[i] == -1) {
317+
CHECK(!has_unknown_dim_size) << "Can only specify one unknown dimension";
318+
has_unknown_dim_size = true;
319+
}
320+
}
321+
}
322+
323+
mxnet::TShape output_shape;
324+
bool success;
325+
std::stringstream ss;
326+
ss << "Cannot reshape array of shape " << in_attrs->at(0)
327+
<< " into shape " << param.newshape
328+
<< " , reverse = " << param.reverse;
329+
std::string err_msg = ss.str();
330+
if (!param.reverse) {
331+
success = NumpyXReshapeInferShape(in_attrs->at(0),
332+
param.newshape, &output_shape, err_msg);
333+
} else {
334+
mxnet::TShape rev_in_shape = in_attrs->at(0);
335+
mxnet::TShape rev_newshape = param.newshape;
336+
std::reverse(rev_in_shape.begin(), rev_in_shape.end());
337+
std::reverse(rev_newshape.begin(), rev_newshape.end());
338+
success = NumpyXReshapeInferShape(rev_in_shape,
339+
rev_newshape, &output_shape, err_msg);
340+
std::reverse(output_shape.begin(), output_shape.end());
341+
}
342+
SHAPE_ASSIGN_CHECK(*out_attrs, 0, output_shape);
343+
return success;
344+
}
345+
205346
NNVM_REGISTER_OP(_np_reshape)
206347
.describe(R"code()code" ADD_FILELINE)
207348
.add_alias("_npi_reshape")
@@ -227,6 +368,31 @@ NNVM_REGISTER_OP(_np_reshape)
227368
.add_argument("a", "NDArray-or-Symbol", "Array to be reshaped.")
228369
.add_arguments(NumpyReshapeParam::__FIELDS__());
229370

371+
372+
NNVM_REGISTER_OP(_npx_reshape)
373+
.describe(R"code()code" ADD_FILELINE)
374+
.set_num_inputs(1)
375+
.set_num_outputs(1)
376+
.set_attr_parser(ParamParser<NumpyXReshapeParam>)
377+
.set_attr<mxnet::FInferShape>("FInferShape", NumpyXReshapeShape)
378+
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
379+
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_reshape"})
380+
.set_attr<FCompute>("FCompute<cpu>", UnaryOp::IdentityCompute<cpu>)
381+
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
382+
[](const NodeAttrs& attrs) {
383+
return std::vector<std::pair<int, int> >{{0, 0}};
384+
})
385+
.set_attr<nnvm::FInplaceIdentity>("FInplaceIdentity",
386+
[](const NodeAttrs& attrs){
387+
return std::vector<bool>{true};
388+
})
389+
.set_attr<nnvm::FListInputNames>("FListInputNames",
390+
[](const NodeAttrs& attrs) {
391+
return std::vector<std::string>{"a"};
392+
})
393+
.add_argument("a", "NDArray-or-Symbol", "Array to be reshaped.")
394+
.add_arguments(NumpyXReshapeParam::__FIELDS__());
395+
230396
bool NumpySqueezeShape(const nnvm::NodeAttrs& attrs,
231397
mxnet::ShapeVector *in_attrs,
232398
mxnet::ShapeVector *out_attrs) {

src/operator/numpy/np_matrix_op.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,5 +109,8 @@ NNVM_REGISTER_OP(_npi_hsplit)
109109
NNVM_REGISTER_OP(_npi_hsplit_backward)
110110
.set_attr<FCompute>("FCompute<gpu>", HSplitOpBackward<gpu>);
111111

112+
NNVM_REGISTER_OP(_npx_reshape)
113+
.set_attr<FCompute>("FCompute<gpu>", UnaryOp::IdentityCompute<gpu>);
114+
112115
} // namespace op
113116
} // namespace mxnet

0 commit comments

Comments
 (0)