Skip to content

Commit 6316472

Browse files
committed
Implement logsumexp
1 parent 823fb81 commit 6316472

File tree

4 files changed

+165
-6
lines changed

4 files changed

+165
-6
lines changed

mlx/backend/cuda/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ target_sources(
1717
${CMAKE_CURRENT_SOURCE_DIR}/gather_axis.cu
1818
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cu
1919
${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu
20+
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu
2021
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
2122
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu
2223
${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu

mlx/backend/cuda/logsumexp.cu

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
// Copyright © 2025 Apple Inc.
2+
3+
#include "mlx/backend/cuda/device.h"
4+
#include "mlx/backend/cuda/iterators/cast_iterator.cuh"
5+
#include "mlx/backend/cuda/kernel_utils.cuh"
6+
#include "mlx/backend/gpu/copy.h"
7+
#include "mlx/dtype_utils.h"
8+
#include "mlx/primitives.h"
9+
10+
#include <cooperative_groups.h>
11+
#include <cooperative_groups/reduce.h>
12+
#include <nvtx3/nvtx3.hpp>
13+
#include <cub/block/block_load.cuh>
14+
15+
#include <cassert>
16+
17+
namespace mlx::core {
18+
19+
namespace cu {
20+
21+
namespace cg = cooperative_groups;
22+
23+
template <typename T>
24+
inline __device__ T softmax_exp(T x) {
25+
// Softmax doesn't need high precision exponential cause x is gonna be in
26+
// (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)).
27+
return __expf(x);
28+
}
29+
30+
template <typename T, typename AccT, int BLOCK_DIM, int N_READS = 4>
31+
__global__ void logsumexp(const T* in, T* out, int axis_size) {
32+
auto grid = cg::this_grid();
33+
auto block = cg::this_thread_block();
34+
auto warp = cg::tiled_partition<WARP_SIZE>(block);
35+
36+
in += grid.block_rank() * axis_size;
37+
38+
cg::greater<AccT> max_op;
39+
cg::plus<AccT> plus_op;
40+
41+
// Thread reduce.
42+
AccT prevmax;
43+
AccT maxval = Limits<AccT>::finite_min;
44+
AccT normalizer = zero_value<AccT>();
45+
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) {
46+
AccT vals[N_READS];
47+
cub::LoadDirectBlocked(
48+
r * BLOCK_DIM + block.thread_rank(),
49+
make_cast_iterator<AccT>(in),
50+
vals,
51+
axis_size,
52+
Limits<AccT>::min);
53+
prevmax = maxval;
54+
maxval = max_op(maxval, cub::ThreadReduce(vals, max_op));
55+
// Online normalizer calculation for softmax:
56+
// https://github.com/NVIDIA/online-softmax
57+
normalizer = normalizer * softmax_exp(prevmax - maxval);
58+
for (int i = 0; i < N_READS; i++) {
59+
normalizer = normalizer + softmax_exp(vals[i] - maxval);
60+
}
61+
}
62+
63+
// First warp reduce.
64+
prevmax = maxval;
65+
maxval = cg::reduce(warp, maxval, max_op);
66+
normalizer = normalizer * softmax_exp(prevmax - maxval);
67+
normalizer = cg::reduce(warp, normalizer, plus_op);
68+
69+
__shared__ AccT local_max[WARP_SIZE];
70+
__shared__ AccT local_normalizer[WARP_SIZE];
71+
72+
// Write to shared memory and do second warp reduce.
73+
prevmax = maxval;
74+
if (warp.thread_rank() == 0) {
75+
local_max[warp.meta_group_rank()] = maxval;
76+
}
77+
block.sync();
78+
maxval = warp.thread_rank() < warp.meta_group_size()
79+
? local_max[warp.thread_rank()]
80+
: Limits<AccT>::finite_min;
81+
maxval = cg::reduce(warp, maxval, max_op);
82+
normalizer = normalizer * softmax_exp(prevmax - maxval);
83+
if (warp.thread_rank() == 0) {
84+
local_normalizer[warp.meta_group_rank()] = normalizer;
85+
}
86+
block.sync();
87+
normalizer = warp.thread_rank() < warp.meta_group_size()
88+
? local_normalizer[warp.thread_rank()]
89+
: AccT{};
90+
normalizer = cg::reduce(warp, normalizer, plus_op);
91+
92+
// Write output.
93+
if (block.thread_rank() == 0) {
94+
out[grid.block_rank()] = isinf(maxval) ? maxval : log(normalizer) + maxval;
95+
}
96+
}
97+
98+
} // namespace cu
99+
100+
void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
101+
nvtx3::scoped_range r("LogSumExp::eval_gpu");
102+
assert(inputs.size() == 1);
103+
auto& s = stream();
104+
auto& encoder = cu::get_command_encoder(s);
105+
106+
// Make sure that the last dimension is contiguous.
107+
auto ensure_contiguous = [&s, &encoder](const array& x) {
108+
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
109+
return x;
110+
} else {
111+
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
112+
copy_gpu(x, x_copy, CopyType::General, s);
113+
encoder.add_temporary(x_copy);
114+
return x_copy;
115+
}
116+
};
117+
118+
auto in = ensure_contiguous(inputs[0]);
119+
if (in.flags().row_contiguous) {
120+
out.set_data(allocator::malloc(out.nbytes()));
121+
} else {
122+
auto n = in.shape(-1);
123+
auto flags = in.flags();
124+
auto strides = in.strides();
125+
for (auto& s : strides) {
126+
s /= n;
127+
}
128+
bool col_contig = strides[0] == 1;
129+
for (int i = 1; col_contig && i < strides.size(); ++i) {
130+
col_contig &=
131+
(out.shape(i) == 1 || strides[i - 1] == out.shape(i) * strides[i]);
132+
}
133+
flags.col_contiguous = col_contig;
134+
out.set_data(
135+
allocator::malloc(in.nbytes() / n),
136+
in.data_size() / n,
137+
std::move(strides),
138+
flags);
139+
}
140+
141+
int axis_size = in.shape().back();
142+
int n_rows = in.data_size() / axis_size;
143+
144+
encoder.set_input_array(in);
145+
encoder.set_output_array(out);
146+
encoder.launch_kernel([&](cudaStream_t stream) {
147+
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "logsumexp", CTYPE, {
148+
using DataType = cuda_type_t<CTYPE>;
149+
constexpr int N_READS = 4;
150+
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
151+
auto kernel = cu::logsumexp<DataType, float, BLOCK_DIM, N_READS>;
152+
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
153+
in.data<DataType>(), out.data<DataType>(), axis_size);
154+
});
155+
});
156+
});
157+
}
158+
159+
} // namespace mlx::core

