Skip to content

Commit c68b719

Browse files
committed
CUDA backend
1 parent 5580b47 commit c68b719

37 files changed

+4045
-5
lines changed

CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
3434
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
3535
option(MLX_BUILD_METAL "Build metal backend" ON)
3636
option(MLX_BUILD_CPU "Build cpu backend" ON)
37+
option(MLX_BUILD_CUDA "Build cuda backend" OFF)
3738
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
3839
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
3940
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
@@ -83,6 +84,10 @@ if(MLX_BUILD_METAL)
8384
set(QUARTZ_LIB "-framework QuartzCore")
8485
endif()
8586

87+
if(MLX_BUILD_CUDA)
88+
enable_language(CUDA)
89+
endif()
90+
8691
if(MLX_BUILD_METAL AND NOT METAL_LIB)
8792
message(STATUS "Metal not found. Unable to build GPU")
8893
set(MLX_BUILD_METAL OFF)

mlx/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ target_sources(
55
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
66
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
77
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
8+
${CMAKE_CURRENT_SOURCE_DIR}/dtype_utils.cpp
89
${CMAKE_CURRENT_SOURCE_DIR}/export.cpp
910
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
1011
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
@@ -47,6 +48,8 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
4748

4849
if(MLX_BUILD_METAL)
4950
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
51+
elseif(MLX_BUILD_CUDA)
52+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda)
5053
else()
5154
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal)
5255
endif()

mlx/backend/cuda/CMakeLists.txt

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Filename rules in CUDA backend:
2+
#
3+
# * Use .cu/.cuh if code contains device code, and .cpp/.h if not.
4+
# * Device-only kernel code should be put in kernels/ subdir.
5+
# * Files in kernels/ subdir should not include files outside.
6+
target_sources(
7+
mlx
8+
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
9+
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
10+
${CMAKE_CURRENT_SOURCE_DIR}/copy.cu
11+
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
12+
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
13+
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
14+
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
15+
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
16+
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu
17+
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
18+
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
19+
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp)
20+
21+
target_compile_definitions(mlx PUBLIC MLX_USE_CUDA)
22+
23+
# Enable defining device lambda functions.
24+
target_compile_options(mlx
25+
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>")
26+
27+
set_target_properties(
28+
mlx
29+
PROPERTIES CUDA_STANDARD 17
30+
CUDA_SEPARABLE_COMPILATION ON
31+
# Compute capability 7 is required for synchronization between
32+
# CPU/GPU with managed memory.
33+
# TODO: Add more architectures for release build.
34+
CUDA_ARCHITECTURES "70;80")
35+
36+
# Use fixed version of CCCL.
37+
FetchContent_Declare(
38+
cccl
39+
URL "https://github.com/NVIDIA/cccl/releases/download/v2.8.1/cccl-v2.8.1.zip")
40+
FetchContent_MakeAvailable(cccl)
41+
target_include_directories(mlx PRIVATE BEFORE "${cccl_SOURCE_DIR}/include")
42+
43+
# Make CUDA APIs visible in .cpp files.
44+
find_package(CUDAToolkit REQUIRED)
45+
target_include_directories(mlx PRIVATE ${CUDAToolkit_INCLUDE_DIRS})
46+
47+
# Suppress nvcc warnings on MLX headers.
48+
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
49+
--diag_suppress=997>)

