From e89597c062e7119bcfbca60585563272d128a1c4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 29 Nov 2023 12:44:47 +0200 Subject: [PATCH 01/14] metal : implement soft_max_ext --- ggml-metal.m | 20 +++++++++++++++----- ggml-metal.metal | 26 ++++++++++++++++---------- ggml.c | 18 ++++++++++++++++-- ggml.h | 8 ++++++++ llama.cpp | 32 +++++++++++++++++++------------- 5 files changed, 74 insertions(+), 30 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index d52a1c3c48210..0b468bea027a4 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -1036,11 +1036,21 @@ void ggml_metal_graph_compute( nth /= 2; [encoder setComputePipelineState:ctx->pipeline_soft_max]; } - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; + + const float scale = ((float *) dst->op_params)[0]; + + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + if (id_src1) { + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + } else { + [encoder setBuffer:nil offset:0 atIndex:1]; + } + + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; + [encoder setBytes:&scale length:sizeof(scale) atIndex:6]; [encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; diff --git a/ggml-metal.metal b/ggml-metal.metal index 5d1357cd72d45..61eb0c403a0f3 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -180,10 +180,12 @@ kernel void kernel_gelu( kernel void kernel_soft_max( device const float * src0, + device const float * src1, device float * dst, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, + constant float & scale, threadgroup float * buf [[threadgroup(0)]], uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], @@ -194,14 +196,15 @@ kernel void kernel_soft_max( const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); - device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + device const float * pmask = src1 ? src1 + i01*ne00 : nullptr; + device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; // parallel max - float lmax = tpitg < ne00 ? psrc0[tpitg] : -INFINITY; + float lmax = (tpitg < ne00) ? (psrc0[tpitg]*scale + (pmask ? pmask[tpitg] : 0.0f)) : -INFINITY; for (int i00 = tpitg + ntg; i00 < ne00; i00 += ntg) { - lmax = MAX(lmax, psrc0[i00]); + lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)); } float max = simd_max(lmax); @@ -225,7 +228,7 @@ kernel void kernel_soft_max( // parallel sum float lsum = 0.0f; for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - const float exp_psrc0 = exp(psrc0[i00] - max); + const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max); lsum += exp_psrc0; // Remember the result of exp here. exp is expensive, so we really do not // wish to compute it twice. @@ -257,10 +260,12 @@ kernel void kernel_soft_max( kernel void kernel_soft_max_4( device const float * src0, + device const float * src1, device float * dst, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, + constant float & scale, threadgroup float * buf [[threadgroup(0)]], uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], @@ -271,14 +276,15 @@ kernel void kernel_soft_max_4( const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); - device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + device const float4 * pmask = src1 ? (device const float4 *)(src1 + i01*ne00) : nullptr; + device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); // parallel max - float4 lmax4 = tpitg < ne00/4 ? psrc4[tpitg] : -INFINITY; + float4 lmax4 = tpitg < ne00/4 ? (psrc4[tpitg]*scale + (pmask ? pmask[tpitg] : 0.0f)) : -INFINITY; for (int i00 = tpitg + ntg; i00 < ne00/4; i00 += ntg) { - lmax4 = fmax(lmax4, psrc4[i00]); + lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)); } const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); @@ -303,7 +309,7 @@ kernel void kernel_soft_max_4( // parallel sum float4 lsum4 = 0.0f; for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - const float4 exp_psrc4 = exp(psrc4[i00] - max); + const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max); lsum4 += exp_psrc4; pdst4[i00] = exp_psrc4; } diff --git a/ggml.c b/ggml.c index c522a101f1552..a0b04cbeb4f67 100644 --- a/ggml.c +++ b/ggml.c @@ -4826,6 +4826,8 @@ struct ggml_tensor * ggml_diag_mask_zero_inplace( static struct ggml_tensor * ggml_soft_max_impl( struct ggml_context * ctx, struct ggml_tensor * a, + struct ggml_tensor * mask, + float scale, bool inplace) { bool is_node = false; @@ -4835,9 +4837,13 @@ static struct ggml_tensor * ggml_soft_max_impl( struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + float params[] = { scale }; + ggml_set_op_params(result, params, sizeof(params)); + result->op = GGML_OP_SOFT_MAX; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; + result->src[1] = mask; return result; } @@ -4845,13 +4851,21 @@ static struct ggml_tensor * ggml_soft_max_impl( struct ggml_tensor * ggml_soft_max( struct ggml_context * ctx, struct ggml_tensor * a) { - return ggml_soft_max_impl(ctx, a, false); + return ggml_soft_max_impl(ctx, a, NULL, 1.0f, false); } struct ggml_tensor * ggml_soft_max_inplace( struct ggml_context * ctx, struct ggml_tensor * a) { - return ggml_soft_max_impl(ctx, a, true); + return ggml_soft_max_impl(ctx, a, NULL, 1.0f, true); +} + +struct ggml_tensor * ggml_soft_max_ext( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * mask, + float scale) { + return ggml_soft_max_impl(ctx, a, mask, scale, false); } // ggml_soft_max_back diff --git a/ggml.h b/ggml.h index 4d6d4edfd933c..2f6787d4e4219 100644 --- a/ggml.h +++ b/ggml.h @@ -1282,6 +1282,14 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + // fused soft_max(a*scale + mask) + // mask is optional + GGML_API struct ggml_tensor * ggml_soft_max_ext( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * mask, + float scale); + GGML_API struct ggml_tensor * ggml_soft_max_back( struct ggml_context * ctx, struct ggml_tensor * a, diff --git a/llama.cpp b/llama.cpp index cb544228b9f02..ba837e26ff1c2 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3705,22 +3705,28 @@ static struct ggml_tensor * llm_build_kqv( struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); cb(kq, "kq", il); - kq = ggml_scale(ctx, kq, kq_scale); - cb(kq, "kq_scaled", il); - + // TODO: !!!!!!!!! if (max_alibi_bias > 0.0f) { - // TODO: n_head or n_head_kv - // TODO: K-shift is likely not working - // TODO: change to ggml_add - kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, max_alibi_bias); - cb(kq, "kq_scaled_alibi", il); - } + kq = ggml_scale(ctx, kq, kq_scale); + cb(kq, "kq_scaled", il); + + if (max_alibi_bias > 0.0f) { + // TODO: n_head or n_head_kv + // TODO: K-shift is likely not working + // TODO: change to ggml_add + kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, max_alibi_bias); + cb(kq, "kq_scaled_alibi", il); + } - kq = ggml_add(ctx, kq, kq_mask); - cb(kq, "kq_masked", il); + kq = ggml_add(ctx, kq, kq_mask); + cb(kq, "kq_masked", il); - kq = ggml_soft_max(ctx, kq); - cb(kq, "kq_soft_max", il); + kq = ggml_soft_max(ctx, kq); + cb(kq, "kq_soft_max", il); + } else { + kq = ggml_soft_max_ext(ctx, kq, kq_mask, 1.0f/sqrtf(float(n_embd_head))); + cb(kq, "kq_soft_max_ext", il); + } // split cached v into n_head heads struct ggml_tensor * v = From 88519fbf97abcbf6b23de1027f4d2ac76cf50166 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 29 Nov 2023 15:34:20 +0200 Subject: [PATCH 02/14] cuda : implement soft_max_ext --- ggml-cuda.cu | 35 +++++++++++++++++++++-------------- ggml.c | 6 ++++++ llama.cpp | 1 + 3 files changed, 28 insertions(+), 14 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 5b80e4ae31329..628f2dcbcd60b 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -4719,16 +4719,18 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int // the CUDA soft max implementation differs from the CPU implementation // instead of doubles floats are used -static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) { - const int row = blockDim.x*blockIdx.x + threadIdx.x; +static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) { + const int rowx = blockDim.x*blockIdx.x + threadIdx.x; + const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension const int block_size = blockDim.y; const int tid = threadIdx.y; float max_val = -INFINITY; for (int col = tid; col < ncols; col += block_size) { - const int i = row*ncols + col; - max_val = max(max_val, x[i]); + const int ix = rowx*ncols + col; + const int iy = rowy*ncols + col; + max_val = max(max_val, x[ix]*scale + (y ? y[iy] : 0.0f)); } // find the max value in the block @@ -4740,10 +4742,11 @@ static __global__ void soft_max_f32(const float * x, float * dst, const int ncol float tmp = 0.f; for (int col = tid; col < ncols; col += block_size) { - const int i = row*ncols + col; - const float val = expf(x[i] - max_val); + const int ix = rowx*ncols + col; + const int iy = rowy*ncols + col; + const float val = expf((x[ix]*scale + (y ? y[iy] : 0.0f)) - max_val); tmp += val; - dst[i] = val; + dst[ix] = val; } // sum up partial sums @@ -4755,7 +4758,7 @@ static __global__ void soft_max_f32(const float * x, float * dst, const int ncol const float inv_tmp = 1.f / tmp; for (int col = tid; col < ncols; col += block_size) { - const int i = row*ncols + col; + const int i = rowx*ncols + col; dst[i] *= inv_tmp; } } @@ -5792,10 +5795,10 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols diag_mask_inf_f32<<>>(x, dst, ncols_x, rows_per_channel, n_past); } -static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) { +static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) { const dim3 block_dims(1, WARP_SIZE, 1); const dim3 block_nums(nrows_x, 1, 1); - soft_max_f32<<>>(x, dst, ncols_x); + soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); } static void im2col_f32_f16_cuda(const float * x, half * dst, @@ -6846,14 +6849,18 @@ inline void ggml_cuda_op_soft_max( GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional + const int64_t ne00 = src0->ne[0]; - const int64_t nrows = ggml_nrows(src0); + const int64_t nrows_x = ggml_nrows(src0); + const int64_t nrows_y = src1 ? ggml_nrows(src1) : 0; - soft_max_f32_cuda(src0_dd, dst_dd, ne00, nrows, main_stream); + float scale = 1.0f; + memcpy(&scale, dst->op_params, sizeof(float)); + + soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream); - (void) src1; (void) dst; - (void) src1_dd; } inline void ggml_cuda_op_scale( diff --git a/ggml.c b/ggml.c index a0b04cbeb4f67..788cabd84504f 100644 --- a/ggml.c +++ b/ggml.c @@ -4829,6 +4829,12 @@ static struct ggml_tensor * ggml_soft_max_impl( struct ggml_tensor * mask, float scale, bool inplace) { + if (mask) { + GGML_ASSERT(mask->ne[2] == 1); + GGML_ASSERT(mask->ne[3] == 1); + GGML_ASSERT(ggml_can_repeat_rows(mask, a)); + } + bool is_node = false; if (a->grad) { diff --git a/llama.cpp b/llama.cpp index ba837e26ff1c2..2c13aeb5091c5 100644 --- a/llama.cpp +++ b/llama.cpp @@ -5048,6 +5048,7 @@ static const std::unordered_map k_offload_map { "kq_scaled_alibi", OFFLOAD_FUNC_KQ }, { "kq_masked", OFFLOAD_FUNC_KQ }, { "kq_soft_max", OFFLOAD_FUNC_V }, + { "kq_soft_max_ext", OFFLOAD_FUNC_V }, { "v", OFFLOAD_FUNC_V }, { "kqv", OFFLOAD_FUNC_V }, { "kqv_merged", OFFLOAD_FUNC_V }, From 6a66f69f9f26e0d4c1d016d3c9d37a5422072a5b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 29 Nov 2023 17:07:07 +0200 Subject: [PATCH 03/14] ggml : implement soft_max_ext (CPU) --- ggml.c | 52 ++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 38 insertions(+), 14 deletions(-) diff --git a/ggml.c b/ggml.c index 788cabd84504f..9dc4678cbd894 100644 --- a/ggml.c +++ b/ggml.c @@ -4829,7 +4829,9 @@ static struct ggml_tensor * ggml_soft_max_impl( struct ggml_tensor * mask, float scale, bool inplace) { + GGML_ASSERT(ggml_is_contiguous(a)); if (mask) { + GGML_ASSERT(ggml_is_contiguous(mask)); GGML_ASSERT(mask->ne[2] == 1); GGML_ASSERT(mask->ne[3] == 1); GGML_ASSERT(ggml_can_repeat_rows(mask, a)); @@ -10571,20 +10573,25 @@ static void ggml_compute_forward_diag_mask_zero( static void ggml_compute_forward_soft_max_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(dst)); - GGML_ASSERT(ggml_are_same_shape(src0, dst)); + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + assert(ggml_is_contiguous(dst)); + assert(ggml_are_same_shape(src0, dst)); if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; } + float scale = 1.0f; + memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); + // TODO: handle transposed/permuted matrices const int ith = params->ith; const int nth = params->nth; + const int64_t ne11 = src1 ? src1->ne[1] : 1; + const int nc = src0->ne[0]; const int nr = ggml_nrows(src0); @@ -10595,29 +10602,39 @@ static void ggml_compute_forward_soft_max_f32( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); + float * wdata = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith; + for (int i1 = ir0; i1 < ir1; i1++) { - float *sp = (float *)((char *) src0->data + i1*src0->nb[1]); - float *dp = (float *)((char *) dst->data + i1*dst->nb[1]); + float * sp = (float *)((char *) src0->data + i1*src0->nb[1]); + float * dp = (float *)((char *) dst->data + i1*dst->nb[1]); + + // broadcast the mask across rows + float * mp = src1 ? (float *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL; + + float * wp = wdata; + for (int i = 0; i < nc; i++) { + wp[i] = sp[i]*scale + (mp ? mp[i] : 0.0f); + } #ifndef NDEBUG for (int i = 0; i < nc; ++i) { //printf("p[%d] = %f\n", i, p[i]); - assert(!isnan(sp[i])); + assert(!isnan(wp[i])); } #endif float max = -INFINITY; - ggml_vec_max_f32(nc, &max, sp); + ggml_vec_max_f32(nc, &max, wp); ggml_float sum = 0.0; uint16_t scvt; for (int i = 0; i < nc; i++) { - if (sp[i] == -INFINITY) { + if (wp[i] == -INFINITY) { dp[i] = 0.0f; } else { - // const float val = (sp[i] == -INFINITY) ? 0.0 : exp(sp[i] - max); - ggml_fp16_t s = GGML_FP32_TO_FP16(sp[i] - max); + // const float val = (wp[i] == -INFINITY) ? 0.0 : exp(wp[i] - max); + ggml_fp16_t s = GGML_FP32_TO_FP16(wp[i] - max); memcpy(&scvt, &s, sizeof(scvt)); const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]); sum += (ggml_float)val; @@ -10642,11 +10659,12 @@ static void ggml_compute_forward_soft_max_f32( static void ggml_compute_forward_soft_max( const struct ggml_compute_params * params, const struct ggml_tensor * src0, - struct ggml_tensor * dst) { + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { switch (src0->type) { case GGML_TYPE_F32: { - ggml_compute_forward_soft_max_f32(params, src0, dst); + ggml_compute_forward_soft_max_f32(params, src0, src1, dst); } break; default: { @@ -13883,7 +13901,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm } break; case GGML_OP_SOFT_MAX: { - ggml_compute_forward_soft_max(params, tensor->src[0], tensor); + ggml_compute_forward_soft_max(params, tensor->src[0], tensor->src[1], tensor); } break; case GGML_OP_SOFT_MAX_BACK: { @@ -15919,6 +15937,12 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) { cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks; } } break; + case GGML_OP_SOFT_MAX: + { + n_tasks = MIN(n_threads, ggml_nrows(node->src[0])); + + cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks; + } break; case GGML_OP_CONV_TRANSPOSE_1D: { GGML_ASSERT(node->src[0]->ne[3] == 1); From 390a4459067cc7da6b41379379856bdfebc8d243 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 29 Nov 2023 17:26:12 +0200 Subject: [PATCH 04/14] batched-bench : print threads ggml-ci --- examples/batched-bench/batched-bench.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index 533c55c17aad1..57596ed986050 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -155,7 +155,7 @@ int main(int argc, char ** argv) { } LOG_TEE("\n"); - LOG_TEE("%s: n_kv_max = %d, is_pp_shared = %d, n_gpu_layers = %d, mmq = %d\n", __func__, n_kv_max, is_pp_shared, n_gpu_layers, mmq); + LOG_TEE("%s: n_kv_max = %d, is_pp_shared = %d, n_gpu_layers = %d, mmq = %d, n_threads = %d, n_threads_batch = %d\n", __func__, n_kv_max, is_pp_shared, n_gpu_layers, mmq, ctx_params.n_threads, ctx_params.n_threads_batch); LOG_TEE("\n"); LOG_TEE("|%6s | %6s | %4s | %6s | %8s | %8s | %8s | %8s | %8s | %8s |\n", "PP", "TG", "B", "N_KV", "T_PP s", "S_PP t/s", "T_TG s", "S_TG t/s", "T s", "S t/s"); From 580fe2064cc439a588c56b791a2ecbe07d35bcba Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 29 Nov 2023 17:30:19 +0200 Subject: [PATCH 05/14] metal : simplify soft_max encoding ggml-ci --- ggml-metal.m | 7 +------ llama.cpp | 2 +- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 0b468bea027a4..58149a487559f 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -1040,12 +1040,7 @@ void ggml_metal_graph_compute( const float scale = ((float *) dst->op_params)[0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - if (id_src1) { - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - } else { - [encoder setBuffer:nil offset:0 atIndex:1]; - } - + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; diff --git a/llama.cpp b/llama.cpp index 2c13aeb5091c5..7b261b73e2210 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3705,8 +3705,8 @@ static struct ggml_tensor * llm_build_kqv( struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); cb(kq, "kq", il); - // TODO: !!!!!!!!! if (max_alibi_bias > 0.0f) { + // temporary branch until we figure out how to handle ggml_alibi through ggml_add kq = ggml_scale(ctx, kq, kq_scale); cb(kq, "kq_scaled", il); From ebd062bc19b92ff9860a12e4f789b305015fb18b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 30 Nov 2023 17:19:29 +0200 Subject: [PATCH 06/14] cuda : use 512 threads for soft_max instead of 32 --- ggml-cuda.cu | 51 ++++++++++++++++++++++++++++++++++----------------- 1 file changed, 34 insertions(+), 17 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 628f2dcbcd60b..53a478aa910ae 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -443,6 +443,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_ #define CUDA_SCALE_BLOCK_SIZE 256 #define CUDA_CLAMP_BLOCK_SIZE 256 #define CUDA_ROPE_BLOCK_SIZE 256 +#define CUDA_SOFT_MAX_BLOCK_SIZE 512 #define CUDA_ALIBI_BLOCK_SIZE 32 #define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32 #define CUDA_QUANTIZE_BLOCK_SIZE 256 @@ -4717,26 +4718,32 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU } -// the CUDA soft max implementation differs from the CPU implementation -// instead of doubles floats are used +// TODO: maybe can be improved with some warp-based primitives static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) { - const int rowx = blockDim.x*blockIdx.x + threadIdx.x; + const int tid = threadIdx.x; + const int rowx = blockIdx.x; const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension - const int block_size = blockDim.y; - const int tid = threadIdx.y; - float max_val = -INFINITY; + const int block_size = blockDim.x; + + __shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE]; + + buf[tid] = -INFINITY; for (int col = tid; col < ncols; col += block_size) { const int ix = rowx*ncols + col; const int iy = rowy*ncols + col; - max_val = max(max_val, x[ix]*scale + (y ? y[iy] : 0.0f)); + buf[tid] = max(buf[tid], x[ix]*scale + (y ? y[iy] : 0.0f)); } + __syncthreads(); + // find the max value in the block -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - max_val = max(max_val, __shfl_xor_sync(0xffffffff, max_val, mask, 32)); + for (int i = block_size/2; i > 0; i >>= 1) { + if (tid < i) { + buf[tid] = max(buf[tid], buf[tid + i]); + } + __syncthreads(); } float tmp = 0.f; @@ -4744,18 +4751,26 @@ static __global__ void soft_max_f32(const float * x, const float * y, float * ds for (int col = tid; col < ncols; col += block_size) { const int ix = rowx*ncols + col; const int iy = rowy*ncols + col; - const float val = expf((x[ix]*scale + (y ? y[iy] : 0.0f)) - max_val); + const float val = expf((x[ix]*scale + (y ? y[iy] : 0.0f)) - buf[0]); tmp += val; dst[ix] = val; } + __syncthreads(); + + buf[tid] = tmp; + + __syncthreads(); + // sum up partial sums -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + for (int i = block_size/2; i > 0; i >>= 1) { + if (tid < i) { + buf[tid] += buf[tid + i]; + } + __syncthreads(); } - const float inv_tmp = 1.f / tmp; + const float inv_tmp = 1.f / buf[0]; for (int col = tid; col < ncols; col += block_size) { const int i = rowx*ncols + col; @@ -5796,7 +5811,9 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols } static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) { - const dim3 block_dims(1, WARP_SIZE, 1); + int nth = WARP_SIZE; + while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; + const dim3 block_dims(nth, 1, 1); const dim3 block_nums(nrows_x, 1, 1); soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); } @@ -6853,7 +6870,7 @@ inline void ggml_cuda_op_soft_max( const int64_t ne00 = src0->ne[0]; const int64_t nrows_x = ggml_nrows(src0); - const int64_t nrows_y = src1 ? ggml_nrows(src1) : 0; + const int64_t nrows_y = src1 ? ggml_nrows(src1) : 1; float scale = 1.0f; memcpy(&scale, dst->op_params, sizeof(float)); From c7c8dabcf74da8e66577948d2498e140d299e7e8 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 30 Nov 2023 20:05:41 +0200 Subject: [PATCH 07/14] ggml : update soft max cpu --- ggml.c | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/ggml.c b/ggml.c index 9dc4678cbd894..e2687ef4f072f 100644 --- a/ggml.c +++ b/ggml.c @@ -10602,7 +10602,7 @@ static void ggml_compute_forward_soft_max_f32( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - float * wdata = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith; + float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith; for (int i1 = ir0; i1 < ir1; i1++) { float * sp = (float *)((char *) src0->data + i1*src0->nb[1]); @@ -10611,9 +10611,10 @@ static void ggml_compute_forward_soft_max_f32( // broadcast the mask across rows float * mp = src1 ? (float *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL; - float * wp = wdata; - for (int i = 0; i < nc; i++) { - wp[i] = sp[i]*scale + (mp ? mp[i] : 0.0f); + ggml_vec_cpy_f32 (nc, wp, sp); + ggml_vec_scale_f32(nc, wp, scale); + if (mp) { + ggml_vec_acc_f32(nc, wp, mp); } #ifndef NDEBUG @@ -15939,7 +15940,7 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) { } break; case GGML_OP_SOFT_MAX: { - n_tasks = MIN(n_threads, ggml_nrows(node->src[0])); + n_tasks = MIN(MIN(4, n_threads), ggml_nrows(node->src[0])); cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks; } break; From 62532c05aa3d54f9ecbb5d0cc158c27a01ac2f59 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 30 Nov 2023 20:36:08 +0200 Subject: [PATCH 08/14] cuda : do warp-based block reduce --- ggml-cuda.cu | 87 +++++++++++++++++++++++++++++----------------------- 1 file changed, 48 insertions(+), 39 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 53a478aa910ae..080193cbd81eb 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -502,6 +502,31 @@ static size_t g_scratch_offset = 0; static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr}; +static __device__ __forceinline__ float warp_reduce_sum(float x) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + x += __shfl_xor_sync(0xffffffff, x, mask, 32); + } + return x; +} + +static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32); + a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32); + } + return a; +} + +static __device__ __forceinline__ float warp_reduce_max(float x) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); + } + return x; +} + static __global__ void add_f32(const float * x, const float * y, float * dst, const int kx, const int ky) { const int i = blockDim.x*blockIdx.x + threadIdx.x; @@ -578,15 +603,6 @@ static __global__ void sqr_f32(const float * x, float * dst, const int k) { dst[i] = x[i] * x[i]; } -static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32); - a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32); - } - return a; -} - template static __global__ void norm_f32(const float * x, float * dst, const int ncols) { const int row = blockIdx.x*blockDim.y + threadIdx.y; @@ -625,14 +641,6 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols) { } } -static __device__ __forceinline__ float warp_reduce_sum(float x) { -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - x += __shfl_xor_sync(0xffffffff, x, mask, 32); - } - return x; -} - template static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) { const int row = blockIdx.x*blockDim.y + threadIdx.y; @@ -4718,7 +4726,6 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU } -// TODO: maybe can be improved with some warp-based primitives static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) { const int tid = threadIdx.x; const int rowx = blockIdx.x; @@ -4726,24 +4733,26 @@ static __global__ void soft_max_f32(const float * x, const float * y, float * ds const int block_size = blockDim.x; - __shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE]; - - buf[tid] = -INFINITY; + float max_val = -INFINITY; for (int col = tid; col < ncols; col += block_size) { const int ix = rowx*ncols + col; const int iy = rowy*ncols + col; - buf[tid] = max(buf[tid], x[ix]*scale + (y ? y[iy] : 0.0f)); + max_val = max(max_val, x[ix]*scale + (y ? y[iy] : 0.0f)); } - __syncthreads(); - // find the max value in the block - for (int i = block_size/2; i > 0; i >>= 1) { - if (tid < i) { - buf[tid] = max(buf[tid], buf[tid + i]); + max_val = warp_reduce_max(max_val); + if (block_size > WARP_SIZE) { + __shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE]; + int warp_id = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + buf[warp_id] = max_val; } __syncthreads(); + max_val = buf[lane_id]; + max_val = warp_reduce_max(max_val); } float tmp = 0.f; @@ -4751,26 +4760,26 @@ static __global__ void soft_max_f32(const float * x, const float * y, float * ds for (int col = tid; col < ncols; col += block_size) { const int ix = rowx*ncols + col; const int iy = rowy*ncols + col; - const float val = expf((x[ix]*scale + (y ? y[iy] : 0.0f)) - buf[0]); + const float val = expf((x[ix]*scale + (y ? y[iy] : 0.0f)) - max_val); tmp += val; dst[ix] = val; } - __syncthreads(); - - buf[tid] = tmp; - - __syncthreads(); - - // sum up partial sums - for (int i = block_size/2; i > 0; i >>= 1) { - if (tid < i) { - buf[tid] += buf[tid + i]; + // find the sum of exps in the block + tmp = warp_reduce_sum(tmp); + if (block_size > WARP_SIZE) { + __shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE]; + int warp_id = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + buf[warp_id] = tmp; } __syncthreads(); + tmp = buf[lane_id]; + tmp = warp_reduce_sum(tmp); } - const float inv_tmp = 1.f / buf[0]; + const float inv_tmp = 1.f / tmp; for (int col = tid; col < ncols; col += block_size) { const int i = rowx*ncols + col; From 6b86bcffac842f35e03e608ce62065a361ffac0c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 30 Nov 2023 20:40:47 +0200 Subject: [PATCH 09/14] cuda : increase max block size to 1024 --- ggml-cuda.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 080193cbd81eb..98343d2083e57 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -443,7 +443,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_ #define CUDA_SCALE_BLOCK_SIZE 256 #define CUDA_CLAMP_BLOCK_SIZE 256 #define CUDA_ROPE_BLOCK_SIZE 256 -#define CUDA_SOFT_MAX_BLOCK_SIZE 512 +#define CUDA_SOFT_MAX_BLOCK_SIZE 1024 #define CUDA_ALIBI_BLOCK_SIZE 32 #define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32 #define CUDA_QUANTIZE_BLOCK_SIZE 256 From 68e02c0d584c2aa3d13925102cde1c54d1784114 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 30 Nov 2023 21:39:48 +0200 Subject: [PATCH 10/14] cuda : fix warp reduction initialization of shared mem --- ggml-cuda.cu | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 98343d2083e57..9019a849f0bff 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -4733,6 +4733,11 @@ static __global__ void soft_max_f32(const float * x, const float * y, float * ds const int block_size = blockDim.x; + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + + __shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE]; + float max_val = -INFINITY; for (int col = tid; col < ncols; col += block_size) { @@ -4744,13 +4749,16 @@ static __global__ void soft_max_f32(const float * x, const float * y, float * ds // find the max value in the block max_val = warp_reduce_max(max_val); if (block_size > WARP_SIZE) { - __shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE]; - int warp_id = threadIdx.x / WARP_SIZE; - int lane_id = threadIdx.x % WARP_SIZE; + if (warp_id == 0) { + buf[lane_id] = -INFINITY; + } + __syncthreads(); + if (lane_id == 0) { buf[warp_id] = max_val; } __syncthreads(); + max_val = buf[lane_id]; max_val = warp_reduce_max(max_val); } @@ -4768,13 +4776,16 @@ static __global__ void soft_max_f32(const float * x, const float * y, float * ds // find the sum of exps in the block tmp = warp_reduce_sum(tmp); if (block_size > WARP_SIZE) { - __shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE]; - int warp_id = threadIdx.x / WARP_SIZE; - int lane_id = threadIdx.x % WARP_SIZE; + if (warp_id == 0) { + buf[lane_id] = 0.f; + } + __syncthreads(); + if (lane_id == 0) { buf[warp_id] = tmp; } __syncthreads(); + tmp = buf[lane_id]; tmp = warp_reduce_sum(tmp); } From 55717c98c457aa1c6def7185dad049383143aad3 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 30 Nov 2023 21:52:32 +0200 Subject: [PATCH 11/14] metal : warp-based reduction for soft max kernel --- ggml-metal.m | 10 ++-- ggml-metal.metal | 119 ++++++++++++++++++++++++----------------------- 2 files changed, 68 insertions(+), 61 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 58149a487559f..febf7f97e2709 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -1028,12 +1028,14 @@ void ggml_metal_graph_compute( int nth = 32; // SIMD width if (ne00%4 == 0) { + while (nth < ne00/4 && nth < 256) { + nth *= 2; + } [encoder setComputePipelineState:ctx->pipeline_soft_max_4]; } else { - do { + while (nth < ne00 && nth < 1024) { nth *= 2; - } while (nth <= ne00 && nth <= 1024); - nth /= 2; + } [encoder setComputePipelineState:ctx->pipeline_soft_max]; } @@ -1046,7 +1048,7 @@ void ggml_metal_graph_compute( [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; [encoder setBytes:&scale length:sizeof(scale) atIndex:6]; - [encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0]; + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; diff --git a/ggml-metal.metal b/ggml-metal.metal index 61eb0c403a0f3..29fec7ef39f4f 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -39,6 +39,8 @@ typedef struct { int8_t qs[QK8_0]; // quants } block_q8_0; +#define N_SIMDWIDTH 32 // assuming SIMD group size is 32 + // general-purpose kernel for addition of two tensors // pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3 // cons: not very efficient @@ -207,54 +209,55 @@ kernel void kernel_soft_max( lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)); } - float max = simd_max(lmax); - if (tiisg == 0) { - buf[sgitg] = max; - } + // find the max value in the block + float max_val = simd_max(lmax); + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = -INFINITY; + } - threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup_barrier(mem_flags::mem_threadgroup); - // broadcast, simd group number is ntg / 32 - for (uint i = ntg / 32 / 2; i > 0; i /= 2) { - if (tpitg < i) { - buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]); - } - } + if (tiisg == 0) { + buf[sgitg] = max_val; + } - threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup_barrier(mem_flags::mem_threadgroup); - max = buf[0]; + max_val = buf[tiisg]; + max_val = simd_max(max_val); + } // parallel sum float lsum = 0.0f; for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max); + const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val); lsum += exp_psrc0; - // Remember the result of exp here. exp is expensive, so we really do not - // wish to compute it twice. pdst[i00] = exp_psrc0; } float sum = simd_sum(lsum); - if (tiisg == 0) { - buf[sgitg] = sum; - } + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = 0.0f; + } - threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup_barrier(mem_flags::mem_threadgroup); - // broadcast, simd group number is ntg / 32 - for (uint i = ntg / 32 / 2; i > 0; i /= 2) { - if (tpitg < i) { - buf[tpitg] += buf[tpitg + i]; - } - } + if (tiisg == 0) { + buf[sgitg] = sum; + } - threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup_barrier(mem_flags::mem_threadgroup); + + sum = buf[tiisg]; + sum = simd_sum(sum); + } - sum = buf[0]; + const float inv_sum = 1.0f/sum; for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - pdst[i00] /= sum; + pdst[i00] *= inv_sum; } } @@ -288,53 +291,56 @@ kernel void kernel_soft_max_4( } const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); - float max = simd_max(lmax); - if (tiisg == 0) { - buf[sgitg] = max; - } - threadgroup_barrier(mem_flags::mem_threadgroup); + float max_val = simd_max(lmax); + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = -INFINITY; + } - // broadcast, simd group number is ntg / 32 - for (uint i = ntg / 32 / 2; i > 0; i /= 2) { - if (tpitg < i) { - buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]); - } - } + threadgroup_barrier(mem_flags::mem_threadgroup); - threadgroup_barrier(mem_flags::mem_threadgroup); + if (tiisg == 0) { + buf[sgitg] = max_val; + } - max = buf[0]; + threadgroup_barrier(mem_flags::mem_threadgroup); + + max_val = buf[tiisg]; + max_val = simd_max(max_val); + } // parallel sum float4 lsum4 = 0.0f; for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max); + const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val); lsum4 += exp_psrc4; pdst4[i00] = exp_psrc4; } const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3]; float sum = simd_sum(lsum); - if (tiisg == 0) { - buf[sgitg] = sum; - } + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = 0.0f; + } - threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup_barrier(mem_flags::mem_threadgroup); - // broadcast, simd group number is ntg / 32 - for (uint i = ntg / 32 / 2; i > 0; i /= 2) { - if (tpitg < i) { - buf[tpitg] += buf[tpitg + i]; - } - } + if (tiisg == 0) { + buf[sgitg] = sum; + } - threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup_barrier(mem_flags::mem_threadgroup); + + sum = buf[tiisg]; + sum = simd_sum(sum); + } - sum = buf[0]; + const float inv_sum = 1.0f/sum; for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - pdst4[i00] /= sum; + pdst4[i00] *= inv_sum; } } @@ -582,7 +588,6 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre // putting them in the kernel cause a significant performance penalty #define N_DST 4 // each SIMD group works on 4 rows #define N_SIMDGROUP 2 // number of SIMD groups in a thread group -#define N_SIMDWIDTH 32 // assuming SIMD group size is 32 //Note: This is a template, but strictly speaking it only applies to // quantizations where the block size is 32. It also does not // giard against the number of rows not being divisible by From c4db59230da972ce698c31d871f59713baebe835 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 30 Nov 2023 22:21:30 +0200 Subject: [PATCH 12/14] metal : warp-based reduce for rms_norm --- ggml-metal.m | 18 +++++++++++------- ggml-metal.metal | 41 +++++++++++++++-------------------------- 2 files changed, 26 insertions(+), 33 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index febf7f97e2709..6cfacf64fcd9f 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -1358,15 +1358,19 @@ void ggml_metal_graph_compute( float eps; memcpy(&eps, dst->op_params, sizeof(float)); - const int nth = MIN(512, ne00); + int nth = 32; // SIMD width + + while (nth < ne00/4 && nth < 1024) { + nth *= 2; + } [encoder setComputePipelineState:ctx->pipeline_rms_norm]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; - [encoder setBytes:&eps length:sizeof( float) atIndex:4]; - [encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; + [encoder setBytes:&eps length:sizeof( float) atIndex:4]; + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; const int64_t nrows = ggml_nrows(src0); diff --git a/ggml-metal.metal b/ggml-metal.metal index 29fec7ef39f4f..e152cc53c0b97 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -447,14 +447,13 @@ kernel void kernel_rms_norm( constant int64_t & ne00, constant uint64_t & nb01, constant float & eps, - threadgroup float * sum [[threadgroup(0)]], + threadgroup float * buf [[threadgroup(0)]], uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint sgitg[[simdgroup_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint ntg[[threads_per_threadgroup]]) { - device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01); - device const float * x_scalar = (device const float *) x; + device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01); float4 sumf = 0; float all_sum = 0; @@ -465,40 +464,30 @@ kernel void kernel_rms_norm( } all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3]; all_sum = simd_sum(all_sum); - if (tiisg == 0) { - sum[sgitg] = all_sum; - } + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = 0.0f; + } - threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup_barrier(mem_flags::mem_threadgroup); - // broadcast, simd group number is ntg / 32 - for (uint i = ntg / 32 / 2; i > 0; i /= 2) { - if (tpitg < i) { - sum[tpitg] += sum[tpitg + i]; - } - } - if (tpitg == 0) { - for (int i = 4 * (ne00 / 4); i < ne00; i++) { - sum[0] += x_scalar[i]; + if (tiisg == 0) { + buf[sgitg] = all_sum; } - sum[0] /= ne00; - } - threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup_barrier(mem_flags::mem_threadgroup); - const float mean = sum[0]; + all_sum = buf[tiisg]; + all_sum = simd_sum(all_sum); + } + + const float mean = all_sum/ne00; const float scale = 1.0f/sqrt(mean + eps); device float4 * y = (device float4 *) (dst + tgpig*ne00); - device float * y_scalar = (device float *) y; for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { y[i00] = x[i00] * scale; } - if (tpitg == 0) { - for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) { - y_scalar[i00] = x_scalar[i00] * scale; - } - } } // function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i]) From d9c8fa3bced0f9974d6ba39aa586e236de4d5f02 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 1 Dec 2023 10:31:21 +0200 Subject: [PATCH 13/14] metal : simplify soft max kernel ggml-ci --- ggml-metal.metal | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index e152cc53c0b97..9a79f815f3a72 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -203,9 +203,9 @@ kernel void kernel_soft_max( device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; // parallel max - float lmax = (tpitg < ne00) ? (psrc0[tpitg]*scale + (pmask ? pmask[tpitg] : 0.0f)) : -INFINITY; + float lmax = -INFINITY; - for (int i00 = tpitg + ntg; i00 < ne00; i00 += ntg) { + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)); } @@ -284,9 +284,9 @@ kernel void kernel_soft_max_4( device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); // parallel max - float4 lmax4 = tpitg < ne00/4 ? (psrc4[tpitg]*scale + (pmask ? pmask[tpitg] : 0.0f)) : -INFINITY; + float4 lmax4 = -INFINITY; - for (int i00 = tpitg + ntg; i00 < ne00/4; i00 += ntg) { + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)); } From eb594c0f7d9528b942e50fae508dbcada597ac06 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 1 Dec 2023 10:45:54 +0200 Subject: [PATCH 14/14] alloc : fix build with debug --- ggml-alloc.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml-alloc.c b/ggml-alloc.c index cdfe4caf69613..0d4e12ae99d3d 100644 --- a/ggml-alloc.c +++ b/ggml-alloc.c @@ -137,7 +137,7 @@ void ggml_tallocr_alloc(ggml_tallocr_t alloc, struct ggml_tensor * tensor) { #ifdef GGML_ALLOCATOR_DEBUG add_allocated_tensor(alloc, tensor); - size_t cur_max = (char*)addr - (char*)alloc->data + size; + size_t cur_max = (char*)addr - (char*)alloc->base + size; if (cur_max > alloc->max_size) { printf("max_size = %.2f MB: tensors: ", cur_max / 1024.0 / 1024.0); for (int i = 0; i < 1024; i++) {