Skip to content

Commit 120abba

Browse files
committed
CUDA backend: gather
1 parent 1b1168d commit 120abba

File tree

2 files changed

+206
-1
lines changed

2 files changed

+206
-1
lines changed

mlx/backend/cuda/gather.cu

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
// Copyright © 2025 Apple Inc.
2+
3+
#include "mlx/backend/cuda/device.h"
4+
#include "mlx/backend/cuda/dtype_utils.cuh"
5+
#include "mlx/backend/cuda/kernels/fp16_math.cuh"
6+
#include "mlx/backend/cuda/kernels/utils.cuh"
7+
#include "mlx/dtype_utils.h"
8+
#include "mlx/primitives.h"
9+
10+
#include <nvtx3/nvtx3.hpp>
11+
#include <thrust/gather.h>
12+
13+
namespace mlx::core {
14+
15+
namespace cu {
16+
17+
// Dispatch dynamic nidx to constexpr.
18+
#define MLX_SWITCH_NIDX(nidx, NIDX, ...) \
19+
if (nidx <= 2) { \
20+
constexpr uint32_t NIDX = 2; \
21+
__VA_ARGS__; \
22+
} else if (nidx <= 16) { \
23+
constexpr uint32_t NIDX = 16; \
24+
__VA_ARGS__; \
25+
} else { \
26+
throw std::runtime_error( \
27+
fmt::format("Indices array can not have more than {} items", nidx)); \
28+
}
29+
30+
// Dispatch dynamic idx_ndim to constexpr.
31+
#define MORE_THAN_ONE MAX_NDIM
32+
#define MLX_SWITCH_IDX_NDIM(idx_ndim, IDX_NDIM, ...) \
33+
if (idx_ndim == 0) { \
34+
constexpr uint32_t IDX_NDIM = 0; \
35+
__VA_ARGS__; \
36+
} else if (idx_ndim == 1) { \
37+
constexpr uint32_t IDX_NDIM = 1; \
38+
__VA_ARGS__; \
39+
} else { \
40+
constexpr uint32_t IDX_NDIM = MORE_THAN_ONE; \
41+
__VA_ARGS__; \
42+
}
43+
44+
// Convert an absolute index to positions in a 3d grid.
45+
template <typename T>
46+
struct IndexToDims {
47+
T dim0;
48+
T dim1;
49+
T dim2;
50+
51+
__device__ cuda::std::tuple<T, T, T> index_to_dims(T index) {
52+
T x = index / (dim1 * dim2);
53+
T y = (index % (dim1 * dim2)) / dim2;
54+
T z = index % dim2;
55+
return cuda::std::make_tuple(x, y, z);
56+
}
57+
};
58+
59+
// Get absolute index from possible negative index.
60+
template <typename IdxT>
61+
inline __device__ auto absolute_index(IdxT idx, int32_t size) {
62+
if constexpr (cuda::std::is_unsigned_v<IdxT>) {
63+
return idx;
64+
} else {
65+
return static_cast<int32_t>(idx < 0 ? idx + size : idx);
66+
}
67+
}
68+
69+
template <typename T, size_t NIDX, size_t IDX_NDIM>
70+
struct Indices {
71+
size_t size;
72+
size_t ndim;
73+
cuda::std::array<const T*, NIDX> buffers;
74+
cuda::std::array<bool, NIDX> row_contiguous;
75+
cuda::std::array<int32_t, NIDX * IDX_NDIM> shapes;
76+
cuda::std::array<int64_t, NIDX * IDX_NDIM> strides;
77+
78+
template <typename Iter>
79+
Indices(Iter begin, Iter end) {
80+
size = end - begin;
81+
ndim = size > 0 ? begin->ndim() : 0;
82+
for (size_t i = 0; i < size; ++i) {
83+
const array& arr = *(begin + i);
84+
buffers[i] = arr.data<T>();
85+
row_contiguous[i] = arr.flags().row_contiguous;
86+
std::copy_n(arr.shape().begin(), ndim, shapes.begin() + i * ndim);
87+
std::copy_n(arr.strides().begin(), ndim, strides.begin() + i * ndim);
88+
}
89+
}
90+
};
91+
92+
template <typename IdxT, size_t NIDX, size_t IDX_NDIM, typename LocT = int64_t>
93+
struct IndexingOp {
94+
IndexToDims<size_t> dims;
95+
size_t ndim;
96+
Shape shape;
97+
Strides strides;
98+
Shape slice_sizes;
99+
Shape axes;
100+
Indices<IdxT, NIDX, IDX_NDIM> indices;
101+
102+
__device__ LocT operator()(size_t idx) {
103+
auto [x, y, z] = dims.index_to_dims(idx);
104+
105+
LocT src_idx = 0;
106+
for (size_t i = 0; i < indices.size; ++i) {
107+
LocT idx_loc;
108+
if constexpr (IDX_NDIM == 0) {
109+
idx_loc = 0;
110+
} else {
111+
idx_loc = x * indices.strides[indices.ndim * i];
112+
if constexpr (IDX_NDIM == MORE_THAN_ONE) {
113+
if (indices.row_contiguous[i]) {
114+
idx_loc += y;
115+
} else {
116+
size_t offset = indices.ndim * i + 1;
117+
idx_loc += elem_to_loc(
118+
y,
119+
indices.shapes.data() + offset,
120+
indices.strides.data() + offset,
121+
indices.ndim - 1);
122+
}
123+
}
124+
}
125+
auto ax = axes[i];
126+
auto idx_val = absolute_index(indices.buffers[i][idx_loc], shape[ax]);
127+
src_idx += static_cast<LocT>(idx_val) * strides[ax];
128+
}
129+
130+
LocT src_offset = elem_to_loc(z, slice_sizes.data(), strides.data(), ndim);
131+
return src_offset + src_idx;
132+
}
133+
};
134+
135+
} // namespace cu
136+
137+
void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
138+
nvtx3::scoped_range r("Gather::eval_gpu");
139+
out.set_data(allocator::malloc(out.nbytes()));
140+
if (out.size() == 0) {
141+
return;
142+
}
143+
144+
auto& s = stream();
145+
auto& encoder = cu::get_command_encoder(s);
146+
for (const auto& in : inputs) {
147+
encoder.set_input_array(in);
148+
}
149+
encoder.set_output_array(out);
150+
151+
const auto& src = inputs[0];
152+
bool has_indices = inputs.size() > 1;
153+
auto idx_dtype = has_indices ? inputs[1].dtype() : bool_;
154+
auto idx_ndim = has_indices ? inputs[1].ndim() : 0;
155+
156+
size_t dim0 = 1;
157+
size_t dim1 = 1;
158+
if (has_indices) {
159+
if (inputs[1].ndim() >= 1) {
160+
dim0 = inputs[1].shape(0);
161+
}
162+
if (inputs[1].ndim() >= 2) {
163+
dim1 = inputs[1].size() / dim0;
164+
}
165+
}
166+
size_t dim2 = 1;
167+
for (size_t s : slice_sizes_) {
168+
dim2 *= s;
169+
}
170+
171+
encoder.launch_kernel([&](cudaStream_t stream) {
172+
MLX_SWITCH_ALL_TYPES(idx_dtype, CTYPE_IDX, {
173+
using IndexType = cuda_type_t<CTYPE_IDX>;
174+
if constexpr (cuda::std::is_integral_v<IndexType>) {
175+
MLX_SWITCH_NIDX(inputs.size() - 1, NIDX, {
176+
MLX_SWITCH_IDX_NDIM(idx_ndim, IDX_NDIM, {
177+
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_DATA, {
178+
using DataType = cuda_type_t<CTYPE_DATA>;
179+
auto map_begin = thrust::make_transform_iterator(
180+
thrust::make_counting_iterator(0),
181+
cu::IndexingOp<IndexType, NIDX, IDX_NDIM>{
182+
{dim0, dim1, dim2},
183+
src.ndim(),
184+
cu::const_param(src.shape()),
185+
cu::const_param(src.strides()),
186+
cu::const_param(slice_sizes_),
187+
cu::const_param(axes_),
188+
{inputs.begin() + 1, inputs.end()}});
189+
thrust::gather(
190+
cu::thrust_policy(stream),
191+
map_begin,
192+
map_begin + out.size(),
193+
src.data<DataType>(),
194+
out.data<DataType>());
195+
});
196+
});
197+
});
198+
} else {
199+
throw std::runtime_error(fmt::format(
200+
"Can not use dtype {} as index.", dtype_to_string(idx_dtype)));
201+
}
202+
});
203+
});
204+
}
205+
206+
} // namespace mlx::core

mlx/backend/cuda/primitives.cu

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,6 @@ NO_GPU_MULTI(DivMod)
173173
NO_GPU(DynamicSlice)
174174
NO_GPU(DynamicSliceUpdate)
175175
NO_GPU(FFT)
176-
NO_GPU(Gather)
177176
NO_GPU(GatherAxis)
178177
NO_GPU(GatherMM)
179178
NO_GPU(GatherQMM)

0 commit comments

Comments
 (0)