Skip to content

Commit 1d78061

Browse files
committed
Revert "CUDA: use MMQ instead of cuBLAS by default (ggml-org#8075)"
This reverts commit a818f30.
1 parent e575da5 commit 1d78061

File tree

8 files changed

+115
-117
lines changed

8 files changed

+115
-117
lines changed

CMakeLists.txt

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,7 @@ option(LLAMA_LLAMAFILE "llama: use llamafile SGEMM"
102102
option(LLAMA_CUDA "llama: use CUDA" OFF)
103103
option(LLAMA_CUBLAS "llama: use CUDA (deprecated, use LLAMA_CUDA)" OFF)
104104
option(LLAMA_CUDA_FORCE_DMMV "llama: use dmmv instead of mmvq CUDA kernels" OFF)
105-
option(LLAMA_CUDA_FORCE_MMQ "llama: always use mmq kernels instead of cuBLAS" OFF)
106-
option(LLAMA_CUDA_FORCE_CUBLAS "llama: always use cuBLAS instead of mmq kernels" OFF)
105+
option(LLAMA_CUDA_FORCE_MMQ "llama: use mmq kernels instead of cuBLAS" OFF)
107106
set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels")
108107
set(LLAMA_CUDA_MMV_Y "1" CACHE STRING "llama: y block size for mmv CUDA kernels")
109108
option(LLAMA_CUDA_F16 "llama: use 16 bit floats for some calculations" OFF)
@@ -417,14 +416,13 @@ if (LLAMA_CUDA)
417416

418417
if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
419418
# 52 == lowest CUDA 12 standard
420-
# 60 == FP16 CUDA intrinsics
419+
# 60 == f16 CUDA intrinsics
421420
# 61 == integer CUDA intrinsics
422-
# 70 == FP16 tensor cores
423-
# 75 == int8 tensor cores
421+
# 70 == compute capability at which unrolling a loop in mul_mat_q kernels is faster
424422
if (LLAMA_CUDA_F16 OR LLAMA_CUDA_DMMV_F16)
425-
set(CMAKE_CUDA_ARCHITECTURES "60;61;70;75")
423+
set(CMAKE_CUDA_ARCHITECTURES "60;61;70") # needed for f16 CUDA intrinsics
426424
else()
427-
set(CMAKE_CUDA_ARCHITECTURES "52;61;70;75")
425+
set(CMAKE_CUDA_ARCHITECTURES "52;61;70") # lowest CUDA 12 standard + lowest for integer intrinsics
428426
#set(CMAKE_CUDA_ARCHITECTURES "OFF") # use this to compile much faster, but only F16 models work
429427
endif()
430428
endif()
@@ -449,9 +447,6 @@ if (LLAMA_CUDA)
449447
if (LLAMA_CUDA_FORCE_MMQ)
450448
add_compile_definitions(GGML_CUDA_FORCE_MMQ)
451449
endif()
452-
if (LLAMA_CUDA_FORCE_CUBLAS)
453-
add_compile_definitions(GGML_CUDA_FORCE_CUBLAS)
454-
endif()
455450
if (LLAMA_CUDA_NO_VMM)
456451
add_compile_definitions(GGML_CUDA_NO_VMM)
457452
endif()

Makefile

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -537,9 +537,6 @@ endif # LLAMA_CUDA_FORCE_DMMV
537537
ifdef LLAMA_CUDA_FORCE_MMQ
538538
MK_NVCCFLAGS += -DGGML_CUDA_FORCE_MMQ
539539
endif # LLAMA_CUDA_FORCE_MMQ
540-
ifdef LLAMA_CUDA_FORCE_CUBLAS
541-
MK_NVCCFLAGS += -DGGML_CUDA_FORCE_CUBLAS
542-
endif # LLAMA_CUDA_FORCE_CUBLAS
543540
ifdef LLAMA_CUDA_DMMV_X
544541
MK_NVCCFLAGS += -DGGML_CUDA_DMMV_X=$(LLAMA_CUDA_DMMV_X)
545542
else

README.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -510,9 +510,8 @@ Building the program with BLAS support may lead to some performance improvements
510510
|--------------------------------|------------------------|---------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
511511
| LLAMA_CUDA_FORCE_DMMV | Boolean | false | Force the use of dequantization + matrix vector multiplication kernels instead of using kernels that do matrix vector multiplication on quantized data. By default the decision is made based on compute capability (MMVQ for 6.1/Pascal/GTX 1000 or higher). Does not affect k-quants. |
512512
| LLAMA_CUDA_DMMV_X | Positive integer >= 32 | 32 | Number of values in x direction processed by the CUDA dequantization + matrix vector multiplication kernel per iteration. Increasing this value can improve performance on fast GPUs. Power of 2 heavily recommended. Does not affect k-quants. |
513-
| LLAMA_CUDA_MMV_Y | Positive integer | 1 | Block size in y direction for the CUDA mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. |
514-
| LLAMA_CUDA_FORCE_MMQ | Boolean | false | Force the use of custom matrix multiplication kernels for quantized models instead of FP16 cuBLAS even if there is no int8 tensor core implementation available (affects V100, RDNA3). Speed for large batch sizes will be worse but VRAM consumption will be lower. |
515-
| LLAMA_CUDA_FORCE_CUBLAS | Boolean | false | Force the use of FP16 cuBLAS instead of custom matrix multiplication kernels for quantized models |
513+
| LLAMA_CUDA_MMV_Y | Positive integer | 1 | Block size in y direction for the CUDA mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. |
514+
| LLAMA_CUDA_FORCE_MMQ | Boolean | false | Force the use of dequantization + matrix multiplication kernels instead of leveraging Math libraries. | |
516515
| LLAMA_CUDA_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels and for the q4_1 and q5_1 matrix matrix multiplication kernels. Can improve performance on relatively recent GPUs. |
517516
| LLAMA_CUDA_KQUANTS_ITER | 1 or 2 | 2 | Number of values processed per iteration and per CUDA thread for Q2_K and Q6_K quantization formats. Setting this value to 1 can improve performance for slow GPUs. |
518517
| LLAMA_CUDA_PEER_MAX_BATCH_SIZE | Positive integer | 128 | Maximum batch size for which to enable peer access between multiple GPUs. Peer access requires either Linux or NVLink. When using NVLink enabling peer access for larger batch sizes is potentially beneficial. |

ggml-cuda.cu

Lines changed: 62 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -152,16 +152,16 @@ static ggml_cuda_device_info ggml_cuda_init() {
152152
GGML_ASSERT(info.device_count <= GGML_CUDA_MAX_DEVICES);
153153

154154
int64_t total_vram = 0;
155-
#ifdef GGML_CUDA_FORCE_MMQ
156-
GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: yes\n", __func__);
155+
#if defined(GGML_CUDA_FORCE_MMQ)
156+
GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: yes\n", __func__);
157157
#else
158-
GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: no\n", __func__);
159-
#endif // GGML_CUDA_FORCE_MMQ
160-
#ifdef GGML_CUDA_FORCE_CUBLAS
161-
GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: yes\n", __func__);
158+
GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: no\n", __func__);
159+
#endif
160+
#if defined(CUDA_USE_TENSOR_CORES)
161+
GGML_CUDA_LOG_INFO("%s: CUDA_USE_TENSOR_CORES: yes\n", __func__);
162162
#else
163-
GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: no\n", __func__);
164-
#endif // GGML_CUDA_FORCE_CUBLAS
163+
GGML_CUDA_LOG_INFO("%s: CUDA_USE_TENSOR_CORES: no\n", __func__);
164+
#endif
165165
GGML_CUDA_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count);
166166
for (int id = 0; id < info.device_count; ++id) {
167167
int device_vmm = 0;
@@ -1873,17 +1873,9 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
18731873
static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
18741874
const bool split = ggml_backend_buffer_is_cuda_split(src0->buffer);
18751875

1876-
bool use_dequantize_mul_mat_vec = (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16)
1877-
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
1878-
&& src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src1->ne[1] == 1;
1879-
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
1880-
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
1881-
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
1882-
bool use_mul_mat_q = ggml_is_quantized(src0->type)
1883-
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
1884-
1885-
bool any_gpus_with_slow_fp16 = false;
1876+
int64_t min_compute_capability = INT_MAX;
18861877

1878+
bool any_pascal_with_slow_fp16 = false;
18871879
if (split) {
18881880
ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context;
18891881
auto & tensor_split = buft_ctx->tensor_split;
@@ -1893,18 +1885,55 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
18931885
continue;
18941886
}
18951887

1896-
const int cc = ggml_cuda_info().devices[id].cc;
1897-
use_mul_mat_vec_q = use_mul_mat_vec_q && cc >= MIN_CC_DP4A;
1898-
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
1899-
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc);
1888+
if (min_compute_capability > ggml_cuda_info().devices[id].cc) {
1889+
min_compute_capability = ggml_cuda_info().devices[id].cc;
1890+
}
1891+
if (ggml_cuda_info().devices[id].cc == 610) {
1892+
any_pascal_with_slow_fp16 = true;
1893+
}
19001894
}
19011895
} else {
1902-
const int cc = ggml_cuda_info().devices[ctx.device].cc;
1903-
use_mul_mat_vec_q = use_mul_mat_vec_q && cc >= MIN_CC_DP4A;
1904-
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
1905-
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc);
1896+
min_compute_capability = ggml_cuda_info().devices[ctx.device].cc;
1897+
any_pascal_with_slow_fp16 = ggml_cuda_info().devices[ctx.device].cc == 610;
19061898
}
19071899

