Skip to content

Commit 5368ddd

Browse files
authored
SYCL: Add non-contiguous support in ROPE (ggml-org#12993)
ggml-ci
1 parent 84a9bf2 commit 5368ddd

File tree

3 files changed

+96
-110
lines changed

3 files changed

+96
-110
lines changed

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3168,11 +3168,6 @@ static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor
31683168
ggml_sycl_op_diag_mask_inf(ctx, dst);
31693169
}
31703170

3171-
static void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3172-
GGML_ASSERT(ggml_is_contiguous(dst->src[0])); // TODO: this restriction is temporary until non-cont support is implemented
3173-
ggml_sycl_op_rope(ctx, dst);
3174-
}
3175-
31763171
static void ggml_sycl_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
31773172
ggml_sycl_op_pool2d(ctx, dst);
31783173
}
@@ -4002,7 +3997,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
40023997
if (mode == GGML_ROPE_TYPE_MROPE) {
40033998
return false;
40043999
}
4005-
return ggml_is_contiguous(op->src[0]);
4000+
return true;
40064001
}
40074002
case GGML_OP_IM2COL:
40084003
return true;

ggml/src/ggml-sycl/rope.cpp

Lines changed: 94 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -34,90 +34,92 @@ static void rope_yarn(
3434
*sin_theta = sycl::sin(theta) * mscale;
3535
}
3636

37-
template<typename T, bool has_ff>
38-
static void rope_norm(
39-
const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
40-
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors,
41-
const sycl::nd_item<3> &item_ct1) {
42-
const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
43-
item_ct1.get_local_id(1));
37+
template <typename T, bool has_ff>
38+
static void rope_norm(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
39+
const int32_t * pos, float freq_scale, float ext_factor, float attn_factor,
40+
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors,
41+
const sycl::nd_item<3> & item_ct1) {
42+
const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1));
4443

4544
if (i0 >= ne0) {
4645
return;
4746
}
4847

49-
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
50-
item_ct1.get_local_id(2);
48+
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
5149

5250
if (i0 >= n_dims) {
53-
const int i = row*ne0 + i0;
51+
const int i = row * ne0 + i0;
5452

5553
dst[i + 0] = x[i + 0];
5654
dst[i + 1] = x[i + 1];
5755

5856
return;
5957
}
6058

61-
const int i = row*ne0 + i0;
62-
const int i2 = row/p_delta_rows;
59+
const int row0 = row % ne1;
60+
const int channel0 = row / ne1;
61+
62+
const int i = row * ne0 + i0;
63+
const int i2 = channel0 * s2 + row0 * s1 + i0;
6364

64-
const float theta_base = pos[i2] * sycl::pow(theta_scale, i0 / 2.0f);
65+
const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
6566

66-
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
67+
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
6768

6869
float cos_theta;
6970
float sin_theta;
7071

71-
rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
72+
rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
7273

73-
const float x0 = x[i + 0];
74-
const float x1 = x[i + 1];
74+
const float x0 = x[i2 + 0];
75+
const float x1 = x[i2 + 1];
7576

76-
dst[i + 0] = x0*cos_theta - x1*sin_theta;
77-
dst[i + 1] = x0*sin_theta + x1*cos_theta;
77+
dst[i + 0] = x0 * cos_theta - x1 * sin_theta;
78+
dst[i + 1] = x0 * sin_theta + x1 * cos_theta;
7879
}
7980

80-
template<typename T, bool has_ff>
81-
static void rope_neox(
82-
const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
83-
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors,
84-
const sycl::nd_item<3> &item_ct1) {
85-
const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
86-
item_ct1.get_local_id(1));
81+
template <typename T, bool has_ff>
82+
static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
83+
const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
84+
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors,
85+
const sycl::nd_item<3> & item_ct1) {
86+
const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1));
8787

8888
if (i0 >= ne0) {
8989
return;
9090
}
9191

92-
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
93-
item_ct1.get_local_id(2);
92+
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
9493

9594
if (i0 >= n_dims) {
96-
const int i = row*ne0 + i0;
95+
const int i = row * ne0 + i0;
9796

9897
dst[i + 0] = x[i + 0];
9998
dst[i + 1] = x[i + 1];
10099

101100
return;
102101
}
103102

104-
const int i = row*ne0 + i0/2;
105-
const int i2 = row/p_delta_rows;
103+
const int row0 = row % ne1;
104+
const int channel0 = row / ne1;
105+
106+
const int i = row * ne0 + i0 / 2;
107+
const int i2 = channel0 * s2 + row0 * s1 + i0 / 2;
106108

