Skip to content

Commit 5cc6a6c

Browse files
sroy745WoosukKwon
authored andcommitted
[V1] Adding min tokens/repetition/presence/frequence penalties to V1 sampler (vllm-project#10681)
Signed-off-by: Sourashis Roy <[email protected]> Signed-off-by: Woosuk Kwon <[email protected]> Co-authored-by: Woosuk Kwon <[email protected]>
1 parent 767ccff commit 5cc6a6c

File tree

11 files changed

+879
-49
lines changed

11 files changed

+879
-49
lines changed

tests/v1/engine/test_engine_core.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,3 +139,41 @@ def test_engine_core(monkeypatch):
139139
engine_core.abort_requests([req2.request_id, req0.request_id])
140140
assert len(engine_core.scheduler.waiting) == 0
141141
assert len(engine_core.scheduler.running) == 0
142+
143+
144+
def test_engine_core_advanced_sampling(monkeypatch):
145+
"""
146+
A basic end-to-end test to verify that the engine functions correctly
147+
when additional sampling parameters, such as min_tokens and
148+
presence_penalty, are set.
149+
"""
150+
with monkeypatch.context() as m:
151+
m.setenv("VLLM_USE_V1", "1")
152+
"""Setup the EngineCore."""
153+
engine_args = EngineArgs(model=MODEL_NAME)
154+
vllm_config = engine_args.create_engine_config(
155+
usage_context=UsageContext.UNKNOWN_CONTEXT)
156+
executor_class = AsyncLLM._get_executor_cls(vllm_config)
157+
158+
engine_core = EngineCore(vllm_config=vllm_config,
159+
executor_class=executor_class,
160+
usage_context=UsageContext.UNKNOWN_CONTEXT)
161+
"""Test basic request lifecycle."""
162+
# First request.
163+
request: EngineCoreRequest = make_request()
164+
request.sampling_params = SamplingParams(
165+
min_tokens=4,
166+
presence_penalty=1.0,
167+
frequency_penalty=1.0,
168+
repetition_penalty=0.1,
169+
stop_token_ids=[1001, 1002],
170+
)
171+
engine_core.add_request(request)
172+
assert len(engine_core.scheduler.waiting) == 1
173+
assert len(engine_core.scheduler.running) == 0
174+
# Loop through until they are all done.
175+
while len(engine_core.step()) > 0:
176+
pass
177+
178+
assert len(engine_core.scheduler.waiting) == 0
179+
assert len(engine_core.scheduler.running) == 0

tests/v1/sample/__init__.py

Whitespace-only changes.

tests/v1/sample/test_sampler.py

Lines changed: 331 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,331 @@
1+
from typing import List, Set, Tuple
2+
3+
import numpy as np
4+
import pytest
5+
import torch
6+
7+
from vllm.utils import make_tensor_with_pad
8+
from vllm.v1.sample.metadata import SamplingMetadata
9+
from vllm.v1.sample.sampler import Sampler
10+
11+
VOCAB_SIZE = 1024
12+
NUM_OUTPUT_TOKENS = 20
13+
CUDA_DEVICES = [
14+
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
15+
]
16+
MAX_NUM_PROMPT_TOKENS = 64
17+
18+
19+
def _create_fake_logits(batch_size: int, vocab_size: int) -> torch.Tensor:
20+
fake_logits = torch.full((batch_size, vocab_size), 1e-2, dtype=torch.float)
21+
return fake_logits
22+
23+
24+
def _create_penalty_tensor(batch_size: int, penalty_value: float,
25+
device: torch.device) -> torch.Tensor:
26+
return torch.full((batch_size, ),
27+
fill_value=penalty_value,
28+
dtype=torch.float,
29+
device=device)
30+
31+
32+
def _create_prompt_tokens_tensor(
33+
prompt_token_ids: List[List[int]],
34+
vocab_size: int,
35+
device: torch.device,
36+
) -> torch.Tensor:
37+
return make_tensor_with_pad(
38+
prompt_token_ids,
39+
pad=vocab_size,
40+
device=device,
41+
dtype=torch.int64,
42+
pin_memory=False,
43+
)
44+
45+
46+
def _create_default_sampling_metadata(
47+
num_output_tokens: int,
48+
batch_size: int,
49+
vocab_size: int,
50+
device: torch.device,
51+
) -> SamplingMetadata:
52+
output_token_ids: List[List[int]] = []
53+
prompt_token_ids: List[List[int]] = []
54+
for _ in range(batch_size):
55+
output_token_ids.append(
56+
np.random.randint(0, vocab_size, size=num_output_tokens).tolist())
57+
prompt_token_ids.append(
58+
np.random.randint(0,
59+
vocab_size,
60+
size=np.random.randint(
61+
1, MAX_NUM_PROMPT_TOKENS)).tolist())
62+
fake_sampling_metadata = SamplingMetadata(
63+
temperature=torch.full((batch_size, ), 0.0),
64+
all_greedy=True,
65+
all_random=False,
66+
top_p=torch.empty(batch_size, ),
67+
top_k=torch.empty(batch_size, ),
68+
no_top_p=True,
69+
no_top_k=True,
70+
generators={},
71+
max_num_logprobs=VOCAB_SIZE,
72+
prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids,
73+
vocab_size, device),
74+
output_token_ids=output_token_ids,
75+
frequency_penalties=_create_penalty_tensor(batch_size, 0.0, device),
76+
presence_penalties=_create_penalty_tensor(batch_size, 0.0, device),
77+
repetition_penalties=_create_penalty_tensor(batch_size, 1.0, device),
78+
no_penalties=True,
79+
min_tokens=[],
80+
stop_token_ids=[],
81+
)
82+
return fake_sampling_metadata
83+
84+
85+
def _generate_min_token_penalties_and_stop_tokens(
86+
num_output_tokens: int, batch_size: int, vocab_size: int,
87+
batch_indices_for_min_token_penalty: List[int]
88+
) -> Tuple[List[int], List[Set[int]]]:
89+
"""
90+
Generates and returns a list of minimum token penalties (`min_tokens`)
91+
and a corresponding list of stop token IDs (`stop_token_ids`) for each
92+
batch.
93+
94+
If a batch index is included in `batch_indices_for_min_token_penalty`,
95+
a higher `min_tokens` value is assigned (within a randomized range),
96+
and a random set of stop token IDs is created. Otherwise, a lower
97+
`min_tokens` value is assigned, and the stop token IDs set is empty.
98+
"""
99+
stop_token_ids: List[Set[int]] = []
100+
min_tokens: List[int] = []
101+
for index in range(batch_size):
102+
if index in batch_indices_for_min_token_penalty:
103+
min_tokens.append(
104+
np.random.randint(num_output_tokens + 1,
105+
2 * num_output_tokens))
106+
stop_token_ids.append(
107+
set(
108+
np.random.randint(0, vocab_size - 1)
109+
for _ in range(np.random.randint(0, vocab_size))))
110+
111+
else:
112+
min_tokens.append(np.random.randint(0, num_output_tokens))
113+
stop_token_ids.append(set())
114+
return (min_tokens, stop_token_ids)
115+
116+
117+
def _create_weighted_output_token_list(
118+
batch_size: int,
119+
vocab_size: int) -> Tuple[List[List[int]], List[List[int]]]:
120+
"""
121+
Creates an output token list where each token occurs a distinct
122+
number of times.
123+
124+
For each batch, a random subset of token IDs is selected from the
125+
vocabulary. The selected tokens are then added to the output token
126+
list, each with a different frequency.
127+
128+
Returns:
129+
Tuple[List[List[int]], List[List[int]]]:
130+
- The first element is the output token list, where each sublist
131+
corresponds to a batch and contains tokens with weighted
132+
frequencies.
133+
- The second element is a list of distinct token IDs for each
134+
batch, ordered by their frequency in the corresponding output
135+
list.
136+
"""
137+
output_token_ids: List[List[int]] = []
138+
sorted_token_ids_in_output: List[List[int]] = []
139+
for _ in range(batch_size):
140+
distinct_token_ids = np.random.choice(vocab_size,
141+
size=np.random.randint(1, 10),
142+
replace=False).tolist()
143+
sorted_token_ids_in_output.append(distinct_token_ids)
144+
output_token_ids_for_batch = []
145+
for index, token_id in enumerate(distinct_token_ids):
146+
output_token_ids_for_batch.extend(
147+
[token_id for _ in range(index + 1)])
148+
output_token_ids.append(output_token_ids_for_batch)
149+
return (output_token_ids, sorted_token_ids_in_output)
150+
151+
152+
@pytest.mark.parametrize("device", CUDA_DEVICES)
153+
@pytest.mark.parametrize("batch_size", [1, 2, 32])
154+
def test_sampler_min_tokens_penalty(device: str, batch_size: int):
155+
"""
156+
Tests that if the number of output tokens is less than
157+
SamplingParams.min_tokens then we will set the logits for
158+
the stop token ids to -inf.
159+
"""
160+
torch.set_default_device(device)
161+
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
162+
sampling_metadata = _create_default_sampling_metadata(
163+
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
164+
batch_indices_for_min_token_penalty = np.random.randint(
165+
0, batch_size - 1, size=np.random.randint(0, batch_size)).tolist()
166+
min_tokens, stop_token_ids = _generate_min_token_penalties_and_stop_tokens(
167+
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE,
168+
batch_indices_for_min_token_penalty)
169+
sampling_metadata.min_tokens = min_tokens
170+
sampling_metadata.stop_token_ids = stop_token_ids
171+
sampler = Sampler()
172+
sampler_output = sampler(fake_logits, sampling_metadata)
173+
for batch_idx in range(batch_size):
174+
for vocab in range(VOCAB_SIZE):
175+
# Verify that the logprobs for stop token ids is set
176+
# to -inf.
177+
logprob_index = torch.where(
178+
sampler_output.logprob_token_ids[batch_idx] ==
179+
vocab)[0].item()
180+
if vocab in stop_token_ids[batch_idx]:
181+
assert sampler_output.logprobs[batch_idx][
182+
logprob_index] == -float("inf")
183+
else:
184+
assert sampler_output.logprobs[batch_idx][
185+
logprob_index] != -float("inf")
186+
187+
188+
@pytest.mark.parametrize("device", CUDA_DEVICES)
189+
@pytest.mark.parametrize("batch_size", [1, 2, 32])
190+
@pytest.mark.parametrize("presence_penalty", [-2.0, 2.0])
191+
def test_sampler_presence_penalty(device: str, batch_size: int,
192+
presence_penalty: float):
193+
"""
194+
Test to verify that if presence penalty is enabled then tokens
195+
are penalized as per their presence in the existing output.
196+
"""
197+
torch.set_default_device(device)
198+
# Create fake logits where each token is assigned the same
199+
# logit value.
200+
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
201+
sampling_metadata = _create_default_sampling_metadata(
202+
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
203+
output_token_ids = sampling_metadata.output_token_ids
204+
sampling_metadata.presence_penalties = _create_penalty_tensor(
205+
batch_size, presence_penalty, torch.device(device))
206+
sampling_metadata.no_penalties = False
207+
sampler = Sampler()
208+
sampler_output = sampler(fake_logits, sampling_metadata)
209+
for batch_idx in range(batch_size):
210+
# The logprobs in the SamplerOutput are arranged in descending order.
211+
# Since all tokens initially have the same logprobs, the non-penalized
212+
# tokens will appear at the beginning, while the penalized tokens
213+
# will appear at the end of the list.
214+
penalized_token_id = sampler_output.logprob_token_ids[batch_idx][
215+
VOCAB_SIZE - 1]
216+
penalized_log_prod = sampler_output.logprobs[batch_idx][VOCAB_SIZE - 1]
217+
non_penalized_token_id = sampler_output.logprob_token_ids[batch_idx][0]
218+
non_penalized_log_prod = sampler_output.logprobs[batch_idx][0]
219+
assert non_penalized_log_prod > penalized_log_prod
220+
if presence_penalty > 0:
221+
# If `presence_penalty` is set to a value greater than 0, it
222+
# indicates a preference for new tokens over those already
223+
# present in the output.
224+
# Verify that the penalized token ID exists in the output, while the
225+
# non-penalized token ID does not.
226+
assert penalized_token_id in output_token_ids[batch_idx]
227+
assert non_penalized_token_id not in output_token_ids[batch_idx]
228+
elif presence_penalty < 0:
229+
# If `presence_penalty` is set to a value less than 0, it indicates
230+
# a preference for existing tokens over new ones. Verify that the
231+
# non-penalized token ID exists in the output, while the penalized
232+
# token ID does not.
233+
assert non_penalized_token_id in output_token_ids[batch_idx]
234+
assert penalized_token_id not in output_token_ids[batch_idx]
235+
236+
237+
@pytest.mark.parametrize("device", CUDA_DEVICES)
238+
@pytest.mark.parametrize("batch_size", [1, 2, 32])
239+
@pytest.mark.parametrize("frequency_penalty", [-2.0, 2.0])
240+
def test_sampler_frequency_penalty(device: str, batch_size: int,
241+
frequency_penalty: float):
242+
"""
243+
Test to verify that if frequency penalty is enabled then tokens are
244+
penalized as per their frequency of occurrence.
245+
"""
246+
torch.set_default_device(device)
247+
# Create fake logits where each token is assigned the same
248+
# logit value.
249+
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
250+
sampling_metadata = _create_default_sampling_metadata(
251+
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
252+
sampling_metadata.frequency_penalties = _create_penalty_tensor(
253+
batch_size, frequency_penalty, torch.device(device))
254+
output_token_ids, sorted_token_ids_in_output = \
255+
_create_weighted_output_token_list(batch_size, VOCAB_SIZE)
256+
sampling_metadata.output_token_ids = output_token_ids
257+
sampling_metadata.no_penalties = False
258+
sampler = Sampler()
259+
sampler_output = sampler(fake_logits, sampling_metadata)
260+
for batch_idx in range(batch_size):
261+
logprobs_token_ids = sampler_output.logprob_token_ids[batch_idx]
262+
non_penalized_token_id = logprobs_token_ids[0]
263+
penalized_token_id = logprobs_token_ids[VOCAB_SIZE - 1]
264+
distinct_sorted_token_ids_in_output = \
265+
sorted_token_ids_in_output[batch_idx]
266+
most_frequent_token_id = distinct_sorted_token_ids_in_output[
267+
len(distinct_sorted_token_ids_in_output) - 1]
268+
if frequency_penalty > 0:
269+
# If `frequency_penalty` is set to > 0, it indicates
270+
# a preference for new tokens over existing ones. Verify that the
271+
# non-penalized token ID is not present in the output, while the
272+
# most penalized token is the one that occurs most frequently in
273+
# the output.
274+
assert non_penalized_token_id \
275+
not in distinct_sorted_token_ids_in_output
276+
assert penalized_token_id == most_frequent_token_id
277+
elif frequency_penalty < 0:
278+
# If `frequency_penalty` is set to < 0, it indicates
279+
# a preference for existing tokens over new ones. Verify that the
280+
# non-penalized token ID is the one that occurs most frequently
281+
# in the output, while the penalized token ID is one that has not
282+
# yet appeared.
283+
assert non_penalized_token_id == most_frequent_token_id
284+
assert penalized_token_id \
285+
not in distinct_sorted_token_ids_in_output
286+
287+
288+
@pytest.mark.parametrize("device", CUDA_DEVICES)
289+
@pytest.mark.parametrize("batch_size", [1, 2, 32])
290+
@pytest.mark.parametrize("repetition_penalty", [0.1, 1.9])
291+
def test_sampler_repetition_penalty(device: str, batch_size: int,
292+
repetition_penalty: float):
293+
"""
294+
Test to verify that when the repetition penalty is enabled, tokens
295+
are penalized based on their presence in the prompt or the existing
296+
output.
297+
"""
298+
torch.set_default_device(device)
299+
# Create fake logits where each token is assigned the same
300+
# logit value.
301+
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
302+
sampling_metadata = _create_default_sampling_metadata(
303+
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
304+
sampling_metadata.repetition_penalties = _create_penalty_tensor(
305+
batch_size, repetition_penalty, torch.device(device))
306+
sampling_metadata.no_penalties = False
307+
sampler = Sampler()
308+
sampler_output = sampler(fake_logits, sampling_metadata)
309+
for batch_idx in range(batch_size):
310+
logprobs_token_ids = sampler_output.logprob_token_ids[batch_idx]
311+
non_penalized_token_id = logprobs_token_ids[0]
312+
penalized_token_id = logprobs_token_ids[VOCAB_SIZE - 1]
313+
prompt_tokens = sampling_metadata.prompt_token_ids[
314+
batch_idx][:].tolist()
315+
output_tokens = sampling_metadata.output_token_ids[batch_idx]
316+
if repetition_penalty > 1.0:
317+
# If `repetition_penalty` > 1.0, verify that the non-penalized
318+
# token ID has not been seen before, while the penalized token ID
319+
# exists either in the prompt or the output.
320+
assert (non_penalized_token_id not in prompt_tokens and \
321+
non_penalized_token_id not in output_tokens)
322+
assert (penalized_token_id in prompt_tokens or \
323+
penalized_token_id in output_tokens)
324+
elif repetition_penalty < 1.0:
325+
# If `repetition_penalty` < 1.0, verify that the penalized
326+
# token ID has not been seen before, while the non-penalized
327+
# token ID exists either in the prompt or the output.
328+
assert (penalized_token_id not in prompt_tokens and \
329+
penalized_token_id not in output_tokens)
330+
assert (non_penalized_token_id in prompt_tokens or \
331+
non_penalized_token_id in output_tokens)

tests/v1/worker/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)