@@ -269,8 +269,10 @@ struct TransposeParam : public dmlc::Parameter<TransposeParam> {
269
269
* \param out output tensor
270
270
* \param row shape of dim 0 of input
271
271
* \param col shape of dim 1 of input
272
+ * \tparam DType Data type
273
+ * \tparam is_addto
272
274
*/
273
- template <typename DType>
275
+ template <typename DType, bool is_addto >
274
276
MSHADOW_XINLINE void Transpose2D (const DType *in, DType *out, index_t row, index_t col) {
275
277
// ensure cache line hits and prevent cache miss for any configuration
276
278
// L1 cache size to be utilized = 32kb = 2^15
@@ -282,7 +284,7 @@ MSHADOW_XINLINE void Transpose2D(const DType *in, DType *out, index_t row, index
282
284
// Block-size - 2^5 v 2^5 (32 v 32) with potential 4 pragma for loop unrolled
283
285
// blocksize * blocksize * num_threads = cache_size / dtype_size
284
286
// Instead of explicit unroll, let compiler figure out optimal unroll factor
285
- index_t blocksize = 32 ;
287
+ const index_t blocksize = 32 ;
286
288
287
289
// collapse 2 parallelizes 2 for loops
288
290
// inner 2 for loops aren't parallelized to prevent cache miss
@@ -299,14 +301,25 @@ MSHADOW_XINLINE void Transpose2D(const DType *in, DType *out, index_t row, index
299
301
// transpose the block
300
302
for (index_t a = j; (a < blocksize + j) && (a < col); ++a) {
301
303
for (index_t b = i; (b < blocksize + i) && (b < row); ++b) {
302
- out[a * row + b] = in[b * col + a];
304
+ if (!is_addto) {
305
+ out[a * row + b] = in[b * col + a];
306
+ } else {
307
+ out[a * row + b] += in[b * col + a];
308
+ }
303
309
}
304
310
}
305
311
}
306
312
}
307
313
}
308
314
309
- template <typename xpu>
315
+ inline bool IsIdentityTranspose (const TShape& axes) {
316
+ for (dim_t i = 0 ; i < axes.ndim (); i++) {
317
+ if (axes[i] != i) return false ;
318
+ }
319
+ return true ;
320
+ }
321
+
322
+ template <typename xpu, bool is_addto = false >
310
323
void TransposeImpl (RunContext ctx,
311
324
const TBlob& src,
312
325
const TBlob& ret,
@@ -323,62 +336,79 @@ void TransposeImpl(RunContext ctx,
323
336
// Example: (0, 2, 3, 1) or (0, 3, 1, 2), but not (0, 2, 1, 3).
324
337
if (isPseudo2DTranspose (axes)) {
325
338
MSHADOW_TYPE_SWITCH (ret.type_flag_ , DType, {
326
- transpose_pseudo2D<DType>(ret, src, axes, s);
339
+ transpose_pseudo2D<DType, is_addto >(ret, src, axes, s);
327
340
});
328
341
return ;
329
342
}
330
343
#endif
344
+ // Special handle the identity case
345
+ if (IsIdentityTranspose (axes)) {
346
+ MSHADOW_TYPE_SWITCH (ret.type_flag_ , DType, {
347
+ Tensor<xpu, 1 , DType> in = src.get_with_shape <xpu, 1 , DType>(mshadow::Shape1 (src.Size ()), s);
348
+ Tensor<xpu, 1 , DType> out = ret.get_with_shape <xpu, 1 , DType>(mshadow::Shape1 (ret.Size ()), s);
349
+ if (!is_addto) {
350
+ // Use memcpy to accelerate the speed
351
+ Copy (out, in, s);
352
+ } else {
353
+ mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::identity, kAddTo >, xpu>::Launch (
354
+ s, ret.Size (), out.dptr_ , in.dptr_ );
355
+ }
356
+ });
357
+ return ;
358
+ }
359
+ // Handle the general transpose case
331
360
MSHADOW_TYPE_SWITCH (ret.type_flag_ , DType, {
332
361
switch (axes.ndim ()) {
333
- case 0 : {
334
- Tensor<xpu, 1 , DType> in = src.get_with_shape <xpu, 1 , DType>(mshadow::Shape1 (1 ), s);
335
- Tensor<xpu, 1 , DType> out = ret.get_with_shape <xpu, 1 , DType>(mshadow::Shape1 (1 ), s);
336
- Copy (out, in, s);
337
- break ;
338
- }
339
- case 1 : {
340
- Tensor<xpu, 1 , DType> in = src.get <xpu, 1 , DType>(s);
341
- Tensor<xpu, 1 , DType> out = ret.get <xpu, 1 , DType>(s);
342
- Copy (out, in, s);
343
- break ;
344
- }
345
362
case 2 : {
346
- mshadow::Tensor<xpu, 2 , DType> in = src.FlatTo2D <xpu, DType>(s);
347
- mshadow::Tensor<xpu, 2 , DType> out = ret.FlatTo2D <xpu, DType>(s);
348
-
349
- if (axes[0 ] == 1 && axes[1 ] == 0 ) {
350
- if (ctx.get_ctx ().dev_mask () == cpu::kDevMask ) {
351
- Transpose2D<DType>(in.dptr_ , out.dptr_ , in.shape_ [0 ], in.shape_ [1 ]);
352
- } else {
353
- out = in.T ();
354
- }
363
+ Tensor<xpu, 2 , DType> in = src.get <xpu, 2 , DType>(s);
364
+ Tensor<xpu, 2 , DType> out = ret.get <xpu, 2 , DType>(s);
365
+ if (ctx.get_ctx ().dev_mask () == cpu::kDevMask ) {
366
+ Transpose2D<DType, is_addto>(in.dptr_ , out.dptr_ , in.shape_ [0 ], in.shape_ [1 ]);
355
367
} else {
356
- Copy (out, in, s);
368
+ LOG (FATAL) << " Not Implemented. We should never reach here because the 2D case "
369
+ " in GPU has been covered by transpose_pseudo2D."
370
+ " Report an issue in Github." ;
357
371
}
358
372
break ;
359
373
}
360
374
case 3 : {
361
375
Tensor<xpu, 3 , DType> in = src.get <xpu, 3 , DType>(s);
362
376
Tensor<xpu, 3 , DType> out = ret.get <xpu, 3 , DType>(s);
363
- out = transpose (in, axes.get <3 >());
377
+ if (!is_addto) {
378
+ out = transpose (in, axes.get <3 >());
379
+ } else {
380
+ out += transpose (in, axes.get <3 >());
381
+ }
364
382
break ;
365
383
}
366
384
case 4 : {
367
385
Tensor<xpu, 4 , DType> in = src.get <xpu, 4 , DType>(s);
368
386
Tensor<xpu, 4 , DType> out = ret.get <xpu, 4 , DType>(s);
369
- out = transpose (in, axes.get <4 >());
387
+ if (!is_addto) {
388
+ out = transpose (in, axes.get <4 >());
389
+ } else {
390
+ out += transpose (in, axes.get <4 >());
391
+ }
370
392
break ;
371
393
}
372
394
case 5 : {
373
395
Tensor<xpu, 5 , DType> in = src.get <xpu, 5 , DType>(s);
374
396
Tensor<xpu, 5 , DType> out = ret.get <xpu, 5 , DType>(s);
375
- out = transpose (in, axes.get <5 >());
397
+ if (!is_addto) {
398
+ out = transpose (in, axes.get <5 >());
399
+ } else {
400
+ out += transpose (in, axes.get <5 >());
401
+ }
376
402
break ;
377
403
}
378
404
case 6 : {
379
405
Tensor<xpu, 6 , DType> in = src.get <xpu, 6 , DType>(s);
380
406
Tensor<xpu, 6 , DType> out = ret.get <xpu, 6 , DType>(s);
381
- out = transpose (in, axes.get <6 >());
407
+ if (!is_addto) {
408
+ out = transpose (in, axes.get <6 >());
409
+ } else {
410
+ out += transpose (in, axes.get <6 >());
411
+ }
382
412
break ;
383
413
}
384
414
default :
@@ -399,15 +429,21 @@ void Transpose(const nnvm::NodeAttrs& attrs,
399
429
return ;
400
430
}
401
431
const TransposeParam& param = nnvm::get<TransposeParam>(attrs.parsed );
402
- CHECK_EQ (req[0 ], kWriteTo ) << " Transpose does not support kWriteInplace and kAddTo" ;
432
+ CHECK (req[0 ] == kWriteTo || req[0 ] == kAddTo )
433
+ << " Transpose only supports kNullOp, kWriteTo and kAddTo" ;
434
+ mxnet::TShape axes;
403
435
if (param.axes .ndim () == 0 ) {
404
- mxnet::TShape axes (inputs[0 ].ndim (), -1 );
436
+ axes = mxnet::TShape (inputs[0 ].ndim (), -1 );
405
437
for (int i = 0 ; i < axes.ndim (); ++i) {
406
438
axes[i] = axes.ndim () - 1 - i;
407
439
}
408
- TransposeImpl<xpu>(ctx.run_ctx , inputs[0 ], outputs[0 ], axes);
409
440
} else {
410
- TransposeImpl<xpu>(ctx.run_ctx , inputs[0 ], outputs[0 ], param.axes );
441
+ axes = common::CanonicalizeAxes (param.axes );
442
+ }
443
+ if (req[0 ] == kAddTo ) {
444
+ TransposeImpl<xpu, true >(ctx.run_ctx , inputs[0 ], outputs[0 ], axes);
445
+ } else {
446
+ TransposeImpl<xpu, false >(ctx.run_ctx , inputs[0 ], outputs[0 ], axes);
411
447
}
412
448
}
413
449
0 commit comments