Skip to content

Commit 87c963d

Browse files
committed
CUDA backend: softmax
1 parent 99c33d0 commit 87c963d

File tree

4 files changed

+321
-2
lines changed

4 files changed

+321
-2
lines changed

mlx/backend/cuda/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@ target_sources(
1818
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
1919
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
2020
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
21+
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu
2122
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
2223
${CMAKE_CURRENT_SOURCE_DIR}/random.cu
2324
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
25+
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
2426
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
2527
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
2628
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp

mlx/backend/cuda/logsumexp.cu

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
// Copyright © 2025 Apple Inc.
2+
3+
#include "mlx/backend/cuda/device.h"
4+
#include "mlx/backend/cuda/kernel_utils.cuh"
5+
#include "mlx/backend/cuda/kernels/cast_op.cuh"
6+
#include "mlx/backend/gpu/copy.h"
7+
#include "mlx/dtype_utils.h"
8+
#include "mlx/primitives.h"
9+
10+
#include <cooperative_groups.h>
11+
#include <cooperative_groups/reduce.h>
12+
#include <nvtx3/nvtx3.hpp>
13+
#include <cub/block/block_load.cuh>
14+
15+
#include <cassert>
16+
17+
namespace mlx::core {
18+
19+
namespace cu {
20+
21+
namespace cg = cooperative_groups;
22+
23+
template <typename T>
24+
inline __device__ T softmax_exp(T x) {
25+
// Softmax doesn't need high precision exponential cause x is gonna be in
26+
// (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)).
27+
return __expf(x);
28+
}
29+
30+
template <typename T, typename AccT, int BLOCK_DIM, int N_READS = 4>
31+
__global__ void logsumexp(const T* in, T* out, int axis_size) {
32+
auto grid = cg::this_grid();
33+
auto block = cg::this_thread_block();
34+
auto warp = cg::tiled_partition<WARP_SIZE>(block);
35+
36+
in += grid.block_rank() * axis_size;
37+
38+
cg::greater<AccT> max_op;
39+
cg::plus<AccT> plus_op;
40+
41+
// Thread reduce.
42+
AccT prevmax;
43+
AccT maxval = Limits<AccT>::finite_min();
44+
AccT normalizer = 0;
45+
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) {
46+
AccT vals[N_READS];
47+
cub::LoadDirectBlocked(
48+
r * BLOCK_DIM + block.thread_rank(),
49+
make_cast_iterator<AccT>(in),
50+
vals,
51+
axis_size,
52+
Limits<AccT>::min());
53+
prevmax = maxval;
54+
maxval = max_op(maxval, cub::ThreadReduce(vals, max_op));
55+
// Online normalizer calculation for softmax:
56+
// https://github.com/NVIDIA/online-softmax
57+
normalizer = normalizer * softmax_exp(prevmax - maxval);
58+
for (int i = 0; i < N_READS; i++) {
59+
normalizer = normalizer + softmax_exp(vals[i] - maxval);
60+
}
61+
}
62+
63+
// First warp reduce.
64+
prevmax = maxval;
65+
maxval = cg::reduce(warp, maxval, max_op);
66+
normalizer = normalizer * softmax_exp(prevmax - maxval);
67+
normalizer = cg::reduce(warp, normalizer, plus_op);
68+
69+
__shared__ AccT local_max[WARP_SIZE];
70+
__shared__ AccT local_normalizer[WARP_SIZE];
71+
72+
// Write to shared memory and do second warp reduce.
73+
prevmax = maxval;
74+
if (warp.thread_rank() == 0) {
75+
local_max[warp.meta_group_rank()] = maxval;
76+
}
77+
block.sync();
78+
maxval = warp.thread_rank() < warp.meta_group_size()
79+
? local_max[warp.thread_rank()]
80+
: Limits<AccT>::finite_min();
81+
maxval = cg::reduce(warp, maxval, max_op);
82+
normalizer = normalizer * softmax_exp(prevmax - maxval);
83+
if (warp.thread_rank() == 0) {
84+
local_normalizer[warp.meta_group_rank()] = normalizer;
85+
}
86+
block.sync();
87+
normalizer = warp.thread_rank() < warp.meta_group_size()
88+
? local_normalizer[warp.thread_rank()]
89+
: AccT{};
90+
normalizer = cg::reduce(warp, normalizer, plus_op);
91+
92+
// Write output.
93+
if (block.thread_rank() == 0) {
94+
out[grid.block_rank()] = isinf(maxval) ? maxval : log(normalizer) + maxval;
95+
}
96+
}
97+
98+
} // namespace cu
99+
100+
void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
101+
nvtx3::scoped_range r("LogSumExp::eval_gpu");
102+
assert(inputs.size() == 1);
103+
auto& s = stream();
104+
auto& encoder = cu::get_command_encoder(s);
105+
106+
// Make sure that the last dimension is contiguous.
107+
auto ensure_contiguous = [&s, &encoder](const array& x) {
108+
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
109+
return x;
110+
} else {
111+
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
112+
copy_gpu(x, x_copy, CopyType::General, s);
113+
encoder.add_temporary(x_copy);
114+
return x_copy;
115+
}
116+
};
117+
118+
auto in = ensure_contiguous(inputs[0]);
119+
if (in.flags().row_contiguous) {
120+
out.set_data(allocator::malloc(out.nbytes()));
121+
} else {
122+
auto n = in.shape(-1);
123+
auto flags = in.flags();
124+
auto strides = in.strides();
125+
for (auto& s : strides) {
126+
s /= n;
127+
}
128+
bool col_contig = strides[0] == 1;
129+
for (int i = 1; col_contig && i < strides.size(); ++i) {
130+
col_contig &=
131+
(out.shape(i) == 1 || strides[i - 1] == out.shape(i) * strides[i]);
132+
}
133+
flags.col_contiguous = col_contig;
134+
out.set_data(
135+
allocator::malloc(in.nbytes() / n),
136+
in.data_size() / n,
137+
std::move(strides),
138+
flags);
139+
}
140+
141+
int axis_size = in.shape().back();
142+
int n_rows = in.data_size() / axis_size;
143+
144+
encoder.set_input_array(in);
145+
encoder.set_output_array(out);
146+
encoder.launch_kernel([&](cudaStream_t stream) {
147+
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "logsumexp", CTYPE, {
148+
using DataType = cuda_type_t<CTYPE>;
149+
constexpr int N_READS = 4;
150+
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
151+
auto kernel = cu::logsumexp<DataType, float, BLOCK_DIM, N_READS>;
152+
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
153+
in.data<DataType>(), out.data<DataType>(), axis_size);
154+
});
155+
});
156+
});
157+
}
158+
159+
} // namespace mlx::core

