|
| 1 | +// Copyright © 2025 Apple Inc. |
| 2 | + |
| 3 | +#include "mlx/backend/cuda/device.h" |
| 4 | +#include "mlx/backend/cuda/kernel_utils.cuh" |
| 5 | +#include "mlx/backend/cuda/kernels/cast_op.cuh" |
| 6 | +#include "mlx/backend/cuda/kernels/fp16_math.cuh" |
| 7 | +#include "mlx/backend/gpu/copy.h" |
| 8 | +#include "mlx/dtype_utils.h" |
| 9 | +#include "mlx/primitives.h" |
| 10 | + |
| 11 | +#include <cooperative_groups.h> |
| 12 | +#include <cooperative_groups/reduce.h> |
| 13 | +#include <nvtx3/nvtx3.hpp> |
| 14 | +#include <cub/block/block_load.cuh> |
| 15 | + |
| 16 | +#include <cassert> |
| 17 | + |
| 18 | +namespace mlx::core { |
| 19 | + |
| 20 | +namespace cu { |
| 21 | + |
| 22 | +namespace cg = cooperative_groups; |
| 23 | + |
| 24 | +template <typename T> |
| 25 | +inline __device__ T softmax_exp(T x) { |
| 26 | + // Softmax doesn't need high precision exponential cause x is gonna be in |
| 27 | + // (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)). |
| 28 | + return __expf(x); |
| 29 | +} |
| 30 | + |
| 31 | +template <typename T, typename AccT, int BLOCK_DIM, int N_READS = 4> |
| 32 | +__global__ void softmax(const T* in, T* out, int axis_size) { |
| 33 | + auto grid = cg::this_grid(); |
| 34 | + auto block = cg::this_thread_block(); |
| 35 | + auto warp = cg::tiled_partition<WARP_SIZE>(block); |
| 36 | + |
| 37 | + in += grid.block_rank() * axis_size; |
| 38 | + out += grid.block_rank() * axis_size; |
| 39 | + |
| 40 | + cg::greater<AccT> max_op; |
| 41 | + cg::plus<AccT> plus_op; |
| 42 | + |
| 43 | + // Thread reduce. |
| 44 | + AccT prevmax; |
| 45 | + AccT maxval = Limits<AccT>::finite_min(); |
| 46 | + AccT normalizer = 0; |
| 47 | + for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) { |
| 48 | + AccT vals[N_READS]; |
| 49 | + cub::LoadDirectBlocked( |
| 50 | + r * BLOCK_DIM + block.thread_rank(), |
| 51 | + make_cast_iterator<AccT>(in), |
| 52 | + vals, |
| 53 | + axis_size, |
| 54 | + Limits<AccT>::finite_min()); |
| 55 | + prevmax = maxval; |
| 56 | + maxval = max_op(maxval, cub::ThreadReduce(vals, max_op)); |
| 57 | + // Online normalizer calculation for softmax: |
| 58 | + // https://github.com/NVIDIA/online-softmax |
| 59 | + normalizer = normalizer * softmax_exp(prevmax - maxval); |
| 60 | + for (int i = 0; i < N_READS; i++) { |
| 61 | + normalizer = normalizer + softmax_exp(vals[i] - maxval); |
| 62 | + } |
| 63 | + } |
| 64 | + |
| 65 | + // First warp reduce. |
| 66 | + prevmax = maxval; |
| 67 | + maxval = cg::reduce(warp, maxval, max_op); |
| 68 | + normalizer = normalizer * softmax_exp(prevmax - maxval); |
| 69 | + normalizer = cg::reduce(warp, normalizer, plus_op); |
| 70 | + |
| 71 | + __shared__ AccT local_max[WARP_SIZE]; |
| 72 | + __shared__ AccT local_normalizer[WARP_SIZE]; |
| 73 | + |
| 74 | + // Write to shared memory and do second warp reduce. |
| 75 | + prevmax = maxval; |
| 76 | + if (warp.thread_rank() == 0) { |
| 77 | + local_max[warp.meta_group_rank()] = maxval; |
| 78 | + } |
| 79 | + block.sync(); |
| 80 | + maxval = warp.thread_rank() < warp.meta_group_size() |
| 81 | + ? local_max[warp.thread_rank()] |
| 82 | + : Limits<AccT>::finite_min(); |
| 83 | + maxval = cg::reduce(warp, maxval, max_op); |
| 84 | + normalizer = normalizer * softmax_exp(prevmax - maxval); |
| 85 | + if (warp.thread_rank() == 0) { |
| 86 | + local_normalizer[warp.meta_group_rank()] = normalizer; |
| 87 | + } |
| 88 | + block.sync(); |
| 89 | + normalizer = warp.thread_rank() < warp.meta_group_size() |
| 90 | + ? local_normalizer[warp.thread_rank()] |
| 91 | + : AccT{}; |
| 92 | + normalizer = cg::reduce(warp, normalizer, plus_op); |
| 93 | + normalizer = 1 / normalizer; |
| 94 | + |
| 95 | + // Write output. |
| 96 | + for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) { |
| 97 | + auto index = r * BLOCK_DIM + block.thread_rank(); |
| 98 | + T vals[N_READS]; |
| 99 | + cub::LoadDirectBlocked(index, in, vals, axis_size); |
| 100 | + for (int i = 0; i < N_READS; i++) { |
| 101 | + vals[i] = softmax_exp(static_cast<AccT>(vals[i]) - maxval) * normalizer; |
| 102 | + } |
| 103 | + cub::StoreDirectBlocked(index, out, vals, axis_size); |
| 104 | + } |
| 105 | +} |
| 106 | + |
| 107 | +} // namespace cu |
| 108 | + |
| 109 | +void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) { |
| 110 | + nvtx3::scoped_range r("Softmax::eval_gpu"); |
| 111 | + assert(inputs.size() == 1); |
| 112 | + auto& s = stream(); |
| 113 | + |
| 114 | + // Make sure that the last dimension is contiguous. |
| 115 | + auto set_output = [&s, &out](const array& x) { |
| 116 | + if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) { |
| 117 | + if (x.is_donatable()) { |
| 118 | + out.copy_shared_buffer(x); |
| 119 | + } else { |
| 120 | + out.set_data( |
| 121 | + allocator::malloc(x.data_size() * x.itemsize()), |
| 122 | + x.data_size(), |
| 123 | + x.strides(), |
| 124 | + x.flags()); |
| 125 | + } |
| 126 | + return x; |
| 127 | + } else { |
| 128 | + auto x_copy = array(x.shape(), x.dtype(), nullptr, {}); |
| 129 | + copy_gpu(x, x_copy, CopyType::General, s); |
| 130 | + out.copy_shared_buffer(x_copy); |
| 131 | + return x_copy; |
| 132 | + } |
| 133 | + }; |
| 134 | + |
| 135 | + array in = set_output(inputs[0]); |
| 136 | + bool precise = in.dtype() != float32 && precise_; |
| 137 | + |
| 138 | + int axis_size = in.shape().back(); |
| 139 | + int n_rows = in.data_size() / axis_size; |
| 140 | + |
| 141 | + auto& encoder = cu::get_command_encoder(s); |
| 142 | + encoder.set_input_array(in); |
| 143 | + encoder.set_output_array(out); |
| 144 | + encoder.launch_kernel([&](cudaStream_t stream) { |
| 145 | + MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "softmax", CTYPE, { |
| 146 | + using DataType = cuda_type_t<CTYPE>; |
| 147 | + constexpr int N_READS = 4; |
| 148 | + MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { |
| 149 | + auto kernel = cu::softmax<DataType, DataType, BLOCK_DIM, N_READS>; |
| 150 | + if (precise) { |
| 151 | + kernel = cu::softmax<DataType, float, BLOCK_DIM, N_READS>; |
| 152 | + } |
| 153 | + kernel<<<n_rows, BLOCK_DIM, 0, stream>>>( |
| 154 | + in.data<DataType>(), out.data<DataType>(), axis_size); |
| 155 | + }); |
| 156 | + }); |
| 157 | + }); |
| 158 | +} |
| 159 | + |
| 160 | +} // namespace mlx::core |
0 commit comments