1900+
// check data types and tensor shapes for custom matrix multiplication kernels:
1901+
bool use_dequantize_mul_mat_vec = (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16)
1902+
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
1903+
&& src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src1->ne[1] == 1;
1904+
1905+
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
1906+
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
1907+
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
1908+
1909+
bool use_mul_mat_q = ggml_cuda_supports_mmq(src0->type)
1910+
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
1911+
1912+
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
1913+
1914+
const bool fp16_performance_good = min_compute_capability >= CC_RDNA1;
1915+
1916+
#ifdef CUDA_USE_TENSOR_CORES
1917+
use_mul_mat_q = use_mul_mat_q && min_compute_capability < CC_RDNA3;
1918+
#endif // CUDA_USE_TENSOR_CORES
1919+
1920+
#else
1921+
1922+
// fp16 performance is good on Volta or newer and on P100 (compute capability 6.0)
1923+
const bool fp16_performance_good = min_compute_capability >= CC_PASCAL && !any_pascal_with_slow_fp16;
1924+
1925+
// mmvq and mmq need the __dp4a instruction which on NVIDIA is only available for CC >= 6.1
1926+
use_mul_mat_vec_q = use_mul_mat_vec_q && min_compute_capability >= MIN_CC_DP4A;
1927+
use_mul_mat_q = use_mul_mat_q && min_compute_capability >= MIN_CC_DP4A;
1928+
1929+
#ifdef CUDA_USE_TENSOR_CORES
1930+
// when tensor cores are available, use them for large batch size
1931+
// ref: https://github.com/ggerganov/llama.cpp/pull/3776
1932+
use_mul_mat_q = use_mul_mat_q && (!fp16_performance_good || src1->ne[1] <= MMQ_MAX_BATCH_SIZE);
1933+
#endif // CUDA_USE_TENSOR_CORES
1934+
1935+
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
1936+
19081937
// if mmvq is available it's a better choice than dmmv:
19091938
#ifndef GGML_CUDA_FORCE_DMMV
19101939
use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
@@ -1918,22 +1947,21 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
19181947
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
19191948
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
19201949

