@@ -34,90 +34,92 @@ static void rope_yarn(
34
34
*sin_theta = sycl::sin (theta) * mscale;
35
35
}
36
36
37
- template <typename T, bool has_ff>
38
- static void rope_norm (
39
- const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
40
- float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors,
41
- const sycl::nd_item<3 > &item_ct1) {
42
- const int i0 = 2 * (item_ct1.get_local_range (1 ) * item_ct1.get_group (1 ) +
43
- item_ct1.get_local_id (1 ));
37
+ template <typename T, bool has_ff>
38
+ static void rope_norm (const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
39
+ const int32_t * pos, float freq_scale, float ext_factor, float attn_factor,
40
+ const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors,
41
+ const sycl::nd_item<3 > & item_ct1) {
42
+ const int i0 = 2 * (item_ct1.get_local_range (1 ) * item_ct1.get_group (1 ) + item_ct1.get_local_id (1 ));
44
43
45
44
if (i0 >= ne0) {
46
45
return ;
47
46
}
48
47
49
- const int row = item_ct1.get_local_range (2 ) * item_ct1.get_group (2 ) +
50
- item_ct1.get_local_id (2 );
48
+ const int row = item_ct1.get_local_range (2 ) * item_ct1.get_group (2 ) + item_ct1.get_local_id (2 );
51
49
52
50
if (i0 >= n_dims) {
53
- const int i = row* ne0 + i0;
51
+ const int i = row * ne0 + i0;
54
52
55
53
dst[i + 0 ] = x[i + 0 ];
56
54
dst[i + 1 ] = x[i + 1 ];
57
55
58
56
return ;
59
57
}
60
58
61
- const int i = row*ne0 + i0;
62
- const int i2 = row/p_delta_rows;
59
+ const int row0 = row % ne1;
60
+ const int channel0 = row / ne1;
61
+
62
+ const int i = row * ne0 + i0;
63
+ const int i2 = channel0 * s2 + row0 * s1 + i0;
63
64
64
- const float theta_base = pos[i2 ] * sycl::pow (theta_scale, i0 / 2 .0f );
65
+ const float theta_base = pos[channel0 ] * sycl::pow (theta_scale, i0 / 2 .0f );
65
66
66
- const float freq_factor = has_ff ? freq_factors[i0/ 2 ] : 1 .0f ;
67
+ const float freq_factor = has_ff ? freq_factors[i0 / 2 ] : 1 .0f ;
67
68
68
69
float cos_theta;
69
70
float sin_theta;
70
71
71
- rope_yarn (theta_base/ freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
72
+ rope_yarn (theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
72
73
73
- const float x0 = x[i + 0 ];
74
- const float x1 = x[i + 1 ];
74
+ const float x0 = x[i2 + 0 ];
75
+ const float x1 = x[i2 + 1 ];
75
76
76
- dst[i + 0 ] = x0* cos_theta - x1* sin_theta;
77
- dst[i + 1 ] = x0* sin_theta + x1* cos_theta;
77
+ dst[i + 0 ] = x0 * cos_theta - x1 * sin_theta;
78
+ dst[i + 1 ] = x0 * sin_theta + x1 * cos_theta;
78
79
}
79
80
80
- template <typename T, bool has_ff>
81
- static void rope_neox (
82
- const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
83
- float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors,
84
- const sycl::nd_item<3 > &item_ct1) {
85
- const int i0 = 2 * (item_ct1.get_local_range (1 ) * item_ct1.get_group (1 ) +
86
- item_ct1.get_local_id (1 ));
81
+ template <typename T, bool has_ff>
82
+ static void rope_neox (const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
83
+ const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
84
+ const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors,
85
+ const sycl::nd_item<3 > & item_ct1) {
86
+ const int i0 = 2 * (item_ct1.get_local_range (1 ) * item_ct1.get_group (1 ) + item_ct1.get_local_id (1 ));
87
87
88
88
if (i0 >= ne0) {
89
89
return ;
90
90
}
91
91
92
- const int row = item_ct1.get_local_range (2 ) * item_ct1.get_group (2 ) +
93
- item_ct1.get_local_id (2 );
92
+ const int row = item_ct1.get_local_range (2 ) * item_ct1.get_group (2 ) + item_ct1.get_local_id (2 );
94
93
95
94
if (i0 >= n_dims) {
96
- const int i = row* ne0 + i0;
95
+ const int i = row * ne0 + i0;
97
96
98
97
dst[i + 0 ] = x[i + 0 ];
99
98
dst[i + 1 ] = x[i + 1 ];
100
99
101
100
return ;
102
101
}
103
102
104
- const int i = row*ne0 + i0/2 ;
105
- const int i2 = row/p_delta_rows;
103
+ const int row0 = row % ne1;
104
+ const int channel0 = row / ne1;
105
+
106
+ const int i = row * ne0 + i0 / 2 ;
107
+ const int i2 = channel0 * s2 + row0 * s1 + i0 / 2 ;
106
108
107
- const float theta_base = pos[i2 ] * sycl::pow (theta_scale, i0 / 2 .0f );
109
+ const float theta_base = pos[channel0 ] * sycl::pow (theta_scale, i0 / 2 .0f );
108
110
109
- const float freq_factor = has_ff ? freq_factors[i0/ 2 ] : 1 .0f ;
111
+ const float freq_factor = has_ff ? freq_factors[i0 / 2 ] : 1 .0f ;
110
112
111
113
float cos_theta;
112
114
float sin_theta;
113
115
114
- rope_yarn (theta_base/ freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
116
+ rope_yarn (theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
115
117
116
- const float x0 = x[i + 0 ];
117
- const float x1 = x[i + n_dims/ 2 ];
118
+ const float x0 = x[i2 + 0 ];
119
+ const float x1 = x[i2 + n_dims / 2 ];
118
120
119
- dst[i + 0 ] = x0* cos_theta - x1* sin_theta;
120
- dst[i + n_dims/ 2 ] = x0* sin_theta + x1* cos_theta;
121
+ dst[i + 0 ] = x0 * cos_theta - x1 * sin_theta;
122
+ dst[i + n_dims / 2 ] = x0 * sin_theta + x1 * cos_theta;
121
123
}
122
124
123
125
template <typename T, bool has_ff>
@@ -163,80 +165,66 @@ static void rope_vision(const T * x, T * dst, const int ne0, const int ne1, cons
163
165
}
164
166
165
167
template <typename T>
166
- static void rope_norm_sycl (
167
- const T *x, T *dst, int ne0, int n_dims, int nr, const int32_t *pos, float freq_scale, int p_delta_rows,
168
- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
168
+ static void rope_norm_sycl (const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2,
169
+ const int n_dims, int nr, const int32_t * pos, const float freq_scale, const float freq_base,
170
+ const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
171
+ const float * freq_factors, queue_ptr stream) {
169
172
GGML_ASSERT (ne0 % 2 == 0 );
170
173
const sycl::range<3 > block_dims (1 , SYCL_ROPE_BLOCK_SIZE, 1 );
171
- const int num_blocks_x = (ne0 + 2 * SYCL_ROPE_BLOCK_SIZE - 1 ) / (2 * SYCL_ROPE_BLOCK_SIZE);
174
+ const int num_blocks_x = (ne0 + 2 * SYCL_ROPE_BLOCK_SIZE - 1 ) / (2 * SYCL_ROPE_BLOCK_SIZE);
172
175
const sycl::range<3 > block_nums (1 , num_blocks_x, nr);
173
176
174
- const float theta_scale = powf (freq_base, -2 .0f / n_dims);
177
+ const float theta_scale = powf (freq_base, -2 .0f / n_dims);
175
178
176
- dpct::has_capability_or_fail (stream->get_device (),
177
- {sycl::aspect::fp16});
179
+ dpct::has_capability_or_fail (stream->get_device (), { sycl::aspect::fp16 });
178
180
179
181
if (freq_factors == nullptr ) {
180
182
/*
181
183
DPCT1049:40: The work-group size passed to the SYCL kernel may exceed
182
184
the limit. To get the device limit, query
183
185
info::device::max_work_group_size. Adjust the work-group size if needed.
184
186
*/
185
- stream->parallel_for (
186
- sycl::nd_range<3 >(block_nums * block_dims, block_dims),
187
- [=](sycl::nd_item<3 > item_ct1) {
188
- rope_norm<T, false >(x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows,
189
- ext_factor, attn_factor, corr_dims, theta_scale, freq_factors,
190
- item_ct1);
191
- });
187
+ stream->parallel_for (sycl::nd_range<3 >(block_nums * block_dims, block_dims), [=](sycl::nd_item<3 > item_ct1) {
188
+ rope_norm<T, false >(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
189
+ theta_scale, freq_factors, item_ct1);
190
+ });
192
191
} else {
193
192
/*
194
193
DPCT1049:41: The work-group size passed to the SYCL kernel may exceed
195
194
the limit. To get the device limit, query
196
195
info::device::max_work_group_size. Adjust the work-group size if needed.
197
196
*/
198
- stream->parallel_for (
199
- sycl::nd_range<3 >(block_nums * block_dims, block_dims),
200
- [=](sycl::nd_item<3 > item_ct1) {
201
- rope_norm<T, true >(x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows,
202
- ext_factor, attn_factor, corr_dims, theta_scale, freq_factors,
203
- item_ct1);
204
- });
197
+ stream->parallel_for (sycl::nd_range<3 >(block_nums * block_dims, block_dims), [=](sycl::nd_item<3 > item_ct1) {
198
+ rope_norm<T, true >(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
199
+ theta_scale, freq_factors, item_ct1);
200
+ });
205
201
}
206
202
}
207
203
208
204
template <typename T>
209
- static void rope_neox_sycl (
210
- const T *x, T *dst, int ne0, int n_dims, int nr, const int32_t *pos, float freq_scale, int p_delta_rows,
211
- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
205
+ static void rope_neox_sycl (const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2,
206
+ const int n_dims, const int nr, const int32_t * pos, const float freq_scale,
207
+ const float freq_base, const float ext_factor, const float attn_factor,
208
+ const rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
212
209
GGML_ASSERT (ne0 % 2 == 0 );
213
210
const sycl::range<3 > block_dims (1 , SYCL_ROPE_BLOCK_SIZE, 1 );
214
- const int num_blocks_x = (ne0 + 2 * SYCL_ROPE_BLOCK_SIZE - 1 ) / (2 * SYCL_ROPE_BLOCK_SIZE);
211
+ const int num_blocks_x = (ne0 + 2 * SYCL_ROPE_BLOCK_SIZE - 1 ) / (2 * SYCL_ROPE_BLOCK_SIZE);
215
212
const sycl::range<3 > block_nums (1 , num_blocks_x, nr);
216
213
217
- const float theta_scale = powf (freq_base, -2 .0f / n_dims);
214
+ const float theta_scale = powf (freq_base, -2 .0f / n_dims);
218
215
219
- dpct::has_capability_or_fail (stream->get_device (),
220
- {sycl::aspect::fp16});
216
+ dpct::has_capability_or_fail (stream->get_device (), { sycl::aspect::fp16 });
221
217
222
218
if (freq_factors == nullptr ) {
223
- stream->parallel_for (
224
- sycl::nd_range<3 >(block_nums * block_dims, block_dims),
225
- [=](sycl::nd_item<3 > item_ct1) {
226
- rope_neox<T, false >(x, dst, ne0, n_dims, pos, freq_scale,
227
- p_delta_rows, ext_factor, attn_factor,
228
- corr_dims, theta_scale, freq_factors,
229
- item_ct1);
230
- });
219
+ stream->parallel_for (sycl::nd_range<3 >(block_nums * block_dims, block_dims), [=](sycl::nd_item<3 > item_ct1) {
220
+ rope_neox<T, false >(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
221
+ theta_scale, freq_factors, item_ct1);
222
+ });
231
223
} else {
232
- stream->parallel_for (
233
- sycl::nd_range<3 >(block_nums * block_dims, block_dims),
234
- [=](sycl::nd_item<3 > item_ct1) {
235
- rope_neox<T, true >(x, dst, ne0, n_dims, pos, freq_scale,
236
- p_delta_rows, ext_factor, attn_factor,
237
- corr_dims, theta_scale, freq_factors,
238
- item_ct1);
239
- });
224
+ stream->parallel_for (sycl::nd_range<3 >(block_nums * block_dims, block_dims), [=](sycl::nd_item<3 > item_ct1) {
225
+ rope_neox<T, true >(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
226
+ theta_scale, freq_factors, item_ct1);
227
+ });
240
228
}
241
229
}
242
230
@@ -272,7 +260,7 @@ static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1,
272
260
}
273
261
}
274
262
275
- void ggml_sycl_op_rope (ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
263
+ inline void ggml_sycl_op_rope (ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
276
264
277
265
GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32 || dst->src [0 ]->type == GGML_TYPE_F16);
278
266
GGML_ASSERT ( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
@@ -329,43 +317,46 @@ void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
329
317
if (is_neox) {
330
318
GGML_SYCL_DEBUG (" %s: neox path\n " , __func__);
331
319
if (dst->src [0 ]->type == GGML_TYPE_F32) {
332
- rope_neox_sycl (
333
- (const float *)dst->src [0 ]->data , (float *)dst->data , ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
334
- attn_factor, corr_dims, freq_factors, main_stream
335
- );
320
+ rope_neox_sycl ((const float *) dst->src [0 ]->data , (float *) dst->data , ne00, ne01, s01, s02, n_dims, nr,
321
+ pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream);
336
322
} else if (dst->src [0 ]->type == GGML_TYPE_F16) {
337
- rope_neox_sycl (
338
- (const sycl::half *)dst->src [0 ]->data , (sycl::half *)dst->data , ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
339
- attn_factor, corr_dims, freq_factors, main_stream
340
- );
323
+ rope_neox_sycl ((const sycl::half *) dst->src [0 ]->data , (sycl::half *) dst->data , ne00, ne01, s01, s02,
324
+ n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
325
+ main_stream);
341
326
} else {
342
327
GGML_ABORT (" fatal error" );
343
328
}
344
329
} else if (is_vision) {
345
330
GGML_SYCL_DEBUG (" %s: vision path\n " , __func__);
346
331
if (dst->src [0 ]->type == GGML_TYPE_F16) {
347
- rope_vision_sycl ((const sycl::half *)dst->src [0 ]->data , (sycl::half *)dst->data , ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
348
- freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, main_stream);
332
+ rope_vision_sycl ((const sycl::half *) dst->src [0 ]->data , (sycl::half *) dst->data , ne00, ne01, ne02, s01,
333
+ s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
334
+ freq_factors, sections, main_stream);
349
335
} else if (dst->src [0 ]->type == GGML_TYPE_F32) {
350
- rope_vision_sycl ((const float *) dst->src [0 ]->data , (float *)dst->data , ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
351
- freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, main_stream);
336
+ rope_vision_sycl ((const float *) dst->src [0 ]->data , (float *) dst->data , ne00, ne01, ne02, s01, s02, n_dims,
337
+ nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections,
338
+ main_stream);
352
339
} else {
353
340
GGML_ABORT (" Fatal error: Tensor type unsupported!" );
354
341
}
355
342
} else {
356
343
GGML_SYCL_DEBUG (" %s: norm path\n " , __func__);
357
344
if (dst->src [0 ]->type == GGML_TYPE_F32) {
358
- rope_norm_sycl (
359
- (const float *)dst->src [0 ]->data , (float *)dst->data , ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
360
- attn_factor, corr_dims, freq_factors, main_stream
361
- );
345
+ rope_norm_sycl ((const float *) dst->src [0 ]->data , (float *) dst->data , ne00, ne01, s01, s02, n_dims, nr,
346
+ pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream);
362
347
} else if (dst->src [0 ]->type == GGML_TYPE_F16) {
363
- rope_norm_sycl (
364
- (const sycl::half *)dst->src [0 ]->data , (sycl::half *)dst->data , ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
365
- attn_factor, corr_dims, freq_factors, main_stream
366
- );
348
+ rope_norm_sycl ((const sycl::half *) dst->src [0 ]->data , (sycl::half *) dst->data , ne00, ne01, s01, s02,
349
+ n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
350
+ main_stream);
367
351
} else {
368
352
GGML_ABORT (" fatal error" );
369
353
}
370
354
}
371
355
}
356
+
357
+ void ggml_sycl_rope (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
358
+ GGML_SYCL_DEBUG (" call %s\n " , __func__);
359
+ ggml_sycl_op_rope (ctx, dst);
360
+ GGML_SYCL_DEBUG (" call %s done\n " , __func__);
361
+ }
362
+
0 commit comments