Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 9c33c4b

Browse files
committed
Fixes from review
1 parent 40c0d00 commit 9c33c4b

File tree

3 files changed

+91
-40
lines changed

3 files changed

+91
-40
lines changed

src/common/cuda_utils.cc

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* Copyright (c) 2017 by Contributors
22+
* \file cuda_utils.cc
23+
* \brief Common CUDA utilities.
24+
*/
25+
26+
#include <mxnet/base.h>
27+
#include <mshadow/base.h>
28+
#include "cuda_utils.h"
29+
30+
#if MXNET_USE_CUDA
31+
32+
namespace mxnet {
33+
namespace common {
34+
namespace cuda {
35+
36+
int get_load_type(size_t N) {
37+
using namespace mshadow;
38+
if (N % 8 == 0) {
39+
return kFloat64;
40+
} else if (N % 4 == 0) {
41+
return kFloat32;
42+
} else if (N % 2 == 0) {
43+
return kFloat16;
44+
} else {
45+
return kInt8;
46+
}
47+
}
48+
} // namespace cuda
49+
} // namespace common
50+
} // namespace mxnet
51+
52+
#endif // MXNET_USE_CUDA

src/common/cuda_utils.h

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
/*!
2121
* Copyright (c) 2015 by Contributors
2222
* \file cuda_utils.h
23-
* \brief CUDA debugging utilities.
23+
* \brief Common CUDA utilities.
2424
*/
2525
#ifndef MXNET_COMMON_CUDA_UTILS_H_
2626
#define MXNET_COMMON_CUDA_UTILS_H_
@@ -326,6 +326,15 @@ class DeviceStore {
326326
bool restore_;
327327
};
328328