mlx/backend/cuda/primitives.cu

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ NO_GPU(GatherMM)
8686
NO_GPU(GatherQMM)
8787
NO_GPU(Hadamard)
8888
NO_GPU(Load)
89-
NO_GPU(LogSumExp)
9089
NO_GPU_MULTI(LUF)
9190
NO_GPU(Partition)
9291
NO_GPU_MULTI(QRF)
@@ -97,7 +96,6 @@ NO_GPU(Scatter)
9796
NO_GPU(ScatterAxis)
9897
NO_GPU(Select)
9998
NO_GPU(SliceUpdate)
100-
NO_GPU(Softmax)
10199
NO_GPU_MULTI(SVD)
102100
NO_GPU(Inverse)
103101
NO_GPU(Cholesky)

mlx/backend/cuda/softmax.cu

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
// Copyright © 2025 Apple Inc.
2+
3+
#include "mlx/backend/cuda/device.h"
4+
#include "mlx/backend/cuda/kernel_utils.cuh"
5+
#include "mlx/backend/cuda/kernels/cast_op.cuh"
6+
#include "mlx/backend/cuda/kernels/fp16_math.cuh"
7+
#include "mlx/backend/gpu/copy.h"
8+
#include "mlx/dtype_utils.h"
9+
#include "mlx/primitives.h"
10+
11+
#include <cooperative_groups.h>
12+
#include <cooperative_groups/reduce.h>
13+
#include <nvtx3/nvtx3.hpp>
14+
#include <cub/block/block_load.cuh>
15+
16+
#include <cassert>
17+
18+
namespace mlx::core {
19+
20+
namespace cu {
21+
22+
namespace cg = cooperative_groups;
23+
24+
template <typename T>
25+
inline __device__ T softmax_exp(T x) {
26+
// Softmax doesn't need high precision exponential cause x is gonna be in
27+
// (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)).
28+
return __expf(x);
29+
}
30+
31+
template <typename T, typename AccT, int BLOCK_DIM, int N_READS = 4>
32+
__global__ void softmax(const T* in, T* out, int axis_size) {
33+
auto grid = cg::this_grid();
34+
auto block = cg::this_thread_block();
35+
auto warp = cg::tiled_partition<WARP_SIZE>(block);
36+
37+
in += grid.block_rank() * axis_size;
38+
out += grid.block_rank() * axis_size;
39+
40+
cg::greater<AccT> max_op;
41+
cg::plus<AccT> plus_op;
42+
43+
// Thread reduce.
44+
AccT prevmax;
45+
AccT maxval = Limits<AccT>::finite_min();
46+
AccT normalizer = 0;
47+
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) {
48+
AccT vals[N_READS];
49+
cub::LoadDirectBlocked(
50+
r * BLOCK_DIM + block.thread_rank(),
51+
make_cast_iterator<AccT>(in),
52+
vals,
53+
axis_size,
54+
Limits<AccT>::finite_min());
55+
prevmax = maxval;
56+
maxval = max_op(maxval, cub::ThreadReduce(vals, max_op));
57+
// Online normalizer calculation for softmax:
58+
// https://github.com/NVIDIA/online-softmax
59+
normalizer = normalizer * softmax_exp(prevmax - maxval);
60+
for (int i = 0; i < N_READS; i++) {
61+
normalizer = normalizer + softmax_exp(vals[i] - maxval);
62+
}
63+
}
64+
65+
// First warp reduce.
66+
prevmax = maxval;
67+
maxval = cg::reduce(warp, maxval, max_op);
68+
normalizer = normalizer * softmax_exp(prevmax - maxval);
69+
normalizer = cg::reduce(warp, normalizer, plus_op);
70+
71+
__shared__ AccT local_max[WARP_SIZE];
72+
__shared__ AccT local_normalizer[WARP_SIZE];
73+
74+
// Write to shared memory and do second warp reduce.
75+
prevmax = maxval;
76+
if (warp.thread_rank() == 0) {
77+
local_max[warp.meta_group_rank()] = maxval;
78+
}
79+
block.sync();
80+
maxval = warp.thread_rank() < warp.meta_group_size()
81+
? local_max[warp.thread_rank()]
82+
: Limits<AccT>::finite_min();
83+
maxval = cg::reduce(warp, maxval, max_op);
84+
normalizer = normalizer * softmax_exp(prevmax - maxval);
85+
if (warp.thread_rank() == 0) {
86+
local_normalizer[warp.meta_group_rank()] = normalizer;
87+
}
88+
block.sync();
89+
normalizer = warp.thread_rank() < warp.meta_group_size()
90+
? local_normalizer[warp.thread_rank()]
91+
: AccT{};
92+
normalizer = cg::reduce(warp, normalizer, plus_op);
93+
normalizer = 1 / normalizer;
94+
95+
// Write output.
96+
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) {
97+
auto index = r * BLOCK_DIM + block.thread_rank();
98+
T vals[N_READS];
99+
cub::LoadDirectBlocked(index, in, vals, axis_size);
100+
for (int i = 0; i < N_READS; i++) {
101+
vals[i] = softmax_exp(static_cast<AccT>(vals[i]) - maxval) * normalizer;
102+
}
103+
cub::StoreDirectBlocked(index, out, vals, axis_size);
104+
}
105+
}
106+
107+
} // namespace cu
108+
109+
void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
110+
nvtx3::scoped_range r("Softmax::eval_gpu");
111+
assert(inputs.size() == 1);
112+
auto& s = stream();
113+
114+
// Make sure that the last dimension is contiguous.
115+
auto set_output = [&s, &out](const array& x) {
116+
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
117+
if (x.is_donatable()) {
118+
out.copy_shared_buffer(x);
119+
} else {
120+
out.set_data(
121+
allocator::malloc(x.data_size() * x.itemsize()),
122+
x.data_size(),
123+
x.strides(),
124+
x.flags());
125+
}
126+
return x;
127+
} else {
128+
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
129+
copy_gpu(x, x_copy, CopyType::General, s);
130+
out.copy_shared_buffer(x_copy);
131+
return x_copy;
132+
}
133+
};
134+
135+
array in = set_output(inputs[0]);
136+
bool precise = in.dtype() != float32 && precise_;
137+
138+
int axis_size = in.shape().back();
139+
int n_rows = in.data_size() / axis_size;
140+
141+
auto& encoder = cu::get_command_encoder(s);
142+
encoder.set_input_array(in);
143+
encoder.set_output_array(out);
144+
encoder.launch_kernel([&](cudaStream_t stream) {
145+
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "softmax", CTYPE, {
146+
using DataType = cuda_type_t<CTYPE>;
147+
constexpr int N_READS = 4;
148+
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
149+
auto kernel = cu::softmax<DataType, DataType, BLOCK_DIM, N_READS>;
150+
if (precise) {
151+
kernel = cu::softmax<DataType, float, BLOCK_DIM, N_READS>;
152+
}
153+
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
154+
in.data<DataType>(), out.data<DataType>(), axis_size);
155+
});
156+
});
157+
});
158+
}
159+
160+
} // namespace mlx::core

0 commit comments

Comments
 (0)