Skip to content

Commit 03e99b8

Browse files
tdoublepkgreenewaldlallison2
committed
Initial working implementation of a-LoRA.
Co-authored-by: Greenewald <[email protected]> Co-authored-by: Allison Li <[email protected]> Signed-off-by: Thomas Parnell <[email protected]>
1 parent dec66d2 commit 03e99b8

File tree

12 files changed

+401
-11
lines changed

12 files changed

+401
-11
lines changed
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
# After starting server using "vllm serve <model> --enable_lora --lora_modules..."
4+
5+
import time
6+
7+
from openai import OpenAI
8+
9+
model_id = "ibm-granite/granite-3.2-8b-instruct"
10+
11+
# Modify OpenAI's API key and API base to use vLLM's API server.
12+
openai_api_key = "EMPTY"
13+
openai_api_base = "http://localhost:8000/v1"
14+
client = OpenAI(
15+
api_key=openai_api_key,
16+
base_url=openai_api_base,
17+
)
18+
19+
BASE_NAME = "ibm-granite/granite-3.2-8b-instruct"
20+
ALORA_NAME = "new_alora" # "ibm-granite/granite-3.2-8b-alora-uncertainty"
21+
invocation_string = "<|start_of_role|>certainty<|end_of_role|>"
22+
23+
###################################################################
24+
prompts = [
25+
"<|start_of_role|>user<|end_of_role|>What is MIT?<|end_of_text|>",
26+
"What is MIT?",
27+
(
28+
"<|start_of_role|>user<|end_of_role|>What is the capital of "
29+
"Massachusetts?<|end_of_text|>\n"
30+
),
31+
"<|start_of_role|>user<|end_of_role|>What is MIT?<|end_of_text|>",
32+
(
33+
"<|start_of_role|>user<|end_of_role|>What is the capital of "
34+
"Massachusetts?<|end_of_text|>\n"
35+
),
36+
"<|start_of_role|>user<|end_of_role|>What is MIT?<|end_of_text|>",
37+
]
38+
39+
# Base model call
40+
outputs_base = client.completions.create(
41+
model=BASE_NAME, prompt=prompts, temperature=0, max_tokens=600
42+
)
43+
44+
choices = outputs_base.choices
45+
generated_text = []
46+
for i in range(len(prompts)):
47+
prompt = prompts[i]
48+
49+
generated_text += [outputs_base.choices[i].text]
50+
print(f"Prompt: {prompt!r}, Generated text: {generated_text[-1]!r}")
51+
52+
prompts_alora = [
53+
x + y + "<|end_of_text|>\n" + invocation_string
54+
for x, y in zip(prompts, generated_text)
55+
]
56+
57+
# Base model with aLoRA call
58+
t0 = time.time()
59+
alora_outputs = client.completions.create(
60+
model=ALORA_NAME, prompt=prompts_alora, temperature=0, max_tokens=10
61+
)
62+
t = time.time() - t0
63+
print(f"Time: {t}")
64+
for i in range(len(prompts_alora)):
65+
prompt = prompts_alora[i]
66+
generated_text = alora_outputs.choices[i].text
67+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#!/bin/bash
2+
3+
# More documentation: https://docs.vllm.ai/en/v0.8.3/serving/openai_compatible_server.html#vllm-serve
4+
export VLLM_USE_V1="1"
5+
# Specify base model (and optionally loras) to load in when starting the server.
6+
vllm serve ibm-granite/granite-3.2-8b-instruct \
7+
--enable-lora \
8+
--lora-modules '{"name": "new_alora", "path": "/proj/dmfexp/statllm/users/kgreenewald/.cache/huggingface/models/hub/models--ibm-granite--granite-3.2-8b-alora-uncertainty/snapshots/6109ad88201426003e696d023ec67c19e7f3d444", "base_model_name": "ibm-granite/granite-3.2-8b-instruct"}' \
9+
--dtype bfloat16 \
10+
--max-lora-rank 64 \
11+
--enable-prefix-caching
12+
#--no-enable-prefix-caching
13+
# Check that the lora model is listed along with other models.
14+
#curl localhost:8000/v1/models | jq .
15+
16+
###########################################
17+
18+
# A second option is to enable dynamic adapter loading instead of at start-up.
19+
#export VLLM_ALLOW_RUNTIME_LORA_UPDATING=True
20+
21+
#curl -X POST http://localhost:8000/v1/load_lora_adapter \
22+
#-H "Content-Type: application/json" \
23+
#-d '{
24+
# "lora_name": "new_alora",
25+
# "lora_path": "/path/to/new_alora"
26+
#}'
27+
# Should return "200 OK - Success: LoRA adapter 'new_alora' added successfully"
28+
29+
# Example of dynamically unloading an adapter.
30+
# curl -X POST http://localhost:8000/v1/unload_lora_adapter \
31+
# -H "Content-Type: application/json" \
32+
# -d '{
33+
# "lora_name": "new_alora"
34+
# }'
35+
36+
###########################################
37+
38+
# Send a request using the new aLoRA
39+
#curl http://localhost:8000/v1/completions \
40+
# -H "Content-Type: application/json" \
41+
# -d '{
42+
# "model": "new_alora",
43+
# "prompt": ""What is MIT?"",
44+
# "max_tokens": 600,
45+
# "temperature": 0
46+
# }' | jq

