Skip to content

Commit 6eba533

Browse files
authored
feat: add linear rope type (#1982)
1 parent e484a83 commit 6eba533

File tree

2 files changed

+69
-12
lines changed

2 files changed

+69
-12
lines changed

litgpt/model.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -588,19 +588,22 @@ def build_rope_cache(
588588
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem))
589589

590590
if extra_config is not None:
591-
orig_context_len = extra_config["original_max_seq_len"]
592591
factor = extra_config["factor"]
593-
low_freq_factor = extra_config["low_freq_factor"]
594-
high_freq_factor = extra_config["high_freq_factor"]
595-
596-
wavelen = 2 * torch.pi / theta
597-
ratio = orig_context_len / wavelen
598-
smooth_factor = (ratio - low_freq_factor) / (high_freq_factor - low_freq_factor)
599-
smooth_factor = torch.clamp(smooth_factor, min=0.0, max=1.0)
600-
601-
# Compute adjusted_theta without masked indexing
602-
adjusted_theta = (1 - smooth_factor) * (theta / factor) + smooth_factor * theta
603-
theta = adjusted_theta
592+
if "original_max_seq_len" in extra_config:
593+
orig_context_len = extra_config["original_max_seq_len"]
594+
low_freq_factor = extra_config["low_freq_factor"]
595+
high_freq_factor = extra_config["high_freq_factor"]
596+
597+
wavelen = 2 * torch.pi / theta
598+
ratio = orig_context_len / wavelen
599+
smooth_factor = (ratio - low_freq_factor) / (high_freq_factor - low_freq_factor)
600+
smooth_factor = torch.clamp(smooth_factor, min=0.0, max=1.0)
601+
602+
# Compute adjusted_theta without masked indexing
603+
adjusted_theta = (1 - smooth_factor) * (theta / factor) + smooth_factor * theta
604+
theta = adjusted_theta
605+
else:
606+
theta = theta / factor
604607

605608
# Create position indices `[0, 1, ..., seq_len - 1]`
606609
seq_idx = torch.arange(seq_len, device=device) / condense_ratio

tests/test_rope.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
22

3+
import pytest
34
import torch
45
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXRotaryEmbedding
56
from transformers.models.gpt_neox.modeling_gpt_neox import apply_rotary_pos_emb as apply_rotary_pos_emb_gptneo
@@ -218,6 +219,59 @@ def test_rope_llama_3_2():
218219
torch.testing.assert_close(theirs_k_rot, ours_k_rot)
219220

220221

222+
# See https://huggingface.co/google/gemma-3-27b-it/blob/main/config.json for settings
223+
# TODO: update HF transformers version to support Gemma3 and fix errors that causes after the update
224+
@pytest.mark.skip(reason="This test fails due to the HF transformers version not supporting Gemma3")
225+
@torch.inference_mode()
226+
def test_rope_gemma_3():
227+
from transformers.models.gemma3.configuration_gemma3 import Gemma3TextConfig
228+
from transformers.models.gemma3.modeling_gemma3 import Gemma3RotaryEmbedding, apply_rotary_pos_emb
229+
230+
head_dim = 32
231+
rope_theta = 50_000
232+
their_rope_config = {
233+
"factor": 8.0,
234+
"rope_type": "linear",
235+
}
236+
237+
our_rope_config = {"factor": 8.0}
238+
239+
##################################
240+
# Compare cos and sin
241+
##################################
242+
# transformer rope
243+
config = Gemma3TextConfig(rope_theta=rope_theta, rope_scaling=their_rope_config, head_dim=head_dim)
244+
rot_emb = Gemma3RotaryEmbedding(config=config)
245+
batch_size, seq_len = 1, 10
246+
qk_tensor = torch.randn(batch_size, seq_len, head_dim)
247+
position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0)
248+
theirs_cos, theirs_sin = rot_emb(qk_tensor, position_ids)
249+
250+
# our rope
251+
ours_cos, ours_sin = build_rope_cache(seq_len, n_elem=head_dim, base=rope_theta, extra_config=our_rope_config)
252+
ours_cos = ours_cos.unsqueeze(0)
253+
ours_sin = ours_sin.unsqueeze(0)
254+
torch.testing.assert_close(theirs_cos, ours_cos)
255+
torch.testing.assert_close(theirs_sin, ours_sin)
256+
257+
##################################
258+
# Compare rotated tensors
259+
##################################
260+
# Settings
261+
num_heads = 4
262+
263+
# Dummy query and key tensors
264+
torch.manual_seed(123)
265+
queries = torch.randn(batch_size, num_heads, seq_len, head_dim)
266+
keys = torch.randn(batch_size, num_heads, seq_len, head_dim)
267+
268+
ours_q_rot = apply_rope(queries, ours_cos, ours_sin)
269+
ours_k_rot = apply_rope(keys, ours_cos, ours_sin)
270+
theirs_q_rot, theirs_k_rot = apply_rotary_pos_emb(queries, keys, theirs_cos, theirs_sin)
271+
torch.testing.assert_close(theirs_q_rot, ours_q_rot)
272+
torch.testing.assert_close(theirs_k_rot, ours_k_rot)
273+
274+
221275
@torch.inference_mode()
222276
def test_rope_cos_sin_shapes_if_rope_n_elem_is_odd():
223277
bs, seq_len, n_head, n_embed = 1, 6, 2, 8

0 commit comments

Comments
 (0)