mlx/backend/cuda/primitives.cu

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,6 @@ NO_GPU(GatherMM)
120120
NO_GPU(GatherQMM)
121121
NO_GPU(Hadamard)
122122
NO_GPU(Load)
123-
NO_GPU(LogSumExp)
124123
NO_GPU_MULTI(LUF)
125124
NO_GPU(Partition)
126125
NO_GPU_MULTI(QRF)

mlx/backend/cuda/softmax.cu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ inline __device__ T softmax_exp(T x) {
2727
return __expf(x);
2828
}
2929

30-
template <typename T, typename AccT, uint32_t BLOCK_DIM, uint32_t N_READS = 4>
31-
__global__ void softmax(const T* in, T* out, const uint32_t axis_size) {
30+
template <typename T, typename AccT, int BLOCK_DIM, int N_READS = 4>
31+
__global__ void softmax(const T* in, T* out, int axis_size) {
3232
auto grid = cg::this_grid();
3333
auto block = cg::this_thread_block();
3434
auto warp = cg::tiled_partition<WARP_SIZE>(block);
@@ -134,16 +134,16 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
134134
array in = set_output(inputs[0]);
135135
bool precise = in.dtype() != float32 && precise_;
136136

137-
uint32_t axis_size = in.shape().back();
138-
uint32_t n_rows = in.data_size() / axis_size;
137+
int axis_size = in.shape().back();
138+
int n_rows = in.data_size() / axis_size;
139139

140140
auto& encoder = cu::get_command_encoder(s);
141141
encoder.set_input_array(in);
142142
encoder.set_output_array(out);
143143
encoder.launch_kernel([&](cudaStream_t stream) {
144144
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "softmax", CTYPE, {
145145
using DataType = cuda_type_t<CTYPE>;
146-
constexpr uint32_t N_READS = 4;
146+
constexpr int N_READS = 4;
147147
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
148148
auto kernel = cu::softmax<DataType, DataType, BLOCK_DIM, N_READS>;
149149
if (precise) {

0 commit comments

Comments
 (0)