1921-
if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
1922-
// FP32 precision KQ single-batch for batch size 1 without FlashAttention
1950+
if (!split && !fp16_performance_good && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
1951+
// KQ single-batch
19231952
ggml_cuda_mul_mat_vec_p021(ctx, src0, src1, dst);
1924-
} else if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
1925-
// FP32 precision KQV single-batch for batch size 1 without FlashAttention
1953+
} else if (!split && !fp16_performance_good && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
1954+
// KQV single-batch
19261955
ggml_cuda_mul_mat_vec_nc(ctx, src0, src1, dst);
1956+
} else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || fp16_performance_good) && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
1957+
// KQ + KQV multi-batch
1958+
ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
19271959
} else if (use_dequantize_mul_mat_vec) {
19281960
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, nullptr);
19291961
} else if (use_mul_mat_vec_q) {
19301962
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda);
19311963
} else if (use_mul_mat_q) {
19321964
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_q, quantize_mmq_q8_1_cuda);
1933-
} else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16)
1934-
&& !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
1935-
// KQ + KQV multi-batch without FlashAttention
1936-
ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
19371965
} else {
19381966
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, nullptr);
19391967
}

ggml-cuda/common.cuh

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,23 @@
146146
#define CC_RDNA2 (CC_OFFSET_AMD + 1030)
147147
#define CC_RDNA3 (CC_OFFSET_AMD + 1100)
148148

