Skip to content

[WIP] CUDA backend #1983

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 11 commits into from
Closed

[WIP] CUDA backend #1983

wants to merge 11 commits into from

Conversation

zcbenz
Copy link
Contributor

@zcbenz zcbenz commented Mar 21, 2025

This PR is an ongoing effort to add a CUDA backend to MLX, very little things work now but you can run the tutorial example already.

To build and test:

$ cmake . -Bbuild -DMLX_BUILD_CUDA=ON -DMLX_BUILD_EXAMPLES=ON
$ cmake --build build -j 16
$ ./build/examples/cpp/tutorial
array([[2, 3],
       [4, 5]], dtype=float32)
array([[1, 1],
       [1, 1]], dtype=float32)

For development I usually use:

$ cmake . -Bbuild -DMLX_BUILD_CUDA=ON -DMLX_BUILD_EXAMPLES=ON -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache -DCMAKE_CUDA_COMPILER_LAUNCHER=ccache -DCMAKE_BUILD_TYPE=Debug -GNinja

Only tested on a Ubuntu 22.04 with CUDA 11.6, in theory other environments can also work but there are no testings.

This PR is not updated frequently, if anyone is interested in the realtime development, please check my forked repo.


There are mainly 2 reasons for a CUDA backend:

  • CUDA supports unified memory. Including hardware support in some devices, and software support for devices without hardware unified memory.
  • NVIDIA hardware is widely used for academic and massive computations. Being able to write/test code locally on a Mac and then deploy to super computers would make a good developer experience.

This work is sponsored by Apple.

@radudiaconu0
Copy link

radudiaconu0 commented Mar 22, 2025

I wanna add rocm support based on your cuda pull request. would that be ok with you?
@zcbenz

@awni
Copy link
Member

awni commented Mar 22, 2025

Awesome progress so far @zcbenz !!

I'm wondering what the best way to get this incorporated into MLX. I can think of a couple of options:

  • Once this is ready we can make this into a cuda branch in MLX and then send PRs to it. This will make it easier from a review / PR management standpoint
  • Just merge the backbone infra for supporting CUDA and send more incremental PRs over time

I kind of prefer the latter.. but I'm open to suggestions.

@zcbenz
Copy link
Contributor Author

zcbenz commented Mar 22, 2025

I wanna add rocm support based on your cuda pull request. would that be ok with you?

@radudiaconu0 Of course I'm ok with it!

Before you begin, you might want to decide how the ROCm backend lives together with CUDA backend first. I'm not familiar with ROCm, but I saw 2 patterns in projects with both backends:

  1. Both backends share the same code, with help of #defines and name aliases.
  2. Transpile CUDA code to HIP on the fly during build time, which is used by PyTorch.

Another thing to notice is this PR is bound to heavy changes in following weeks, I'm still experimenting what is the best interface for integration.

@angeloskath
Copy link
Member

Awesome progress indeed!

Just chiming in regarding the best way to incorporate this. Imho merging often is the way to go (option 2 basically). Combined with running CUDA tests in CI it will be the easiest to live with (since we 'll know when we break it even if we don't use it). Otherwise the cuda branch would have to be constantly rebased on top of main which could be annoying.

@radudiaconu0
Copy link

I wanna add rocm support based on your cuda pull request. would that be ok with you?

@radudiaconu0 Of course I'm ok with it!

Before you begin, you might want to decide how the ROCm backend lives together with CUDA backend first. I'm not familiar with ROCm, but I saw 2 patterns in projects with both backends:

  1. Both backends share the same code, with help of #defines and name aliases.

  2. Transpile CUDA code to HIP on the fly during build time, which is used by PyTorch.

Another thing to notice is this PR is bound to heavy changes in following weeks, I'm still experimenting what is the best interface for integration.

I would try to make a separate hip folder or to use hipify on your CUDA code to make it use rocm/hip

@zcbenz
Copy link
Contributor Author

zcbenz commented Mar 23, 2025

I'm wondering what the best way to get this incorporated into MLX.

I find myself keep refactoring the code when porting new kernels, I think I still need to implement a few more primitives before getting the backbone code stable, probably a few more weeks of experimenting.

Once the code is ready for review, I can split this PR into a backbone PR, and a few small PRs for each primitive. And future works would then be submitted in incremental PRs.

@zcbenz
Copy link
Contributor Author

zcbenz commented Mar 24, 2025