329+
/*! \brief Get the largest datatype suitable to read
330+
* requested number of bytes.
331+
*
332+
* \input Number of bytes to be read
333+
* \return mshadow representation of type that could
334+
* be used for reading
335+
*/
336+
int get_load_type(size_t N);
337+
329338
} // namespace cuda
330339
} // namespace common
331340
} // namespace mxnet
@@ -550,7 +559,7 @@ static inline __device__ void atomicAdd(double *address, double val) {
550559
// Overload atomicAdd for half precision
551560
// Taken from:
552561
// https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh
553-
#if defined(__CUDA_ARCH__)
562+
#ifdef __CUDACC__
554563
static inline __device__ void atomicAdd(mshadow::half::half_t *address,
555564
mshadow::half::half_t val) {
556565
unsigned int *address_as_ui =
@@ -615,6 +624,28 @@ __device__ inline DType ldg(const DType* address) {
615624
return *address;
616625
#endif
617626
}
618-
#endif
627+
628+
template <typename OP, typename T>
629+
__device__ inline T warp_reduce(T value, OP redfun) {
630+
value = redfun(value, __shfl_down_sync(0xffffffff, value, 16));
631+
value = redfun(value, __shfl_down_sync(0xffffffff, value, 8));
632+
value = redfun(value, __shfl_down_sync(0xffffffff, value, 4));
633+
value = redfun(value, __shfl_down_sync(0xffffffff, value, 2));
634+
value = redfun(value, __shfl_down_sync(0xffffffff, value, 1));
635+
return value;
636+
}
637+
638+
template <typename OP>
639+
__device__ inline mshadow::half::half_t warp_reduce(mshadow::half::half_t value, OP redfun) {
640+
float v = static_cast<float>(value);
641+
v = redfun(v, __shfl_down_sync(0xffffffff, v, 16));
642+
v = redfun(v, __shfl_down_sync(0xffffffff, v, 8));
643+
v = redfun(v, __shfl_down_sync(0xffffffff, v, 4));
644+
v = redfun(v, __shfl_down_sync(0xffffffff, v, 2));
645+
v = redfun(v, __shfl_down_sync(0xffffffff, v, 1));
646+
return mshadow::half::half_t(v);
647+
}
648+
649+
#endif // __CUDACC__
619650

620651
#endif // MXNET_COMMON_CUDA_UTILS_H_

src/operator/nn/softmax-inl.h

Lines changed: 5 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "../mxnet_op.h"
3535
#include "../operator_common.h"
3636
#include "../tensor/broadcast_reduce_op.h"
37+
#include "../../common/cuda_utils.h"
3738

3839
namespace mxnet {
3940
namespace op {
@@ -312,27 +313,6 @@ __global__ void softmax_compute_kernel(DType *in, OType *out, IType *length,
312313

313314
const int softmax_threads_per_block = 512;
314315

315-
template <typename OP, typename T>
316-
__device__ inline T warp_reduce(T value, OP redfun) {
317-
value = redfun(value, __shfl_down_sync(0xffffffff, value, 16));
318-
value = redfun(value, __shfl_down_sync(0xffffffff, value, 8));
319-
value = redfun(value, __shfl_down_sync(0xffffffff, value, 4));
320-
value = redfun(value, __shfl_down_sync(0xffffffff, value, 2));
321-
value = redfun(value, __shfl_down_sync(0xffffffff, value, 1));
322-
return value;
323-
}
324-
325-
template <typename OP>
326-
__device__ inline mshadow::half::half_t warp_reduce(mshadow::half::half_t value, OP redfun) {
327-
float v = static_cast<float>(value);
328-
v = redfun(v, __shfl_down_sync(0xffffffff, v, 16));
329-
v = redfun(v, __shfl_down_sync(0xffffffff, v, 8));
330-
v = redfun(v, __shfl_down_sync(0xffffffff, v, 4));
331-
v = redfun(v, __shfl_down_sync(0xffffffff, v, 2));
332-
v = redfun(v, __shfl_down_sync(0xffffffff, v, 1));
333-
return mshadow::half::half_t(v);
334-
}
335-
336316
template<typename OP, bool negate, typename AType, typename LType,
337317
typename DType, typename OType, typename IType>
338318
__global__ void softmax_stride1_compute_kernel(const DType *in, OType *out, IType *length,
@@ -356,7 +336,7 @@ __global__ void softmax_stride1_compute_kernel(const DType *in, OType *out, ITyp
356336
// the division by zero warning generated for such invalid cases.
357337
const int row_length = entries_per_load > 0 ? M / entries_per_load : 0;
358338

359-
const LType * in_aligned = reinterpret_cast<const LType *>(in);
339+
const LType* in_aligned = reinterpret_cast<const LType*>(in);
360340
size_t base = my_row * row_length;
361341

362342
for (index_t i = my_id; i < row_length; i += threads_per_row) {
@@ -420,7 +400,7 @@ __global__ void softmax_stride1_compute_kernel(const DType *in, OType *out, ITyp
420400
}
421401
__syncthreads();
422402

423-
LType * out_aligned = reinterpret_cast<LType *>(out);
403+
LType* out_aligned = reinterpret_cast<LType*>(out);
424404

425405
for (index_t i = my_id; i < row_length; i += threads_per_row) {
426406
out_aligned[base + i] = persistent_storage[my_local_row * row_length + i];
@@ -429,18 +409,6 @@ __global__ void softmax_stride1_compute_kernel(const DType *in, OType *out, ITyp
429409

430410
namespace {
431411

432-
int get_load_type(size_t N) {
433-
if (N % 8 == 0) {
434-
return kFloat64;
435-
} else if (N % 4 == 0) {
436-
return kFloat32;
437-
} else if (N % 2 == 0) {
438-
return kFloat16;
439-
} else {
440-
return kInt8;
441-
}
442-
}
443-
444412
int get_rows_per_block(size_t N) {
445413
const int warp_size = 32;
446414
// How many read instructions should 1 thread at least do
@@ -479,9 +447,9 @@ inline void Softmax(Stream<gpu> *s, DType *in, OType *out, IType *length,
479447
// Using 20 kB of shared memory for persistent storage in the optimized case
480448
const size_t max_opt_M = 20 * 1024 / DSize;
481449
if (stride[axis] == 1 &&
482-
M <= max_opt_M &&
450+
static_cast<size_t>(M) <= max_opt_M &&
483451
std::is_same<DType, OType>::value) {
484-
int ltype = get_load_type(M * sizeof(DType));
452+
int ltype = mxnet::common::cuda::get_load_type(M * sizeof(DType));
485453
MSHADOW_TYPE_SWITCH(ltype, LType, {
486454
int rows_per_block = get_rows_per_block(M * sizeof(DType) / sizeof(LType));
487455
int nblocks = (N + rows_per_block - 1) / rows_per_block;

0 commit comments

Comments
 (0)