@@ -154,8 +154,19 @@ def forward(
154
154
if self .config .scale_embeddings :
155
155
x = x * torch .tensor (self .config .n_embd ** 0.5 , dtype = x .dtype )
156
156
157
- for block in self .transformer .h :
158
- x = block (x , cos , sin , mask , input_pos , input_pos_maxp1 )
157
+ for block_idx , block in enumerate (self .transformer .h ):
158
+ if self .config .rope_indices is not None :
159
+ x = block (
160
+ x ,
161
+ cos [..., self .config .rope_indices [block_idx ]],
162
+ sin [..., self .config .rope_indices [block_idx ]],
163
+ mask ,
164
+ input_pos ,
165
+ input_pos_maxp1 ,
166
+ )
167
+ else :
168
+ x = block (x , cos , sin , mask , input_pos , input_pos_maxp1 )
169
+
159
170
x = self .transformer .ln_f (x )
160
171
clamp_head = (
161
172
partial (do_softcapping , thresh = self .config .final_logit_softcapping )
@@ -215,7 +226,10 @@ def set_kv_cache(
215
226
dtype : Optional [torch .dtype ] = None ,
216
227
) -> None :
217
228
if rope_cache_length is None :
218
- rope_cache_length = self .cos .size (- 1 )
229
+ if len (self .cos .shape ) == 2 :
230
+ rope_cache_length = self .cos .size (- 1 )
231
+ else :
232
+ rope_cache_length = self .cos .size (- 2 )
219
233
220
234
if max_seq_length is None :
221
235
max_seq_length = self .max_seq_length
0 commit comments