Skip to content

Commit 5bf8d00

Browse files
k223kimpre-commit-ci[bot]t-vi
authored
feat: add rope indices (#1997)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Thomas Viehmann <[email protected]>
1 parent 2c4ff02 commit 5bf8d00

File tree

2 files changed

+18
-3
lines changed

2 files changed

+18
-3
lines changed

litgpt/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ class Config:
8282
# The base period of the RoPE embeddings for local attention.
8383
# If not provided, rope_theta will be used for both local and global attention.
8484
rope_local_base_freq: Optional[float] = None
85+
rope_indices: Optional[List] = None
8586

8687
def __post_init__(self):
8788
if not self.name:

litgpt/model.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,19 @@ def forward(
154154
if self.config.scale_embeddings:
155155
x = x * torch.tensor(self.config.n_embd**0.5, dtype=x.dtype)
156156

157-
for block in self.transformer.h:
158-
x = block(x, cos, sin, mask, input_pos, input_pos_maxp1)
157+
for block_idx, block in enumerate(self.transformer.h):
158+
if self.config.rope_indices is not None:
159+
x = block(
160+
x,
161+
cos[..., self.config.rope_indices[block_idx]],
162+
sin[..., self.config.rope_indices[block_idx]],
163+
mask,
164+
input_pos,
165+
input_pos_maxp1,
166+
)
167+
else:
168+
x = block(x, cos, sin, mask, input_pos, input_pos_maxp1)
169+
159170
x = self.transformer.ln_f(x)
160171
clamp_head = (
161172
partial(do_softcapping, thresh=self.config.final_logit_softcapping)
@@ -215,7 +226,10 @@ def set_kv_cache(
215226
dtype: Optional[torch.dtype] = None,
216227
) -> None:
217228
if rope_cache_length is None:
218-
rope_cache_length = self.cos.size(-1)
229+
if len(self.cos.shape) == 2:
230+
rope_cache_length = self.cos.size(-1)
231+
else:
232+
rope_cache_length = self.cos.size(-2)
219233

220234
if max_seq_length is None:
221235
max_seq_length = self.max_seq_length

0 commit comments

Comments
 (0)