In CUDA the kernel parameters' size must be known at compile-time, i.e. we can't pass dynamic-sized shape/strides via constant memory like what the Metal kernels do.

I'm currently passing shape/strides to kernels via fixed-size cuda::std::array, which is what PyTorch has been doing. This comes with a limitation of maximum ndim in arrays, which PyTorch sets to 25, I'm using 8 for now and it can be easily changed if found not enough.

@awni
Copy link
Member

awni commented Mar 24, 2025

This comes with a limitation of maximum ndim in arrays, which PyTorch sets to 25, I'm using 8 for now and it can be easily changed if found not enough.

Sounds great! As long as we can change it by setting one number somewhere I think that's perfectly fine.

@zcbenz
Copy link
Contributor Author

zcbenz commented Apr 5, 2025

The C++ logistic_regression example can run now:

$ ./build/examples/cpp/logistic_regression 
Loss array(0.0344943, dtype=float32), Accuracy, array(1, dtype=float32), Throughput 518.05 (it/s).

@zcbenz
Copy link
Contributor Author

zcbenz commented Apr 7, 2025

The logistic_regression is slow, I did some profiling.


step

Each step takes about 2ms (i.e. 500 it/s), and each step consists of:

  1. Graph building (the MLX calls before invoking eval, which is the empty area before eval_impl and Event::wait).
  2. Kernel launching (the eval_impl part).
  3. Waiting for the results of kernels (the Event::wait part).

We can see that we spent as much time launching kernels as waiting for the results.


eval

Looking closer at the kernel launching part, between each eval_gpu calls there is a very long Event::is_signaled call, which is an atomic read under the hood and we want to cut down its time a lot.

Inside the eval_gpu call, we have some cudaMalloc calls (the red blocks) that can be removed by introducing buffer cache in future.


CUDA HW

Then look at the "CUDA HW" panel, which indicates that kernel running time is the same with eval_gpu, which likely means that the kernel is executed synchronously instead of asynchronously, which I need to check what went wrong.

There are very large paddings between the kernels, some of them belong to ops that do not need to launch kernels (like broadcast), some of them are the slow Event::is_signaled calls that need to be improved. And once the kernels run asynchronously, there ought to be no paddings between them then.


Screenshot 2025-04-07 at 11 04 43

Finally there are long paddings between the eval_impls, which is the main thread waiting for the finish signal from the launched kernels, and it is really really slow compared to actual kernel running time. I think I need a reimplementation of Event to cut it down, current implementation uses cuda::std::atomic which seems very inefficient.


Overall the overhead does not look very big: it is only about 2ms per step, and it should be a fixed number as it is not related to the size of arrays. However in the case of logistic_regression as the computation itself is very fast, the overhead dramatically slows things down.

@zcbenz
Copy link
Contributor Author

zcbenz commented Apr 7, 2025

Then look at the "CUDA HW" panel, which indicates that kernel running time is the same with eval_gpu, which likely means that the kernel is executed synchronously instead of asynchronously, which I need to check what went wrong.

After a closer look, I think the "NVTX" row under "CUDA HW" only means to mark the event that started the kernel, and it does not indicate the event started at the same time with kernel. The kernels are started asynchronously.

@zcbenz
Copy link
Contributor Author

zcbenz commented Apr 7, 2025

So in the case of logistic_regression where the computation takes less time than overhead, how fast it runs depends on how fast we can push ops to CUDA stream.

For PyTorch the time duration between 2 simple ops is 5µs:

5µs

And for us it is at least 41µs:

41µs


There are a few things I can do to reduce the overhead:

  1. Reimplement Event with cudaEvent, which should be the fasted op provided by CUDA.
  2. Add buffer cache to reduce cudaMalloc calls.
  3. Record the locations of buffers to avoid unnecessary cudaMemPrefetch calls.
  4. Batch the cleanup of temporary arrays to reduce the latency between 2 kernels.

At last there is still a good news though: the time spent on running kernel is the same with PyTorch, which means we don't need to improve the kernel implementation.

@awni
Copy link
Member

awni commented Apr 7, 2025

some of them are the slow Event::is_signaled calls that need to be improved.

Where are those calls coming from? Is it here? We might be able to reduce the number of times we call that if needed.. I'm not actually certain it needs to be there. I think it's just a mechanism to eagerly clean up unused events.

@zcbenz
Copy link
Contributor Author

zcbenz commented Apr 7, 2025

Where are those calls coming from? Is it here?

Yes it is where the calls came from.

Removing it would be great, I think it is an expensive op on all platforms.

