34
34
#include " ../mxnet_op.h"
35
35
#include " ../operator_common.h"
36
36
#include " ../tensor/broadcast_reduce_op.h"
37
+ #include " ../../common/cuda_utils.h"
37
38
38
39
namespace mxnet {
39
40
namespace op {
@@ -312,27 +313,6 @@ __global__ void softmax_compute_kernel(DType *in, OType *out, IType *length,
312
313
313
314
const int softmax_threads_per_block = 512 ;
314
315
315
- template <typename OP, typename T>
316
- __device__ inline T warp_reduce (T value, OP redfun) {
317
- value = redfun (value, __shfl_down_sync (0xffffffff , value, 16 ));
318
- value = redfun (value, __shfl_down_sync (0xffffffff , value, 8 ));
319
- value = redfun (value, __shfl_down_sync (0xffffffff , value, 4 ));
320
- value = redfun (value, __shfl_down_sync (0xffffffff , value, 2 ));
321
- value = redfun (value, __shfl_down_sync (0xffffffff , value, 1 ));
322
- return value;
323
- }
324
-
325
- template <typename OP>
326
- __device__ inline mshadow::half::half_t warp_reduce (mshadow::half::half_t value, OP redfun) {
327
- float v = static_cast <float >(value);
328
- v = redfun (v, __shfl_down_sync (0xffffffff , v, 16 ));
329
- v = redfun (v, __shfl_down_sync (0xffffffff , v, 8 ));
330
- v = redfun (v, __shfl_down_sync (0xffffffff , v, 4 ));
331
- v = redfun (v, __shfl_down_sync (0xffffffff , v, 2 ));
332
- v = redfun (v, __shfl_down_sync (0xffffffff , v, 1 ));
333
- return mshadow::half::half_t (v);
334
- }
335
-
336
316
template <typename OP, bool negate, typename AType, typename LType,
337
317
typename DType, typename OType, typename IType>
338
318
__global__ void softmax_stride1_compute_kernel (const DType *in, OType *out, IType *length,
@@ -356,7 +336,7 @@ __global__ void softmax_stride1_compute_kernel(const DType *in, OType *out, ITyp
356
336
// the division by zero warning generated for such invalid cases.
357
337
const int row_length = entries_per_load > 0 ? M / entries_per_load : 0 ;
358
338
359
- const LType * in_aligned = reinterpret_cast <const LType *>(in);
339
+ const LType* in_aligned = reinterpret_cast <const LType*>(in);
360
340
size_t base = my_row * row_length;
361
341
362
342
for (index_t i = my_id; i < row_length; i += threads_per_row) {
@@ -420,7 +400,7 @@ __global__ void softmax_stride1_compute_kernel(const DType *in, OType *out, ITyp
420
400
}
421
401
__syncthreads ();
422
402
423
- LType * out_aligned = reinterpret_cast <LType *>(out);
403
+ LType* out_aligned = reinterpret_cast <LType*>(out);
424
404
425
405
for (index_t i = my_id; i < row_length; i += threads_per_row) {
426
406
out_aligned[base + i] = persistent_storage[my_local_row * row_length + i];
@@ -429,18 +409,6 @@ __global__ void softmax_stride1_compute_kernel(const DType *in, OType *out, ITyp
429
409
430
410
namespace {
431
411
432
- int get_load_type (size_t N) {
433
- if (N % 8 == 0 ) {
434
- return kFloat64 ;
435
- } else if (N % 4 == 0 ) {
436
- return kFloat32 ;
437
- } else if (N % 2 == 0 ) {
438
- return kFloat16 ;
439
- } else {
440
- return kInt8 ;
441
- }
442
- }
443
-
444
412
int get_rows_per_block (size_t N) {
445
413
const int warp_size = 32 ;
446
414
// How many read instructions should 1 thread at least do
@@ -479,9 +447,9 @@ inline void Softmax(Stream<gpu> *s, DType *in, OType *out, IType *length,
479
447
// Using 20 kB of shared memory for persistent storage in the optimized case
480
448
const size_t max_opt_M = 20 * 1024 / DSize;
481
449
if (stride[axis] == 1 &&
482
- M <= max_opt_M &&
450
+ static_cast < size_t >(M) <= max_opt_M &&
483
451
std::is_same<DType, OType>::value) {
484
- int ltype = get_load_type (M * sizeof (DType));
452
+ int ltype = mxnet::common::cuda:: get_load_type (M * sizeof (DType));
485
453
MSHADOW_TYPE_SWITCH (ltype, LType, {
486
454
int rows_per_block = get_rows_per_block (M * sizeof (DType) / sizeof (LType));
487
455
int nblocks = (N + rows_per_block - 1 ) / rows_per_block;
0 commit comments