@@ -464,58 +464,91 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) {
464
464
dst[i] = x[i] / (1 .0f + expf (-x[i]));
465
465
}
466
466
467
+ static __device__ __forceinline__ float2 warp_reduce_sum (float2 a) {
468
+ #pragma unroll
469
+ for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
470
+ a.x += __shfl_xor_sync (0xffffffff , a.x , mask, 32 );
471
+ a.y += __shfl_xor_sync (0xffffffff , a.y , mask, 32 );
472
+ }
473
+ return a;
474
+ }
475
+
476
+ template <int block_size>
467
477
static __global__ void norm_f32 (const float * x, float * dst, const int ncols) {
468
478
const int row = blockIdx .x *blockDim .y + threadIdx .y ;
469
479
const int tid = threadIdx .x ;
470
480
471
481
const float eps = 1e-5f ;
472
482
473
- float mean = 0 .0f ;
474
- float var = 0 .0f ;
483
+ float2 mean_var = make_float2 (0 .f , 0 .f );
475
484
476
- for (int col = tid; col < ncols; col += WARP_SIZE ) {
485
+ for (int col = tid; col < ncols; col += block_size ) {
477
486
const float xi = x[row*ncols + col];
478
- mean += xi;
479
- var += xi * xi;
487
+ mean_var. x += xi;
488
+ mean_var. y += xi * xi;
480
489
}
481
490
482
491
// sum up partial sums
483
- #pragma unroll
484
- for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
485
- mean += __shfl_xor_sync (0xffffffff , mean, mask, 32 );
486
- var += __shfl_xor_sync (0xffffffff , var, mask, 32 );
492
+ mean_var = warp_reduce_sum (mean_var);
493
+ if (block_size > WARP_SIZE) {
494
+ __shared__ float2 s_sum[32 ];
495
+ int warp_id = threadIdx .x / WARP_SIZE;
496
+ int lane_id = threadIdx .x % WARP_SIZE;
497
+ if (lane_id == 0 ) {
498
+ s_sum[warp_id] = mean_var;
499
+ }
500
+ __syncthreads ();
501
+ mean_var = s_sum[lane_id];
502
+ mean_var = warp_reduce_sum (mean_var);
487
503
}
488
504
489
- mean /= ncols;
490
- var = var / ncols - mean * mean;
491
- const float inv_var = rsqrtf (var + eps);
505
+ const float mean = mean_var. x / ncols;
506
+ const float var = mean_var. y / ncols - mean * mean;
507
+ const float inv_std = rsqrtf (var + eps);
492
508
493
- for (int col = tid; col < ncols; col += WARP_SIZE ) {
494
- dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_var ;
509
+ for (int col = tid; col < ncols; col += block_size ) {
510
+ dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_std ;
495
511
}
496
512
}
497
513
514
+ static __device__ __forceinline__ float warp_reduce_sum (float x) {
515
+ #pragma unroll
516
+ for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
517
+ x += __shfl_xor_sync (0xffffffff , x, mask, 32 );
518
+ }
519
+ return x;
520
+ }
521
+
522
+ template <int block_size>
498
523
static __global__ void rms_norm_f32 (const float * x, float * dst, const int ncols, const float eps) {
499
524
const int row = blockIdx .x *blockDim .y + threadIdx .y ;
500
525
const int tid = threadIdx .x ;
501
526
502
527
float tmp = 0 .0f ; // partial sum for thread in warp
503
528
504
- for (int col = tid; col < ncols; col += WARP_SIZE ) {
529
+ for (int col = tid; col < ncols; col += block_size ) {
505
530
const float xi = x[row*ncols + col];
506
531
tmp += xi * xi;
507
532
}
508
533
509
534
// sum up partial sums
510
- #pragma unroll
511
- for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
512
- tmp += __shfl_xor_sync (0xffffffff , tmp, mask, 32 );
535
+ tmp = warp_reduce_sum (tmp);
536
+ if (block_size > WARP_SIZE) {
537
+ __shared__ float s_sum[32 ];
538
+ int warp_id = threadIdx .x / WARP_SIZE;
539
+ int lane_id = threadIdx .x % WARP_SIZE;
540
+ if (lane_id == 0 ) {
541
+ s_sum[warp_id] = tmp;
542
+ }
543
+ __syncthreads ();
544
+ tmp = s_sum[lane_id];
545
+ tmp = warp_reduce_sum (tmp);
513
546
}
514
547
515
548
const float mean = tmp / ncols;
516
549
const float scale = rsqrtf (mean + eps);
517
550
518
- for (int col = tid; col < ncols; col += WARP_SIZE ) {
551
+ for (int col = tid; col < ncols; col += block_size ) {
519
552
dst[row*ncols + col] = scale * x[row*ncols + col];
520
553
}
521
554
}
@@ -4203,14 +4236,24 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
4203
4236
4204
4237
static void norm_f32_cuda (const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
4205
4238
GGML_ASSERT (ncols % WARP_SIZE == 0 );
4206
- const dim3 block_dims (WARP_SIZE, 1 , 1 );
4207
- norm_f32<<<nrows, block_dims, 0 , stream>>> (x, dst, ncols);
4239
+ if (ncols < 1024 ) {
4240
+ const dim3 block_dims (WARP_SIZE, 1 , 1 );
4241
+ norm_f32<WARP_SIZE><<<nrows, block_dims, 0 , stream>>> (x, dst, ncols);
4242
+ } else {
4243
+ const dim3 block_dims (1024 , 1 , 1 );
4244
+ norm_f32<1024 ><<<nrows, block_dims, 0 , stream>>> (x, dst, ncols);
4245
+ }
4208
4246
}
4209
4247
4210
4248
static void rms_norm_f32_cuda (const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
4211
4249
GGML_ASSERT (ncols % WARP_SIZE == 0 );
4212
- const dim3 block_dims (WARP_SIZE, 1 , 1 );
4213
- rms_norm_f32<<<nrows, block_dims, 0 , stream>>> (x, dst, ncols, eps);
4250
+ if (ncols < 1024 ) {
4251
+ const dim3 block_dims (WARP_SIZE, 1 , 1 );
4252
+ rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0 , stream>>> (x, dst, ncols, eps);
4253
+ } else {
4254
+ const dim3 block_dims (1024 , 1 , 1 );
4255
+ rms_norm_f32<1024 ><<<nrows, block_dims, 0 , stream>>> (x, dst, ncols, eps);
4256
+ }
4214
4257
}
4215
4258
4216
4259
static void quantize_row_q8_1_cuda (const float * x, void * vy, const int kx, const int ky, const int kx_padded, cudaStream_t stream) {
0 commit comments