107-
const float theta_base = pos[i2] * sycl::pow(theta_scale, i0 / 2.0f);
109+
const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
108110

109-
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
111+
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
110112

111113
float cos_theta;
112114
float sin_theta;
113115

114-
rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
116+
rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
115117

116-
const float x0 = x[i + 0];
117-
const float x1 = x[i + n_dims/2];
118+
const float x0 = x[i2 + 0];
119+
const float x1 = x[i2 + n_dims / 2];
118120

119-
dst[i + 0] = x0*cos_theta - x1*sin_theta;
120-
dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
121+
dst[i + 0] = x0 * cos_theta - x1 * sin_theta;
122+
dst[i + n_dims / 2] = x0 * sin_theta + x1 * cos_theta;
121123
}
122124

123125
template <typename T, bool has_ff>
@@ -163,80 +165,66 @@ static void rope_vision(const T * x, T * dst, const int ne0, const int ne1, cons
163165
}
164166

165167
template <typename T>
166-
static void rope_norm_sycl(
167-
const T *x, T *dst, int ne0, int n_dims, int nr, const int32_t *pos, float freq_scale, int p_delta_rows,
168-
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
168+
static void rope_norm_sycl(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2,
169+
const int n_dims, int nr, const int32_t * pos, const float freq_scale, const float freq_base,
170+
const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
171+
const float * freq_factors, queue_ptr stream) {
169172
GGML_ASSERT(ne0 % 2 == 0);
170173
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
171-
const int num_blocks_x = (ne0 + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE);
174+
const int num_blocks_x = (ne0 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
172175
const sycl::range<3> block_nums(1, num_blocks_x, nr);
173176

174-
const float theta_scale = powf(freq_base, -2.0f/n_dims);
177+
const float theta_scale = powf(freq_base, -2.0f / n_dims);
175178

176-
dpct::has_capability_or_fail(stream->get_device(),
177-
{sycl::aspect::fp16});
179+
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
178180

179181
if (freq_factors == nullptr) {
180182
/*
181183
DPCT1049:40: The work-group size passed to the SYCL kernel may exceed
182184
the limit. To get the device limit, query
183185
info::device::max_work_group_size. Adjust the work-group size if needed.
184186
*/
185-
stream->parallel_for(
186-
sycl::nd_range<3>(block_nums * block_dims, block_dims),
187-
[=](sycl::nd_item<3> item_ct1) {
188-
rope_norm<T, false>(x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows,
189-
ext_factor, attn_factor, corr_dims, theta_scale, freq_factors,
190-
item_ct1);
191-
});
187+
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
188+
rope_norm<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
189+
theta_scale, freq_factors, item_ct1);
190+
});
192191
} else {
193192
/*
194193
DPCT1049:41: The work-group size passed to the SYCL kernel may exceed
195194
the limit. To get the device limit, query
196195
info::device::max_work_group_size. Adjust the work-group size if needed.
197196
*/
198-
stream->parallel_for(
199-
sycl::nd_range<3>(block_nums * block_dims, block_dims),
200-
[=](sycl::nd_item<3> item_ct1) {
201-
rope_norm<T, true>(x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows,
202-
ext_factor, attn_factor, corr_dims, theta_scale, freq_factors,
203-
item_ct1);
204-
});
197+
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
198+
rope_norm<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
199+
theta_scale, freq_factors, item_ct1);
200+
});
205201
}
206202
}
207203