examples/alora/new_alora_testing.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import os
4+
import time
5+
6+
import torch
7+
from huggingface_hub import snapshot_download
8+
9+
from vllm import LLM, SamplingParams
10+
from vllm.lora.request import LoRARequest
11+
12+
BASE_NAME = "ibm-granite/granite-3.2-8b-instruct"
13+
ALORA_NAME = "ibm-granite/granite-3.2-8b-alora-uncertainty"
14+
invocation_string = "<|start_of_role|>certainty<|end_of_role|>"
15+
16+
os.environ["VLLM_USE_V1"] = "1"
17+
os.environ["VLLM_V1_USE_DEMO_LOGGING"] = "1"
18+
19+
# download your LoRA adapter to ~/.cache/huggingface/…
20+
alora_path = snapshot_download(repo_id=ALORA_NAME)
21+
22+
print(alora_path)
23+
#######################################
24+
25+
26+
llm = LLM(
27+
model=BASE_NAME,
28+
enable_lora=True,
29+
enforce_eager=True,
30+
dtype=torch.bfloat16,
31+
enable_prefix_caching=True, # enable APC
32+
max_lora_rank=64,
33+
enable_chunked_prefill=False,
34+
)
35+
36+
prompts = [
37+
(
38+
"<|start_of_role|>user<|end_of_role|>What is MIT?<|end_of_text|>\n"
39+
"<|start_of_role|>assistant<|end_of_role|>"
40+
),
41+
]
42+
43+
sampling_params = SamplingParams(temperature=0, max_tokens=600)
44+
45+
outputsBase = llm.generate(
46+
prompts,
47+
sampling_params,
48+
)
49+
generated_text = []
50+
for output in outputsBase:
51+
prompt = output.prompt
52+
generated_text += [output.outputs[0].text]
53+
print(f"Prompt: {prompt!r}, Generated text: {generated_text[-1]!r}")
54+
55+
prompts_alora = [
56+
x + y + "<|end_of_text|>\n" + invocation_string
57+
for x, y in zip(prompts, generated_text)
58+
]
59+
60+
sampling_params = SamplingParams(temperature=0, max_tokens=10)
61+
62+
t0 = time.time()
63+
outputs = llm.generate(
64+
prompts_alora,
65+
sampling_params,
66+
lora_request=LoRARequest("UQ_adapter", 1, alora_path),
67+
)
68+
t = time.time() - t0
69+
print(f"Time: {t}")
70+
71+
for output in outputs:
72+
prompt = output.prompt
73+
generated_text = output.outputs[0].text
74+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

