Skip to content

Commit 1b1168d

Browse files
committed
CUDA backend: sort
1 parent c360e02 commit 1b1168d

File tree

3 files changed

+182
-2
lines changed

3 files changed

+182
-2
lines changed

mlx/backend/cuda/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ target_sources(
1717
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
1818
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu
1919
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
20+
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
2021
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
2122
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
2223
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)

mlx/backend/cuda/primitives.cu

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,6 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
166166

167167
NO_GPU(AddMM)
168168
NO_GPU(ArgPartition)
169-
NO_GPU(ArgSort)
170169
NO_GPU(BlockMaskedMM)
171170
NO_GPU_MULTI(Compiled)
172171
NO_GPU(Convolution)
@@ -191,7 +190,6 @@ NO_GPU(ScatterAxis)
191190
NO_GPU(Select)
192191
NO_GPU(SliceUpdate)
193192
NO_GPU(Softmax)
194-
NO_GPU(Sort)
195193
NO_GPU_MULTI(SVD)
196194
NO_GPU(Inverse)
197195
NO_GPU(Cholesky)

mlx/backend/cuda/sort.cu

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
// Copyright © 2025 Apple Inc.
2+
3+
#include "mlx/backend/common/utils.h"
4+
#include "mlx/backend/cuda/device.h"
5+
#include "mlx/backend/cuda/dtype_utils.cuh"
6+
#include "mlx/backend/cuda/kernels/utils.cuh"
7+
#include "mlx/backend/gpu/copy.h"
8+
#include "mlx/dtype_utils.h"
9+
#include "mlx/primitives.h"
10+
11+
#include <nvtx3/nvtx3.hpp>
12+
#include <thrust/device_ptr.h>
13+
#include <thrust/transform.h>
14+
#include <cub/device/device_segmented_sort.cuh>
15+
16+
#include <cassert>
17+
#include <numeric>
18+
19+
namespace mlx::core {
20+
21+
namespace {
22+
23+
template <typename T>
24+
struct ModOp {
25+
T divisor;
26+
__device__ T operator()(T x) {
27+
return x % divisor;
28+
}
29+
};
30+
31+
// We can not use any op in eval, make an utility.
32+
array swapaxes_in_eval(const array& in, int axis1, int axis2) {
33+
std::vector<int> axes(in.ndim());
34+
std::iota(axes.begin(), axes.end(), 0);
35+
std::swap(axes[axis1], axes[axis2]);
36+
// TODO: Share the code with Transpose::eval.
37+
Shape shape(axes.size());
38+
Strides strides(in.ndim());
39+
for (size_t ax = 0; ax < axes.size(); ++ax) {
40+
shape[ax] = in.shape()[axes[ax]];
41+
strides[ax] = in.strides()[axes[ax]];
42+
}
43+
auto flags = in.flags();
44+
if (flags.contiguous) {
45+
auto [_, row_contiguous, col_contiguous] = check_contiguity(shape, strides);
46+
flags.row_contiguous = row_contiguous;
47+
flags.col_contiguous = col_contiguous;
48+
}
49+
array out(shape, in.dtype(), nullptr, {});
50+
out.copy_shared_buffer(in, strides, flags, in.data_size());
51+
return out;
52+
}
53+
54+
template <typename... Args>
55+
void segmented_sort_pairs(cu::CommandEncoder& encoder, Args&&... args) {
56+
// Allocate temporary storage.
57+
size_t size;
58+
CHECK_CUDA_ERROR(
59+
cub::DeviceSegmentedSort::StableSortPairs(nullptr, size, args...));
60+
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
61+
encoder.add_temporary(temp);
62+
// Run op.
63+
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortPairs(
64+
temp.data<void>(), size, args...));
65+
}
66+
67+
template <typename... Args>
68+
void segmented_sort(cu::CommandEncoder& encoder, Args&&... args) {
69+
// Allocate temporary storage.
70+
size_t size;
71+
CHECK_CUDA_ERROR(
72+
cub::DeviceSegmentedSort::StableSortKeys(nullptr, size, args...));
73+
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
74+
encoder.add_temporary(temp);
75+
// Run op.
76+
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortKeys(
77+
temp.data<void>(), size, args...));
78+
}
79+
80+
void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
81+
array out = out_;
82+
auto& encoder = cu::get_command_encoder(s);
83+
encoder.set_input_array(in);
84+
encoder.set_output_array(out);
85+
86+
if (axis < 0) {
87+
axis += in.ndim();
88+
}
89+
int nsort = in.shape(axis);
90+
int nsegments = in.data_size() / nsort;
91+
int last_dim = in.ndim() - 1;
92+
93+
// If we are not sorting the innermost dimension of a contiguous array,
94+
// transpose and make a copy.
95+
bool is_segmented_sort = in.flags().contiguous && in.strides()[axis] == 1;
96+
if (!is_segmented_sort) {
97+
array trans = swapaxes_in_eval(in, axis, last_dim);
98+
in = array(trans.shape(), trans.dtype(), nullptr, {});
99+
copy_gpu(trans, in, CopyType::General, s);
100+
encoder.add_temporary(in);
101+
out = array(allocator::malloc(out.nbytes()), in.shape(), out.dtype());
102+
encoder.add_temporary(out);
103+
} else {
104+
out.set_data(allocator::malloc(out.nbytes()));
105+
}
106+
107+
encoder.launch_kernel([&](cudaStream_t stream) {
108+
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
109+
if constexpr (!std::is_same_v<CTYPE, complex64_t>) {
110+
using Type = cuda_type_t<CTYPE>;
111+
auto offsets = thrust::make_transform_iterator(
112+
thrust::make_counting_iterator(0),
113+
[nsort] __device__(int i) { return i * nsort; });
114+
if (argsort) {
115+
// Indices in the sorted dimension.
116+
array indices(
117+
allocator::malloc(out.nbytes()), in.shape(), out.dtype());
118+
encoder.add_temporary(indices);
119+
thrust::transform(
120+
cu::thrust_policy(stream),
121+
thrust::counting_iterator<uint32_t>(0),
122+
thrust::counting_iterator<uint32_t>(indices.data_size()),
123+
thrust::device_pointer_cast(indices.data<uint32_t>()),
124+
ModOp<uint32_t>{static_cast<uint32_t>(nsort)});
125+
126+
// In argsort though we don't need the result of sorted values, the
127+
// API requires us to provide an array to store it.
128+
array discard(allocator::malloc(in.nbytes()), in.shape(), in.dtype());
129+
encoder.add_temporary(discard);
130+
131+
segmented_sort_pairs(
132+
encoder,
133+
in.data<Type>(),
134+
discard.data<Type>(),
135+
indices.data<uint32_t>(),
136+
out.data<uint32_t>(),
137+
in.data_size(),
138+
nsegments,
139+
offsets,
140+
offsets + 1,
141+
stream);
142+
} else {
143+
segmented_sort(
144+
encoder,
145+
in.data<Type>(),
146+
out.data<Type>(),
147+
in.data_size(),
148+
nsegments,
149+
offsets,
150+
offsets + 1,
151+
stream);
152+
}
153+
} else {
154+
throw std::runtime_error(
155+
"CUDA backend does not support sorting complex numbers");
156+
}
157+
});
158+
});
159+
160+
if (!is_segmented_sort) {
161+
// Swap the sorted axis back.
162+
// TODO: Do in-place transpose instead of using a temporary out array.
163+
copy_gpu(swapaxes_in_eval(out, axis, last_dim), out_, CopyType::General, s);
164+
}
165+
}
166+
167+
} // namespace
168+
169+
void ArgSort::eval_gpu(const std::vector<array>& inputs, array& out) {
170+
nvtx3::scoped_range r("ArgSort::eval_gpu");
171+
assert(inputs.size() == 1);
172+
gpu_sort(stream(), inputs[0], out, axis_, true);
173+
}
174+
175+
void Sort::eval_gpu(const std::vector<array>& inputs, array& out) {
176+
nvtx3::scoped_range r("Sort::eval_gpu");
177+
assert(inputs.size() == 1);
178+
gpu_sort(stream(), inputs[0], out, axis_, false);
179+
}
180+
181+
} // namespace mlx::core

0 commit comments

Comments
 (0)