208204
template <typename T>
209-
static void rope_neox_sycl(
210-
const T *x, T *dst, int ne0, int n_dims, int nr, const int32_t *pos, float freq_scale, int p_delta_rows,
211-
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
205+
static void rope_neox_sycl(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2,
206+
const int n_dims, const int nr, const int32_t * pos, const float freq_scale,
207+
const float freq_base, const float ext_factor, const float attn_factor,
208+
const rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
212209
GGML_ASSERT(ne0 % 2 == 0);
213210
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
214-
const int num_blocks_x = (ne0 + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE);
211+
const int num_blocks_x = (ne0 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
215212
const sycl::range<3> block_nums(1, num_blocks_x, nr);
216213

217-
const float theta_scale = powf(freq_base, -2.0f/n_dims);
214+
const float theta_scale = powf(freq_base, -2.0f / n_dims);
218215

219-
dpct::has_capability_or_fail(stream->get_device(),
220-
{sycl::aspect::fp16});
216+
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
221217

222218
if (freq_factors == nullptr) {
223-
stream->parallel_for(
224-
sycl::nd_range<3>(block_nums * block_dims, block_dims),
225-
[=](sycl::nd_item<3> item_ct1) {
226-
rope_neox<T, false>(x, dst, ne0, n_dims, pos, freq_scale,
227-
p_delta_rows, ext_factor, attn_factor,
228-
corr_dims, theta_scale, freq_factors,
229-
item_ct1);
230-
});
219+
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
220+
rope_neox<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
221+
theta_scale, freq_factors, item_ct1);
222+
});
231223
} else {
232-
stream->parallel_for(
233-
sycl::nd_range<3>(block_nums * block_dims, block_dims),
234-
[=](sycl::nd_item<3> item_ct1) {
235-
rope_neox<T, true>(x, dst, ne0, n_dims, pos, freq_scale,
236-
p_delta_rows, ext_factor, attn_factor,
237-
corr_dims, theta_scale, freq_factors,
238-
item_ct1);
239-
});
224+
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
225+
rope_neox<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
226+
theta_scale, freq_factors, item_ct1);
227+
});
240228
}
241229
}
242230

@@ -272,7 +260,7 @@ static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1,
272260
}
273261
}
274262

275-
void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
263+
inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
276264

277265
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
278266
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
@@ -329,43 +317,46 @@ void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
329317
if (is_neox) {
330318
GGML_SYCL_DEBUG("%s: neox path\n", __func__);
331319
if (dst->src[0]->type == GGML_TYPE_F32) {
332-
rope_neox_sycl(
333-
(const float *)dst->src[0]->data, (float *)dst->data, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
334-
attn_factor, corr_dims, freq_factors, main_stream
335-
);
320+
rope_neox_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, s01, s02, n_dims, nr,
321+
pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream);
336322
} else if (dst->src[0]->type == GGML_TYPE_F16) {
337-
rope_neox_sycl(
338-
(const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
339-
attn_factor, corr_dims, freq_factors, main_stream
340-
);
323+
rope_neox_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, s01, s02,
324+
n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
325+
main_stream);
341326
} else {
342327
GGML_ABORT("fatal error");
343328
}
344329
} else if (is_vision) {
345330
GGML_SYCL_DEBUG("%s: vision path\n", __func__);
346331
if (dst->src[0]->type == GGML_TYPE_F16) {
347-
rope_vision_sycl((const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
348-
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, main_stream);
332+
rope_vision_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, ne02, s01,
333+
s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
334+
freq_factors, sections, main_stream);
349335
} else if (dst->src[0]->type == GGML_TYPE_F32) {
350-
rope_vision_sycl((const float *) dst->src[0]->data, (float *)dst->data, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
351-
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, main_stream);
336+
rope_vision_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims,
337+
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections,
338+
main_stream);
352339
} else {
353340
GGML_ABORT("Fatal error: Tensor type unsupported!");
354341
}
355342
} else {
356343
GGML_SYCL_DEBUG("%s: norm path\n", __func__);
357344
if (dst->src[0]->type == GGML_TYPE_F32) {
358-
rope_norm_sycl(
359-
(const float *)dst->src[0]->data, (float *)dst->data, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
360-
attn_factor, corr_dims, freq_factors, main_stream
361-
);
345+
rope_norm_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, s01, s02, n_dims, nr,
346+
pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream);
362347
} else if (dst->src[0]->type == GGML_TYPE_F16) {
363-
rope_norm_sycl(
364-
(const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
365-
attn_factor, corr_dims, freq_factors, main_stream
366-
);
348+
rope_norm_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, s01, s02,
349+
n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
350+
main_stream);
367351
} else {
368352
GGML_ABORT("fatal error");
369353
}
370354
}
371355
}
356+
357+
void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
358+
GGML_SYCL_DEBUG("call %s\n", __func__);
359+
ggml_sycl_op_rope(ctx, dst);
360+
GGML_SYCL_DEBUG("call %s done\n", __func__);
361+
}
362+

ggml/src/ggml-sycl/rope.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,6 @@
1515

1616
#include "common.hpp"
1717

18-
void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst);
18+
void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst);
1919

2020
#endif // GGML_SYCL_ROPE_HPP

0 commit comments

Comments
 (0)