Skip to content

Commit 3519568

Browse files
authored
2x faster (rms) norm cuda kernels (3.7% e2e improvement) (#2985)
* 2x faster (rms) norm cuda kernels * Fix code style
1 parent cf9b084 commit 3519568

File tree

1 file changed

+66
-23
lines changed

1 file changed

+66
-23
lines changed

ggml-cuda.cu

Lines changed: 66 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -464,58 +464,91 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) {
464464
dst[i] = x[i] / (1.0f + expf(-x[i]));
465465
}
466466

467+
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
468+
#pragma unroll
469+
for (int mask = 16; mask > 0; mask >>= 1) {
470+
a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
471+
a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
472+
}
473+
return a;
474+
}
475+
476+
template <int block_size>
467477
static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
468478
const int row = blockIdx.x*blockDim.y + threadIdx.y;
469479
const int tid = threadIdx.x;
470480

471481
const float eps = 1e-5f;
472482

473-
float mean = 0.0f;
474-
float var = 0.0f;
483+
float2 mean_var = make_float2(0.f, 0.f);
475484

476-
for (int col = tid; col < ncols; col += WARP_SIZE) {
485+
for (int col = tid; col < ncols; col += block_size) {
477486
const float xi = x[row*ncols + col];
478-
mean += xi;
479-
var += xi * xi;
487+
mean_var.x += xi;
488+
mean_var.y += xi * xi;
480489
}
481490

482491
// sum up partial sums
483-
#pragma unroll
484-
for (int mask = 16; mask > 0; mask >>= 1) {
485-
mean += __shfl_xor_sync(0xffffffff, mean, mask, 32);
486-
var += __shfl_xor_sync(0xffffffff, var, mask, 32);
492+
mean_var = warp_reduce_sum(mean_var);
493+
if (block_size > WARP_SIZE) {
494+
__shared__ float2 s_sum[32];
495+
int warp_id = threadIdx.x / WARP_SIZE;
496+
int lane_id = threadIdx.x % WARP_SIZE;
497+
if (lane_id == 0) {
498+
s_sum[warp_id] = mean_var;
499+
}
500+
__syncthreads();
501+
mean_var = s_sum[lane_id];
502+
mean_var = warp_reduce_sum(mean_var);
487503
}
488504

489-
mean /= ncols;
490-
var = var / ncols - mean * mean;
491-
const float inv_var = rsqrtf(var + eps);
505+
const float mean = mean_var.x / ncols;
506+
const float var = mean_var.y / ncols - mean * mean;
507+
const float inv_std = rsqrtf(var + eps);
492508

493-
for (int col = tid; col < ncols; col += WARP_SIZE) {
494-
dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_var;
509+
for (int col = tid; col < ncols; col += block_size) {
510+
dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_std;
495511
}
496512
}
497513

514+
static __device__ __forceinline__ float warp_reduce_sum(float x) {
515+
#pragma unroll
516+
for (int mask = 16; mask > 0; mask >>= 1) {
517+
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
518+
}
519+
return x;
520+
}
521+
522+
template <int block_size>
498523
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
499524
const int row = blockIdx.x*blockDim.y + threadIdx.y;
500525
const int tid = threadIdx.x;
501526

502527
float tmp = 0.0f; // partial sum for thread in warp
503528

504-
for (int col = tid; col < ncols; col += WARP_SIZE) {
529+
for (int col = tid; col < ncols; col += block_size) {
505530
const float xi = x[row*ncols + col];
506531
tmp += xi * xi;
507532
}
508533

509534
// sum up partial sums
510-
#pragma unroll
511-
for (int mask = 16; mask > 0; mask >>= 1) {
512-
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
535+
tmp = warp_reduce_sum(tmp);
536+
if (block_size > WARP_SIZE) {
537+
__shared__ float s_sum[32];
538+
int warp_id = threadIdx.x / WARP_SIZE;
539+
int lane_id = threadIdx.x % WARP_SIZE;
540+
if (lane_id == 0) {
541+
s_sum[warp_id] = tmp;
542+
}
543+
__syncthreads();
544+
tmp = s_sum[lane_id];
545+
tmp = warp_reduce_sum(tmp);
513546
}
514547

515548
const float mean = tmp / ncols;
516549
const float scale = rsqrtf(mean + eps);
517550

518-
for (int col = tid; col < ncols; col += WARP_SIZE) {
551+
for (int col = tid; col < ncols; col += block_size) {
519552
dst[row*ncols + col] = scale * x[row*ncols + col];
520553
}
521554
}
@@ -4203,14 +4236,24 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
42034236

42044237
static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
42054238
GGML_ASSERT(ncols % WARP_SIZE == 0);
4206-
const dim3 block_dims(WARP_SIZE, 1, 1);
4207-
norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
4239+
if (ncols < 1024) {
4240+
const dim3 block_dims(WARP_SIZE, 1, 1);
4241+
norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
4242+
} else {
4243+
const dim3 block_dims(1024, 1, 1);
4244+
norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
4245+
}
42084246
}
42094247

42104248
static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
42114249
GGML_ASSERT(ncols % WARP_SIZE == 0);
4212-
const dim3 block_dims(WARP_SIZE, 1, 1);
4213-
rms_norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
4250+
if (ncols < 1024) {
4251+
const dim3 block_dims(WARP_SIZE, 1, 1);
4252+
rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
4253+
} else {
4254+
const dim3 block_dims(1024, 1, 1);
4255+
rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
4256+
}
42144257
}
42154258

42164259
static void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, const int ky, const int kx_padded, cudaStream_t stream) {

0 commit comments

Comments
 (0)