149+
// define this if you want to always fallback to MMQ kernels and not use cuBLAS for matrix multiplication
150+
// on modern hardware, using cuBLAS is recommended as it utilizes F16 tensor cores which are very performant
151+
// for large computational tasks. the drawback is that this requires some extra amount of VRAM:
152+
// - 7B quantum model: +100-200 MB
153+
// - 13B quantum model: +200-400 MB
154+
//
155+
//#define GGML_CUDA_FORCE_MMQ
156+
157+
// TODO: improve this to be correct for more hardware
158+
// for example, currently fails for GeForce GTX 1660 which is TURING arch (> VOLTA) but does not have tensor cores
159+
#if !defined(GGML_CUDA_FORCE_MMQ)
160+
#define CUDA_USE_TENSOR_CORES
161+
#endif
162+
163+
#define MMVQ_MAX_BATCH_SIZE 8 // max batch size to use MMVQ kernels
164+
#define MMQ_MAX_BATCH_SIZE 64 // max batch size to use MMQ kernels when tensor cores are available
165+
149166
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
150167

151168
#if defined(_MSC_VER)
@@ -326,15 +343,15 @@ static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int
326343
#define INT8_MMA_AVAILABLE
327344
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
328345

329-
static constexpr bool fast_fp16_available(const int cc) {
346+
static bool fast_fp16_available(const int cc) {
330347
return cc >= CC_PASCAL && cc != 610;
331348
}
332349

333-
static constexpr bool fp16_mma_available(const int cc) {
350+
static bool fp16_mma_available(const int cc) {
334351
return cc < CC_OFFSET_AMD && cc >= CC_VOLTA;
335352
}
336353

337-
static constexpr bool int8_mma_available(const int cc) {
354+
static bool int8_mma_available(const int cc) {
338355
return cc < CC_OFFSET_AMD && cc >= CC_TURING;
339356
}
340357

@@ -626,6 +643,19 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ3_S> {
626643
static constexpr int qi = QI3_S;
627644
};
628645

646+
static constexpr int get_mmq_x_max_host(int cc) {
647+
#ifdef CUDA_USE_TENSOR_CORES
648+
return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? MMQ_MAX_BATCH_SIZE : 64;
649+
#else
650+
return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? 128 : 64;
651+
#endif // CUDA_USE_TENSOR_CORES
652+
}
653+
654+
// Round rows to this value for --split-mode row:
655+
static constexpr int get_mmq_y_host(int cc) {
656+
return cc >= CC_VOLTA ? 128 : 64;
657+
}
658+
629659
//////////////////////
630660

631661
struct ggml_cuda_device_info {

ggml-cuda/mmq.cu

Lines changed: 3 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,7 @@ void ggml_cuda_op_mul_mat_q(
6969
GGML_UNUSED(src1_ddf_i);
7070
}
7171

72-
bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
73-
#ifdef GGML_CUDA_FORCE_CUBLAS
74-
return false;
75-
#endif // GGML_CUDA_FORCE_CUBLAS
76-
77-
bool mmq_supported;
78-
72+
bool ggml_cuda_supports_mmq(enum ggml_type type) {
7973
switch (type) {
8074
case GGML_TYPE_Q4_0:
8175
case GGML_TYPE_Q4_1:
@@ -87,32 +81,8 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
8781
case GGML_TYPE_Q4_K:
8882
case GGML_TYPE_Q5_K:
8983
case GGML_TYPE_Q6_K:
90-
mmq_supported = true;
91-
break;
84+
return true;
9285
default:
93-
mmq_supported = false;
94-
break;
95-
}
96-
97-
if (!mmq_supported) {
98-
return false;
99-
}
100-
101-
if (int8_mma_available(cc)) {
102-
return true;
103-
}
104-
105-
if (cc < MIN_CC_DP4A) {
106-
return false;
86+
return false;
10787
}
108-
109-
#ifdef GGML_CUDA_FORCE_MMQ
110-
return true;
111-
#endif //GGML_CUDA_FORCE_MMQ
112-
113-
if (cc < CC_OFFSET_AMD) {
114-
return cc < CC_VOLTA || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
115-
}
116-
117-
return cc < CC_RDNA3 || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
11888
}

0 commit comments

Comments
 (0)