|
1 | 1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
|
2 | 2 |
|
| 3 | +import pytest |
3 | 4 | import torch
|
4 | 5 | from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXRotaryEmbedding
|
5 | 6 | from transformers.models.gpt_neox.modeling_gpt_neox import apply_rotary_pos_emb as apply_rotary_pos_emb_gptneo
|
@@ -218,6 +219,59 @@ def test_rope_llama_3_2():
|
218 | 219 | torch.testing.assert_close(theirs_k_rot, ours_k_rot)
|
219 | 220 |
|
220 | 221 |
|
| 222 | +# See https://huggingface.co/google/gemma-3-27b-it/blob/main/config.json for settings |
| 223 | +# TODO: update HF transformers version to support Gemma3 and fix errors that causes after the update |
| 224 | +@pytest.mark.skip(reason="This test fails due to the HF transformers version not supporting Gemma3") |
| 225 | +@torch.inference_mode() |
| 226 | +def test_rope_gemma_3(): |
| 227 | + from transformers.models.gemma3.configuration_gemma3 import Gemma3TextConfig |
| 228 | + from transformers.models.gemma3.modeling_gemma3 import Gemma3RotaryEmbedding, apply_rotary_pos_emb |
| 229 | + |
| 230 | + head_dim = 32 |
| 231 | + rope_theta = 50_000 |
| 232 | + their_rope_config = { |
| 233 | + "factor": 8.0, |
| 234 | + "rope_type": "linear", |
| 235 | + } |
| 236 | + |
| 237 | + our_rope_config = {"factor": 8.0} |
| 238 | + |
| 239 | + ################################## |
| 240 | + # Compare cos and sin |
| 241 | + ################################## |
| 242 | + # transformer rope |
| 243 | + config = Gemma3TextConfig(rope_theta=rope_theta, rope_scaling=their_rope_config, head_dim=head_dim) |
| 244 | + rot_emb = Gemma3RotaryEmbedding(config=config) |
| 245 | + batch_size, seq_len = 1, 10 |
| 246 | + qk_tensor = torch.randn(batch_size, seq_len, head_dim) |
| 247 | + position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0) |
| 248 | + theirs_cos, theirs_sin = rot_emb(qk_tensor, position_ids) |
| 249 | + |
| 250 | + # our rope |
| 251 | + ours_cos, ours_sin = build_rope_cache(seq_len, n_elem=head_dim, base=rope_theta, extra_config=our_rope_config) |
| 252 | + ours_cos = ours_cos.unsqueeze(0) |
| 253 | + ours_sin = ours_sin.unsqueeze(0) |
| 254 | + torch.testing.assert_close(theirs_cos, ours_cos) |
| 255 | + torch.testing.assert_close(theirs_sin, ours_sin) |
| 256 | + |
| 257 | + ################################## |
| 258 | + # Compare rotated tensors |
| 259 | + ################################## |
| 260 | + # Settings |
| 261 | + num_heads = 4 |
| 262 | + |
| 263 | + # Dummy query and key tensors |
| 264 | + torch.manual_seed(123) |
| 265 | + queries = torch.randn(batch_size, num_heads, seq_len, head_dim) |
| 266 | + keys = torch.randn(batch_size, num_heads, seq_len, head_dim) |
| 267 | + |
| 268 | + ours_q_rot = apply_rope(queries, ours_cos, ours_sin) |
| 269 | + ours_k_rot = apply_rope(keys, ours_cos, ours_sin) |
| 270 | + theirs_q_rot, theirs_k_rot = apply_rotary_pos_emb(queries, keys, theirs_cos, theirs_sin) |
| 271 | + torch.testing.assert_close(theirs_q_rot, ours_q_rot) |
| 272 | + torch.testing.assert_close(theirs_k_rot, ours_k_rot) |
| 273 | + |
| 274 | + |
221 | 275 | @torch.inference_mode()
|
222 | 276 | def test_rope_cos_sin_shapes_if_rope_n_elem_is_odd():
|
223 | 277 | bs, seq_len, n_head, n_embed = 1, 6, 2, 8
|
|
0 commit comments