Skip to content

Commit d521cbb

Browse files
sxjscienceptrendx
authored andcommitted
[Bugfix] [Numpy] Add kAddTo and kNullOp to Transpose (apache#16979)
* update Check for repeated axes enable addto to transpose fix fix fix fix remove unused ndim Update pseudo2DTranspose_op-inl.cuh Update pseudo2DTranspose_op-inl.cuh Update pseudo2DTranspose_op-inl.cuh fix Update pseudo2DTranspose_op-inl.cuh try to fix Update pseudo2DTranspose_op-inl.cuh Update pseudo2DTranspose_op-inl.cuh Update pseudo2DTranspose_op-inl.cuh fix Update np_matrix_op.cc Update test_numpy_op.py update test case fix implementation fix bug update fix bug Update pseudo2DTranspose_op-inl.cuh fix fix Update test_numpy_op.py * Fix bug * fix docstring * try to address comment * no need to change this line * Fix bug * address comments * address comment
1 parent fd0ea16 commit d521cbb

File tree

6 files changed

+206
-177
lines changed

6 files changed

+206
-177
lines changed

src/operator/numpy/np_matrix_op-inl.h

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -119,16 +119,22 @@ void NumpyTranspose(const nnvm::NodeAttrs& attrs,
119119
const std::vector<OpReqType>& req,
120120
const std::vector<TBlob>& outputs) {
121121
const NumpyTransposeParam& param = nnvm::get<NumpyTransposeParam>(attrs.parsed);
122-
CHECK_EQ(req[0], kWriteTo) << "Transpose does not support inplace";
122+
if (req[0] == kNullOp) return;
123+
CHECK(req[0] == kWriteTo || req[0] == kAddTo)
124+
<< "Transpose only supports kWriteTo, kNullOp and kAddTo";
125+
mxnet::TShape axes;
123126
if (ndim_is_known(param.axes)) {
124-
mxnet::TShape axes = common::CanonicalizeAxes(param.axes);
125-
TransposeImpl<xpu>(ctx.run_ctx, inputs[0], outputs[0], axes);
127+
axes = common::CanonicalizeAxes(param.axes);
126128
} else {
127-
mxnet::TShape axes(inputs[0].ndim(), -1);
129+
axes = mxnet::TShape(inputs[0].ndim(), -1);
128130
for (int i = 0; i < axes.ndim(); ++i) {
129131
axes[i] = axes.ndim() - 1 - i;
130132
}
131-
TransposeImpl<xpu>(ctx.run_ctx, inputs[0], outputs[0], axes);
133+
}
134+
if (req[0] == kAddTo) {
135+
TransposeImpl<xpu, true>(ctx.run_ctx, inputs[0], outputs[0], axes);
136+
} else {
137+
TransposeImpl<xpu, false>(ctx.run_ctx, inputs[0], outputs[0], axes);
132138
}
133139
}
134140

src/operator/numpy/np_matrix_op.cc

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
*/
2525

2626
#include <vector>
27+
#include <set>
2728
#include "./np_matrix_op-inl.h"
2829
#include "../nn/concat-inl.h"
2930

@@ -65,8 +66,13 @@ bool NumpyTransposeShape(const nnvm::NodeAttrs& attrs,
6566
mxnet::TShape ret(ndim, -1);
6667

6768
if (ndim_is_known(param.axes)) {
68-
CHECK_EQ(ndim, param.axes.ndim());
69+
CHECK_EQ(ndim, param.axes.ndim())
70+
<< "The number of axes does not match the dimension of the tensor. axes = "
71+
<< param.axes << ", input tensor shape = " << shp;
6972
mxnet::TShape axes = common::CanonicalizeAxes(param.axes);
73+
std::set<dim_t> axes_set(axes.begin(), axes.end());
74+
CHECK_EQ(axes_set.size(), axes.ndim()) << "Repeated axis in transpose. param.axes = "
75+
<< param.axes;
7076
if (ndim_is_known(shp)) {
7177
for (int i = 0; i < ndim; ++i) {
7278
ret[i] = shp[axes[i]];
@@ -115,9 +121,9 @@ NNVM_REGISTER_OP(_np_transpose)
115121
}
116122
std::ostringstream os;
117123
os << axes;
118-
return MakeNonlossGradNode("transpose", n, ograds, {}, {{"axes", os.str()}});
124+
return MakeNonlossGradNode("_np_transpose", n, ograds, {}, {{"axes", os.str()}});
119125
} else {
120-
return MakeNonlossGradNode("transpose", n, ograds, {},
126+
return MakeNonlossGradNode("_np_transpose", n, ograds, {},
121127
std::unordered_map<std::string, std::string>());
122128
}
123129
})

src/operator/tensor/matrix_op-inl.h

Lines changed: 71 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -269,8 +269,10 @@ struct TransposeParam : public dmlc::Parameter<TransposeParam> {
269269
* \param out output tensor
270270
* \param row shape of dim 0 of input
271271
* \param col shape of dim 1 of input
272+
* \tparam DType Data type
273+
* \tparam is_addto
272274
*/
273-
template<typename DType>
275+
template<typename DType, bool is_addto>
274276
MSHADOW_XINLINE void Transpose2D(const DType *in, DType *out, index_t row, index_t col) {
275277
// ensure cache line hits and prevent cache miss for any configuration
276278
// L1 cache size to be utilized = 32kb = 2^15
@@ -282,7 +284,7 @@ MSHADOW_XINLINE void Transpose2D(const DType *in, DType *out, index_t row, index
282284
// Block-size - 2^5 v 2^5 (32 v 32) with potential 4 pragma for loop unrolled
283285
// blocksize * blocksize * num_threads = cache_size / dtype_size
284286
// Instead of explicit unroll, let compiler figure out optimal unroll factor
285-
index_t blocksize = 32;
287+
const index_t blocksize = 32;
286288

287289
// collapse 2 parallelizes 2 for loops
288290
// inner 2 for loops aren't parallelized to prevent cache miss
@@ -299,14 +301,25 @@ MSHADOW_XINLINE void Transpose2D(const DType *in, DType *out, index_t row, index
299301
// transpose the block
300302
for (index_t a = j; (a < blocksize + j) && (a < col); ++a) {
301303
for (index_t b = i; (b < blocksize + i) && (b < row); ++b) {
302-
out[a * row + b] = in[b * col + a];
304+
if (!is_addto) {
305+
out[a * row + b] = in[b * col + a];
306+
} else {
307+
out[a * row + b] += in[b * col + a];
308+
}
303309
}
304310
}
305311
}
306312
}
307313
}
308314

309-
template<typename xpu>
315+
inline bool IsIdentityTranspose(const TShape& axes) {
316+
for (dim_t i = 0; i < axes.ndim(); i++) {
317+
if (axes[i] != i) return false;
318+
}
319+
return true;
320+
}
321+
322+
template<typename xpu, bool is_addto = false>
310323
void TransposeImpl(RunContext ctx,
311324
const TBlob& src,
312325
const TBlob& ret,
@@ -323,62 +336,79 @@ void TransposeImpl(RunContext ctx,
323336
// Example: (0, 2, 3, 1) or (0, 3, 1, 2), but not (0, 2, 1, 3).
324337
if (isPseudo2DTranspose(axes)) {
325338
MSHADOW_TYPE_SWITCH(ret.type_flag_, DType, {
326-
transpose_pseudo2D<DType>(ret, src, axes, s);
339+
transpose_pseudo2D<DType, is_addto>(ret, src, axes, s);
327340
});
328341
return;
329342
}
330343
#endif
344+
// Special handle the identity case
345+
if (IsIdentityTranspose(axes)) {
346+
MSHADOW_TYPE_SWITCH(ret.type_flag_, DType, {
347+
Tensor<xpu, 1, DType> in = src.get_with_shape<xpu, 1, DType>(mshadow::Shape1(src.Size()), s);
348+
Tensor<xpu, 1, DType> out = ret.get_with_shape<xpu, 1, DType>(mshadow::Shape1(ret.Size()), s);
349+
if (!is_addto) {
350+
// Use memcpy to accelerate the speed
351+
Copy(out, in, s);
352+
} else {
353+
mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::identity, kAddTo>, xpu>::Launch(
354+
s, ret.Size(), out.dptr_, in.dptr_);
355+
}
356+
});
357+
return;
358+
}
359+
// Handle the general transpose case
331360
MSHADOW_TYPE_SWITCH(ret.type_flag_, DType, {
332361
switch (axes.ndim()) {
333-
case 0: {
334-
Tensor<xpu, 1, DType> in = src.get_with_shape<xpu, 1, DType>(mshadow::Shape1(1), s);
335-
Tensor<xpu, 1, DType> out = ret.get_with_shape<xpu, 1, DType>(mshadow::Shape1(1), s);
336-
Copy(out, in, s);
337-
break;
338-
}
339-
case 1: {
340-
Tensor<xpu, 1, DType> in = src.get<xpu, 1, DType>(s);
341-
Tensor<xpu, 1, DType> out = ret.get<xpu, 1, DType>(s);
342-
Copy(out, in, s);
343-
break;
344-
}
345362
case 2: {
346-
mshadow::Tensor<xpu, 2, DType> in = src.FlatTo2D<xpu, DType>(s);
347-
mshadow::Tensor<xpu, 2, DType> out = ret.FlatTo2D<xpu, DType>(s);
348-
349-
if (axes[0] == 1 && axes[1] == 0) {
350-
if (ctx.get_ctx().dev_mask() == cpu::kDevMask) {
351-
Transpose2D<DType>(in.dptr_, out.dptr_, in.shape_[0], in.shape_[1]);
352-
} else {
353-
out = in.T();
354-
}
363+
Tensor<xpu, 2, DType> in = src.get<xpu, 2, DType>(s);
364+
Tensor<xpu, 2, DType> out = ret.get<xpu, 2, DType>(s);
365+
if (ctx.get_ctx().dev_mask() == cpu::kDevMask) {
366+
Transpose2D<DType, is_addto>(in.dptr_, out.dptr_, in.shape_[0], in.shape_[1]);
355367
} else {
356-
Copy(out, in, s);
368+
LOG(FATAL) << "Not Implemented. We should never reach here because the 2D case "
369+
"in GPU has been covered by transpose_pseudo2D."
370+
" Report an issue in Github.";
357371
}
358372
break;
359373
}
360374
case 3: {
361375
Tensor<xpu, 3, DType> in = src.get<xpu, 3, DType>(s);
362376
Tensor<xpu, 3, DType> out = ret.get<xpu, 3, DType>(s);
363-
out = transpose(in, axes.get<3>());
377+
if (!is_addto) {
378+
out = transpose(in, axes.get<3>());
379+
} else {
380+
out += transpose(in, axes.get<3>());
381+
}
364382
break;
365383
}
366384
case 4: {
367385
Tensor<xpu, 4, DType> in = src.get<xpu, 4, DType>(s);
368386
Tensor<xpu, 4, DType> out = ret.get<xpu, 4, DType>(s);
369-
out = transpose(in, axes.get<4>());
387+
if (!is_addto) {
388+
out = transpose(in, axes.get<4>());
389+
} else {
390+
out += transpose(in, axes.get<4>());
391+
}
370392
break;
371393
}
372394
case 5: {
373395
Tensor<xpu, 5, DType> in = src.get<xpu, 5, DType>(s);
374396
Tensor<xpu, 5, DType> out = ret.get<xpu, 5, DType>(s);
375-
out = transpose(in, axes.get<5>());
397+
if (!is_addto) {
398+
out = transpose(in, axes.get<5>());
399+
} else {
400+
out += transpose(in, axes.get<5>());
401+
}
376402
break;
377403
}
378404
case 6: {
379405
Tensor<xpu, 6, DType> in = src.get<xpu, 6, DType>(s);
380406
Tensor<xpu, 6, DType> out = ret.get<xpu, 6, DType>(s);
381-
out = transpose(in, axes.get<6>());
407+
if (!is_addto) {
408+
out = transpose(in, axes.get<6>());
409+
} else {
410+
out += transpose(in, axes.get<6>());
411+
}
382412
break;
383413
}
384414
default:
@@ -399,15 +429,21 @@ void Transpose(const nnvm::NodeAttrs& attrs,
399429
return;
400430
}
401431
const TransposeParam& param = nnvm::get<TransposeParam>(attrs.parsed);
402-
CHECK_EQ(req[0], kWriteTo) << "Transpose does not support kWriteInplace and kAddTo";
432+
CHECK(req[0] == kWriteTo || req[0] == kAddTo)
433+
<< "Transpose only supports kNullOp, kWriteTo and kAddTo";
434+
mxnet::TShape axes;
403435
if (param.axes.ndim() == 0) {
404-
mxnet::TShape axes(inputs[0].ndim(), -1);
436+
axes = mxnet::TShape(inputs[0].ndim(), -1);
405437
for (int i = 0; i < axes.ndim(); ++i) {
406438
axes[i] = axes.ndim() - 1 - i;
407439
}
408-
TransposeImpl<xpu>(ctx.run_ctx, inputs[0], outputs[0], axes);
409440
} else {
410-
TransposeImpl<xpu>(ctx.run_ctx, inputs[0], outputs[0], param.axes);
441+
axes = common::CanonicalizeAxes(param.axes);
442+
}
443+
if (req[0] == kAddTo) {
444+
TransposeImpl<xpu, true>(ctx.run_ctx, inputs[0], outputs[0], axes);
445+
} else {
446+
TransposeImpl<xpu, false>(ctx.run_ctx, inputs[0], outputs[0], axes);
411447
}
412448
}
413449

src/operator/tensor/matrix_op.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -283,11 +283,12 @@ static void TransposeComputeExCPU(const nnvm::NodeAttrs& attrs,
283283
return;
284284
}
285285
const TransposeParam& param = nnvm::get<TransposeParam>(attrs.parsed);
286-
CHECK_EQ(req[0], kWriteTo) << "Transpose does not support kWriteInplace and kAddTo";
286+
CHECK(req[0] == kWriteTo || req[0] == kAddTo) <<
287+
"Transpose only supports kNullOp, kWriteTo and kAddTo";
287288
CHECK_EQ(inputs.size(), 1U);
288289
CHECK_EQ(outputs.size(), 1U);
289290

290-
if (SupportMKLDNNTranspose(param, inputs[0])) {
291+
if (SupportMKLDNNTranspose(param, inputs[0]) && req[0] == kWriteTo) {
291292
MKLDNNTransposeForward(attrs, ctx, inputs[0], req[0], outputs[0]);
292293
return;
293294
}

0 commit comments

Comments
 (0)