vllm/envs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@
120120
VLLM_USE_DEEP_GEMM: bool = False
121121
VLLM_XGRAMMAR_CACHE_MB: int = 0
122122
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
123+
VLLM_V1_USE_DEMO_LOGGING: bool = True
123124
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False
124125
VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost"
125126
VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557
@@ -835,6 +836,10 @@ def get_vllm_port() -> Optional[int]:
835836
"VLLM_MSGPACK_ZERO_COPY_THRESHOLD":
836837
lambda: int(os.getenv("VLLM_MSGPACK_ZERO_COPY_THRESHOLD", "256")),
837838

839+
# Useful for demo
840+
"VLLM_V1_USE_DEMO_LOGGING":
841+
lambda: os.environ.get("VLLM_V1_USE_DEMO_LOGGING", "0") == "1",
842+
838843
# If set, allow insecure serialization using pickle.
839844
# This is useful for environments where it is deemed safe to use the
840845
# insecure method and it is needed for some reason.

vllm/forward_context.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,12 @@
2626
batchsize_forward_time: defaultdict = defaultdict(list)
2727

2828

29+
@dataclass
30+
class ALoRAMetadata:
31+
k_offsets: torch.Tensor
32+
query_start_locs: list[int]
33+
34+
2935
@dataclass
3036
class DPMetadata:
3137
max_tokens_across_dp_cpu: torch.Tensor
@@ -94,6 +100,7 @@ class ForwardContext:
94100
virtual_engine: int # set dynamically for each forward pass
95101
# set dynamically for each forward pass
96102
dp_metadata: Optional[DPMetadata] = None
103+
alora_metadata: Optional[ALoRAMetadata] = None
97104
skip_cuda_graphs: bool = False
98105

99106