@zcbenz
Copy link
Contributor Author

zcbenz commented Apr 8, 2025

Tried the ideas: switching the implementation of Event from cuda::std::atomic to cudaEvent bumped training speed from 500 it/s to 900; reducing the prefetch calls increased it from 900 it/s to 1100.


The next optimization is tricky: after evaluating each op, the operands and temporaries are saved until kernel finishes running, in Metal it is done like this:

if (d.command_buffer_needs_commit(s.index)) {
d.end_encoding(s.index);
scheduler::notify_new_task(s);
command_buffer->addCompletedHandler(
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
scheduler::notify_task_completion(s);
check_error(cbuf);
});
d.commit_command_buffer(s.index);
d.get_command_buffer(s.index);
} else {
command_buffer->addCompletedHandler(
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
check_error(cbuf);
});
}
}

In CUDA there is a cudaLaunchHostFunc API that was used to implement this. However according to the profiling it adds at least 20µs latency in cuda stream, which means each kernel has to wait at least 20µs before running.

To get rid of this latency, I improved the CUDA backend by saving operands and temporaries of the op until finalize() is called, i.e. when mx::eval_impl() finishes running. In this way the cudaLaunchHostFunc is only called once per mx::eval(), instead of once per op::eval_gpu(). And the duration between 2 kernels is now under 1µs, which is better than PyTorch and I believe it is the best we can do.

duration

The downside is the arrays take longer to be destroyed, which could increase memory usages. The code also no longer waits if there are more tasks than MAX_ACTIVE_TASKS.

After this optimization the speed increased from 1100 it/s to 1600.


There were still many kernels that get unusually delayed.

prefetch

What did the delayed kernels have in common? Before launching the kernel they all called an API: cudaMemPrefetch (the green blocks).

In the CUDA backend we use the unified memory APIs, which automatically transfers data between host and device, since I know the data is going to be used in GPU, I used the cudaMemPrefetch API to prefetch the memory in device so the kernel does not have to wait for the implicit memory transfer during execution.

It turns out the prefetching heavily delayed the kernel executions.

Removing prefetching increased speed from 1600 it/s to 2100, and we now have a really beautiful timeline in profiler.

noprefetch


One optimization I haven't done yet is buffer cache: I will add it when most ops are implemented and there is not no more third party libraries to be integrated.


Can we do better? The remaining are mostly hard work: optimize the kernels and make CPU code run faster, which I think should be visited after we have implemented all ops.

@awni
Copy link
Member

awni commented Apr 8, 2025

Very nice @zcbenz !

To get rid of this latency, I improved the CUDA backend by saving operands and temporaries of the op until finalize() is called, i.e. when mx::eval_impl() finishes running.

That one is a bit concerning. For large graphs it can really blow up memory use if you hold the temporaries until the end of the graph eval. I don't think it's worth doing that. We might want to do something in-between like saving them once ever ~10 calls to eval_gpu or something like that.

@zcbenz
Copy link
Contributor Author

zcbenz commented Apr 9, 2025

That one is a bit concerning. For large graphs it can really blow up memory use if you hold the temporaries until the end of the graph eval. I don't think it's worth doing that

I agree, this feels like an immature optimization for a special case. I'll make the behavior easy to configure so we can optimize for more cases in future.

@zcbenz zcbenz closed this Apr 9, 2025
@zcbenz zcbenz reopened this Apr 9, 2025
@zcbenz zcbenz mentioned this pull request Apr 13, 2025
@zcbenz zcbenz force-pushed the cuda branch 2 times, most recently from c3ce6e2 to 582817b Compare April 25, 2025 11:14
@corupta
Copy link

corupta commented Jun 4, 2025