mlx/backend/cuda/allocator.cpp

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
// Copyright © 2025 Apple Inc.
2+
3+
#include "mlx/backend/cuda/allocator.h"
4+
#include "mlx/backend/cuda/utils.h"
5+
6+
#include <cuda_runtime.h>
7+
#include <fmt/format.h>
8+
9+
namespace mlx::core {
10+
11+
namespace mxcuda {
12+
13+
CudaAllocator::CudaAllocator() {
14+
size_t free, total;
15+
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
16+
memory_limit_ = total * 0.8;
17+
}
18+
19+
Buffer CudaAllocator::malloc(size_t size) {
20+
// TODO: Check memory limit.
21+
auto* buf = new CudaBuffer{nullptr, size};
22+
cudaError_t err = cudaMallocManaged(&buf->data, size);
23+
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
24+
throw std::runtime_error(
25+
fmt::format("cudaMallocManaged failed: {}", cudaGetErrorString(err)));
26+
}
27+
active_memory_ += size;
28+
peak_memory_ = std::max(active_memory_, peak_memory_);
29+
return Buffer{buf};
30+
}
31+
32+
void CudaAllocator::free(Buffer buffer) {
33+
auto* buf = static_cast<CudaBuffer*>(buffer.ptr());
34+
if (!buf) {
35+
return;
36+
}
37+
active_memory_ -= buf->size;
38+
cudaFree(buf->data);
39+
delete buf;
40+
}
41+
42+
size_t CudaAllocator::size(Buffer buffer) const {
43+
auto* buf = static_cast<CudaBuffer*>(buffer.ptr());
44+
if (!buf) {
45+
return 0;
46+
}
47+
return static_cast<CudaBuffer*>(buffer.ptr())->size;
48+
}
49+
50+
CudaAllocator& allocator() {
51+
// By creating the |allocator_| on heap, the destructor of CudaAllocator
52+
// will not be called on exit and buffers in the cache will be leaked. This
53+
// can save some time at program exit.
54+
static CudaAllocator* allocator_ = new CudaAllocator;
55+
return *allocator_;
56+
}
57+
58+
} // namespace mxcuda
59+
60+
namespace allocator {
61+
62+
Allocator& allocator() {
63+
return mxcuda::allocator();
64+
}
65+
66+
void* Buffer::raw_ptr() {
67+
if (!ptr_) {
68+
return nullptr;
69+
}
70+
return static_cast<mxcuda::CudaBuffer*>(ptr_)->data;
71+
}
72+
73+
} // namespace allocator
74+
75+
size_t get_active_memory() {
76+
return mxcuda::allocator().get_active_memory();
77+
}
78+
size_t get_peak_memory() {
79+
return mxcuda::allocator().get_peak_memory();
80+
}
81+
void reset_peak_memory() {
82+
return mxcuda::allocator().reset_peak_memory();
83+
}
84+
size_t set_memory_limit(size_t limit) {
85+
return mxcuda::allocator().set_memory_limit(limit);
86+
}
87+
size_t get_memory_limit() {
88+
return mxcuda::allocator().get_memory_limit();
89+
}
90+
91+
// No-ops for common allocator
92+
size_t get_cache_memory() {
93+
return 0;
94+
}
95+
size_t set_cache_limit(size_t) {
96+
return 0;
97+
}
98+
size_t set_wired_limit(size_t) {
99+
return 0;
100+
}
101+
void clear_cache() {}
102+
103+
} // namespace mlx::core

mlx/backend/cuda/allocator.h

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// Copyright © 2025 Apple Inc.
2+
3+
#pragma once
4+
5+
#include "mlx/allocator.h"
6+
7+
#include <utility>
8+
9+
namespace mlx::core::mxcuda {
10+
11+
using allocator::Buffer;
12+
13+
// Stores cuda-managed memory.
14+
struct CudaBuffer {
15+
void* data;
16+
size_t size;
17+
};
18+
19+
class CudaAllocator : public allocator::Allocator {
20+
public:
21+
Buffer malloc(size_t size) override;
22+
void free(Buffer buffer) override;
23+
size_t size(Buffer buffer) const override;
24+
25+
size_t get_active_memory() const {
26+
return active_memory_;
27+
};
28+
size_t get_peak_memory() const {
29+
return peak_memory_;
30+
};
31+
void reset_peak_memory() {
32+
peak_memory_ = 0;
33+
};
34+
size_t get_memory_limit() {
35+
return memory_limit_;
36+
}
37+
size_t set_memory_limit(size_t limit) {
38+
std::swap(memory_limit_, limit);
39+
return limit;
40+
}
41+
42+
private:
43+
CudaAllocator();
44+
friend CudaAllocator& allocator();
45+
46+
size_t memory_limit_;
47+
size_t active_memory_{0};
48+
size_t peak_memory_{0};
49+
};
50+
51+
CudaAllocator& allocator();
52+
53+
} // namespace mlx::core::mxcuda

0 commit comments

Comments
 (0)