@@ -116,6 +123,7 @@ def set_forward_context(
116123
num_tokens: Optional[int] = None,
117124
num_tokens_across_dp: Optional[torch.Tensor] = None,
118125
skip_cuda_graphs: bool = False,
126+
alora_metadata: Optional[ALoRAMetadata] = None,
119127
):
120128
"""A context manager that stores the current forward context,
121129
can be attention metadata, etc.
@@ -140,6 +148,7 @@ def set_forward_context(
140148
virtual_engine=virtual_engine,
141149
attn_metadata=attn_metadata,
142150
dp_metadata=dp_metadata,
151+
alora_metadata=alora_metadata,
143152
skip_cuda_graphs=skip_cuda_graphs,
144153
)
145154

vllm/lora/layers.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
tensor_model_parallel_all_gather,
2020
tensor_model_parallel_all_reduce)
2121
from vllm.distributed.utils import divide
22+
from vllm.forward_context import get_forward_context
2223
# yapf: disable
2324
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
2425
LinearBase,
@@ -418,14 +419,44 @@ def apply(self,
418419
output = output.flatten(0, 1)
419420
x = x.flatten(0, 1)
420421

421-
lora_output: Optional[
422-
torch.Tensor] = self.punica_wrapper.add_lora_linear(
423-
output, x, self.lora_a_stacked, self.lora_b_stacked,
424-
self.lora_bias_stacked, 1.0, self.output_slices)
425-
if not current_platform.can_update_inplace():
426-
output = lora_output
427-
428-
return output
422+
# Extract aLoRA batch metadata from forward context
423+
alora_metadata = get_forward_context().alora_metadata
424+
k_offsets = alora_metadata.k_offsets
425+
query_start_locs = alora_metadata.query_start_locs
426+
427+
# Build the 1D “save‐prefix” mask:
428+
T = output.size(0) # total tokens
429+
starts = query_start_locs[:-1] # starts and end index of each request
430+
ends = query_start_locs[1:]
431+
lengths = ends - starts # request lengths
432+
kept_lens = lengths - k_offsets
433+
kept_lens = torch.clamp(
434+
kept_lens,
435+
min=0) # portion of request to keep as base model weights
436+
437+
device = output.device
438+
# Create the alora mask
439+
delta = torch.zeros(T + 1, device=device, dtype=output.dtype)
440+
ends_for_scatter = starts + kept_lens
441+
pos_vals = kept_lens.sign().to(output.dtype)
442+
neg_vals = -pos_vals
443+
delta.scatter_add_(0, starts, pos_vals)
444+
delta.scatter_add_(0, ends_for_scatter, neg_vals)
445+
cums = torch.cumsum(delta[:-1], dim=0)
446+
mask1d = cums > 0 # shape [T], bool
447+
mask2d = mask1d.unsqueeze(1).to(output.dtype)
448+
449+
# Clone base layer output before running LoRA
450+
orig_out = output.clone()
451+
452+
# Apply LoRA in‐place on `output`:
453+
self.punica_wrapper.add_lora_linear(output, x, self.lora_a_stacked,
454+
self.lora_b_stacked,
455+
self.lora_bias_stacked, 1.0,
456+
self.output_slices)
457+
# Apply alora mask
458+
final_output = orig_out.mul(mask2d) + output.mul(1.0 - mask2d)
459+
return final_output
429460

430461
@property
431462
def weight(self) -> torch.Tensor:

vllm/lora/request.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ class LoRARequest(
3333
long_lora_max_len: Optional[int] = None
3434
base_model_name: Optional[str] = msgspec.field(default=None)
3535
tensorizer_config_dict: Optional[dict] = None
36+
invocation_tokens: Optional[list[int]] = None
37+
k_offset: Optional[int] = None
3638

3739
def __post_init__(self):
3840
if self.lora_local_path:

vllm/model_executor/layers/linear.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torch.nn as nn
1010
from torch.nn.parameter import Parameter, UninitializedParameter
1111

12+
from vllm.config import get_current_vllm_config
1213
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
1314
get_tensor_model_parallel_world_size,
1415
split_tensor_along_last_dim,
@@ -229,6 +230,12 @@ def __init__(
229230
):
230231
super().__init__()
231232

233+
# tpa -- find out why this is needed
234+
compilation_config = get_current_vllm_config().compilation_config
235+
if prefix in compilation_config.static_forward_context:
236+
raise ValueError(f"Duplicate layer name: {prefix}")
237+
compilation_config.static_forward_context[prefix] = self
238+
232239
# Keep input parameters
233240
self.input_size = input_size
234241
self.output_size = output_size

vllm/v1/core/kv_cache_utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,20 @@ def hash_request_tokens(hash_function: Any, block_size: int,
457457
token_ids = request.all_token_ids
458458

459459
req_need_extra_keys = need_extra_keys(request)
460+
if (request.lora_request is not None
461+
and request.lora_request.invocation_tokens is not None):
462+
use_alora = True
463+
invocation_tokens = request.lora_request.invocation_tokens
464+
# scan backward for the last match (faster than full forward scan+max)
465+
invocation_start = -1
466+
n = len(invocation_tokens)
467+
for idx in range(len(token_ids) - n, -1, -1):
468+
if token_ids[idx:idx + n] == invocation_tokens:
469+
# weights activated 1 token after start
470+
invocation_start = idx + 1
471+
break
472+
else:
473+
use_alora = False
460474
req_extra_keys = None
461475
curr_mm_idx = 0
462476

@@ -473,6 +487,8 @@ def hash_request_tokens(hash_function: Any, block_size: int,
473487
# MM and LoRA requests need extra keys for block-hash computation.
474488
req_extra_keys, curr_mm_idx = generate_block_hash_extra_keys(
475489
request, start, end, curr_mm_idx)
490+
if use_alora and end <= invocation_start:
491+
req_extra_keys = None # cache is equivalent to base model cache
476492

477493
block_hash = hash_block_tokens(hash_function, parent_block_hash_value,
478494
block_token_ids, req_extra_keys)

0 commit comments

Comments
 (0)