Hi @zcbenz,
NVIDIA Jetson devices have hardware unified memory, and I have both a Xavier (sm72) AGX 32GB, and an Orin (sm87) AGX 64GB.
I have tried building your cuda fork frost-beta/mlx-cuda, on Orin. (needed to change supported sm to include sm87, and also needed to change libcublasLt_static to libcublasLt, otherwise I'd get relocation truncated to fit: R_AARCH64_CALL26 against symbol) It takes about 3 hours for a full build. When I tried to run examples, I get the following.

root@54712748d876:/opt/mlx# ./build/examples/cpp/tutorial
terminate called after throwing an instance of 'std::runtime_error'
  what():  Device 0 does not support synchronization in managed memory.
Aborted (core dumped)
root@54712748d876:/opt/mlx# ./build/examples/cpp/logistic_regression
terminate called after throwing an instance of 'std::runtime_error'
  what():  Device 0 does not support synchronization in managed memory.
Aborted (core dumped)
root@54712748d876:/opt/mlx# ./build/examples/cpp/metal_capture
terminate called after throwing an instance of 'std::runtime_error'
  what():  Device 0 does not support synchronization in managed memory.

I'm willing to collaborate, if you're interested, I can run builds on Jetson devices and report back and try tinkering to some degree (I'm not an expert in CUDA)

@zcbenz
Copy link
Contributor Author

zcbenz commented Jun 4, 2025

@corupta Can you remove the throw statement in Device::Device in mlx/backend/cuda/device.cpp and check if the example runs?

On the slow compilation, you can pass -DMLX_FAST_COMPILE=ON -DMLX_CUDA_ARCHITECTURES=native to cmake to speed up a lot (by disabling many things). I'm currently working on JIT compilation support which will solve this.

@corupta
Copy link

corupta commented Jun 5, 2025

tldr: it gives segfault
With Fast Compile on, it takes about 15 mins to build. Also, libcublasLt_static doesn't cause a problem when fast compile is on. I've tried building it in various settings (frost-beta/mlx-cuda in both today's latest commit and yesterday's, with/without libcublasLt set as dynamic) all resulted in below for all the example files: (replaced throw with a std::cerr statement)

Device 0 does not support synchronization in managed memory. Ignoring...
Segmentation fault (core dumped)

Some info about environments I use:
Jetson Orin => Ubuntu 22.04 Cuda 12.6 native or Ubuntu 24.04 Cuda 12.8 in docker
Jetson Xavier => Ubuntu 20.04 Cuda 11.4 native or Ubuntu 20.04 Cuda 12.2 or Cuda 11.8 in docker
(by docker I mean using the nvidia patched docker runtime within these devices, so docker should expose gpu without any issue).
Same result in orin native/docker, xavier docker cuda 12.2. xavier native gives a build error related to __grid_constant__ being not defined and lots of others, I think it is because of the cuda version 11.4 being too low.
xavier docker cuda 11.8 gave build errors such as (similar errors in different files for different operators such as +, >, etc)

/opt/mlx/mlx/backend/cuda/kernels/reduce_ops.cuh(125): error: more than one instance of constructor "__nv_bfloat16::__nv_bfloat16" matches the argument list:
            function "__nv_bfloat16::__nv_bfloat16(float)"
/usr/local/cuda/targets/aarch64-linux/include/cuda_bf16.hpp(174): here
            function "__nv_bfloat16::__nv_bfloat16(double)"
/usr/local/cuda/targets/aarch64-linux/include/cuda_bf16.hpp(175): here
            argument types are: (int)
          detected during instantiation of "auto mlx::core::cu::ReduceInit<mlx::core::cu::Prod, T>::value() [with T=__nv_bfloat16]"
/opt/mlx/mlx/backend/cuda/reduce/segmented_reduce.cu(53): here

/opt/mlx/mlx/backend/cuda/kernels/utils.cuh(66): error: more than one operator "-" matches these operands:
            built-in operator "- arithmetic"
            function "mlx::core::operator-(const mlx::core::complex64_t &)"
/opt/mlx/mlx/types/complex.h(77): here
            operand types are: - __nv_bfloat16
          detected during:
            instantiation of "T mlx::core::cu::Limits<T, cuda::std::__4::enable_if_t<<expression>, void>>::min() [with T=__nv_bfloat16]"
/opt/mlx/mlx/backend/cuda/kernels/reduce_ops.cuh(140): here
            instantiation of "T mlx::core::cu::ReduceInit<mlx::core::cu::Max, T>::value() [with T=__nv_bfloat16]"
/opt/mlx/mlx/backend/cuda/reduce/segmented_reduce.cu(53): here

@zcbenz
Copy link
Contributor Author

zcbenz commented Jun 5, 2025

Thanks for testing the build, unfortunately this kind of error requires me to work the actual environment to debug. Currently I'm only testing on a few cloud environments, but I will look into making it work on Jetson once most development is done, the devices with hardware unified memory definitely need first class support.

@zcbenz
Copy link
Contributor Author

zcbenz commented Jun 12, 2025

I'm closing this as it has been split into smaller pieces (check the linked PR above), future changes to CUDA backend will be submitted with incremental PRs.

@zcbenz zcbenz closed this Jun 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants