|
| 1 | +from typing import List, Set, Tuple |
| 2 | + |
| 3 | +import numpy as np |
| 4 | +import pytest |
| 5 | +import torch |
| 6 | + |
| 7 | +from vllm.utils import make_tensor_with_pad |
| 8 | +from vllm.v1.sample.metadata import SamplingMetadata |
| 9 | +from vllm.v1.sample.sampler import Sampler |
| 10 | + |
| 11 | +VOCAB_SIZE = 1024 |
| 12 | +NUM_OUTPUT_TOKENS = 20 |
| 13 | +CUDA_DEVICES = [ |
| 14 | + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) |
| 15 | +] |
| 16 | +MAX_NUM_PROMPT_TOKENS = 64 |
| 17 | + |
| 18 | + |
| 19 | +def _create_fake_logits(batch_size: int, vocab_size: int) -> torch.Tensor: |
| 20 | + fake_logits = torch.full((batch_size, vocab_size), 1e-2, dtype=torch.float) |
| 21 | + return fake_logits |
| 22 | + |
| 23 | + |
| 24 | +def _create_penalty_tensor(batch_size: int, penalty_value: float, |
| 25 | + device: torch.device) -> torch.Tensor: |
| 26 | + return torch.full((batch_size, ), |
| 27 | + fill_value=penalty_value, |
| 28 | + dtype=torch.float, |
| 29 | + device=device) |
| 30 | + |
| 31 | + |
| 32 | +def _create_prompt_tokens_tensor( |
| 33 | + prompt_token_ids: List[List[int]], |
| 34 | + vocab_size: int, |
| 35 | + device: torch.device, |
| 36 | +) -> torch.Tensor: |
| 37 | + return make_tensor_with_pad( |
| 38 | + prompt_token_ids, |
| 39 | + pad=vocab_size, |
| 40 | + device=device, |
| 41 | + dtype=torch.int64, |
| 42 | + pin_memory=False, |
| 43 | + ) |
| 44 | + |
| 45 | + |
| 46 | +def _create_default_sampling_metadata( |
| 47 | + num_output_tokens: int, |
| 48 | + batch_size: int, |
| 49 | + vocab_size: int, |
| 50 | + device: torch.device, |
| 51 | +) -> SamplingMetadata: |
| 52 | + output_token_ids: List[List[int]] = [] |
| 53 | + prompt_token_ids: List[List[int]] = [] |
| 54 | + for _ in range(batch_size): |
| 55 | + output_token_ids.append( |
| 56 | + np.random.randint(0, vocab_size, size=num_output_tokens).tolist()) |
| 57 | + prompt_token_ids.append( |
| 58 | + np.random.randint(0, |
| 59 | + vocab_size, |
| 60 | + size=np.random.randint( |
| 61 | + 1, MAX_NUM_PROMPT_TOKENS)).tolist()) |
| 62 | + fake_sampling_metadata = SamplingMetadata( |
| 63 | + temperature=torch.full((batch_size, ), 0.0), |
| 64 | + all_greedy=True, |
| 65 | + all_random=False, |
| 66 | + top_p=torch.empty(batch_size, ), |
| 67 | + top_k=torch.empty(batch_size, ), |
| 68 | + no_top_p=True, |
| 69 | + no_top_k=True, |
| 70 | + generators={}, |
| 71 | + max_num_logprobs=VOCAB_SIZE, |
| 72 | + prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids, |
| 73 | + vocab_size, device), |
| 74 | + output_token_ids=output_token_ids, |
| 75 | + frequency_penalties=_create_penalty_tensor(batch_size, 0.0, device), |
| 76 | + presence_penalties=_create_penalty_tensor(batch_size, 0.0, device), |
| 77 | + repetition_penalties=_create_penalty_tensor(batch_size, 1.0, device), |
| 78 | + no_penalties=True, |
| 79 | + min_tokens=[], |
| 80 | + stop_token_ids=[], |
| 81 | + ) |
| 82 | + return fake_sampling_metadata |
| 83 | + |
| 84 | + |
| 85 | +def _generate_min_token_penalties_and_stop_tokens( |
| 86 | + num_output_tokens: int, batch_size: int, vocab_size: int, |
| 87 | + batch_indices_for_min_token_penalty: List[int] |
| 88 | +) -> Tuple[List[int], List[Set[int]]]: |
| 89 | + """ |
| 90 | + Generates and returns a list of minimum token penalties (`min_tokens`) |
| 91 | + and a corresponding list of stop token IDs (`stop_token_ids`) for each |
| 92 | + batch. |
| 93 | +
|
| 94 | + If a batch index is included in `batch_indices_for_min_token_penalty`, |
| 95 | + a higher `min_tokens` value is assigned (within a randomized range), |
| 96 | + and a random set of stop token IDs is created. Otherwise, a lower |
| 97 | + `min_tokens` value is assigned, and the stop token IDs set is empty. |
| 98 | + """ |
| 99 | + stop_token_ids: List[Set[int]] = [] |
| 100 | + min_tokens: List[int] = [] |
| 101 | + for index in range(batch_size): |
| 102 | + if index in batch_indices_for_min_token_penalty: |
| 103 | + min_tokens.append( |
| 104 | + np.random.randint(num_output_tokens + 1, |
| 105 | + 2 * num_output_tokens)) |
| 106 | + stop_token_ids.append( |
| 107 | + set( |
| 108 | + np.random.randint(0, vocab_size - 1) |
| 109 | + for _ in range(np.random.randint(0, vocab_size)))) |
| 110 | + |
| 111 | + else: |
| 112 | + min_tokens.append(np.random.randint(0, num_output_tokens)) |
| 113 | + stop_token_ids.append(set()) |
| 114 | + return (min_tokens, stop_token_ids) |
| 115 | + |
| 116 | + |
| 117 | +def _create_weighted_output_token_list( |
| 118 | + batch_size: int, |
| 119 | + vocab_size: int) -> Tuple[List[List[int]], List[List[int]]]: |
| 120 | + """ |
| 121 | + Creates an output token list where each token occurs a distinct |
| 122 | + number of times. |
| 123 | +
|
| 124 | + For each batch, a random subset of token IDs is selected from the |
| 125 | + vocabulary. The selected tokens are then added to the output token |
| 126 | + list, each with a different frequency. |
| 127 | +
|
| 128 | + Returns: |
| 129 | + Tuple[List[List[int]], List[List[int]]]: |
| 130 | + - The first element is the output token list, where each sublist |
| 131 | + corresponds to a batch and contains tokens with weighted |
| 132 | + frequencies. |
| 133 | + - The second element is a list of distinct token IDs for each |
| 134 | + batch, ordered by their frequency in the corresponding output |
| 135 | + list. |
| 136 | + """ |
| 137 | + output_token_ids: List[List[int]] = [] |
| 138 | + sorted_token_ids_in_output: List[List[int]] = [] |
| 139 | + for _ in range(batch_size): |
| 140 | + distinct_token_ids = np.random.choice(vocab_size, |
| 141 | + size=np.random.randint(1, 10), |
| 142 | + replace=False).tolist() |
| 143 | + sorted_token_ids_in_output.append(distinct_token_ids) |
| 144 | + output_token_ids_for_batch = [] |
| 145 | + for index, token_id in enumerate(distinct_token_ids): |
| 146 | + output_token_ids_for_batch.extend( |
| 147 | + [token_id for _ in range(index + 1)]) |
| 148 | + output_token_ids.append(output_token_ids_for_batch) |
| 149 | + return (output_token_ids, sorted_token_ids_in_output) |
| 150 | + |
| 151 | + |
| 152 | +@pytest.mark.parametrize("device", CUDA_DEVICES) |
| 153 | +@pytest.mark.parametrize("batch_size", [1, 2, 32]) |
| 154 | +def test_sampler_min_tokens_penalty(device: str, batch_size: int): |
| 155 | + """ |
| 156 | + Tests that if the number of output tokens is less than |
| 157 | + SamplingParams.min_tokens then we will set the logits for |
| 158 | + the stop token ids to -inf. |
| 159 | + """ |
| 160 | + torch.set_default_device(device) |
| 161 | + fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE) |
| 162 | + sampling_metadata = _create_default_sampling_metadata( |
| 163 | + NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) |
| 164 | + batch_indices_for_min_token_penalty = np.random.randint( |
| 165 | + 0, batch_size - 1, size=np.random.randint(0, batch_size)).tolist() |
| 166 | + min_tokens, stop_token_ids = _generate_min_token_penalties_and_stop_tokens( |
| 167 | + NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, |
| 168 | + batch_indices_for_min_token_penalty) |
| 169 | + sampling_metadata.min_tokens = min_tokens |
| 170 | + sampling_metadata.stop_token_ids = stop_token_ids |
| 171 | + sampler = Sampler() |
| 172 | + sampler_output = sampler(fake_logits, sampling_metadata) |
| 173 | + for batch_idx in range(batch_size): |
| 174 | + for vocab in range(VOCAB_SIZE): |
| 175 | + # Verify that the logprobs for stop token ids is set |
| 176 | + # to -inf. |
| 177 | + logprob_index = torch.where( |
| 178 | + sampler_output.logprob_token_ids[batch_idx] == |
| 179 | + vocab)[0].item() |
| 180 | + if vocab in stop_token_ids[batch_idx]: |
| 181 | + assert sampler_output.logprobs[batch_idx][ |
| 182 | + logprob_index] == -float("inf") |
| 183 | + else: |
| 184 | + assert sampler_output.logprobs[batch_idx][ |
| 185 | + logprob_index] != -float("inf") |
| 186 | + |
| 187 | + |
| 188 | +@pytest.mark.parametrize("device", CUDA_DEVICES) |
| 189 | +@pytest.mark.parametrize("batch_size", [1, 2, 32]) |
| 190 | +@pytest.mark.parametrize("presence_penalty", [-2.0, 2.0]) |
| 191 | +def test_sampler_presence_penalty(device: str, batch_size: int, |
| 192 | + presence_penalty: float): |
| 193 | + """ |
| 194 | + Test to verify that if presence penalty is enabled then tokens |
| 195 | + are penalized as per their presence in the existing output. |
| 196 | + """ |
| 197 | + torch.set_default_device(device) |
| 198 | + # Create fake logits where each token is assigned the same |
| 199 | + # logit value. |
| 200 | + fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE) |
| 201 | + sampling_metadata = _create_default_sampling_metadata( |
| 202 | + NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) |
| 203 | + output_token_ids = sampling_metadata.output_token_ids |
| 204 | + sampling_metadata.presence_penalties = _create_penalty_tensor( |
| 205 | + batch_size, presence_penalty, torch.device(device)) |
| 206 | + sampling_metadata.no_penalties = False |
| 207 | + sampler = Sampler() |
| 208 | + sampler_output = sampler(fake_logits, sampling_metadata) |
| 209 | + for batch_idx in range(batch_size): |
| 210 | + # The logprobs in the SamplerOutput are arranged in descending order. |
| 211 | + # Since all tokens initially have the same logprobs, the non-penalized |
| 212 | + # tokens will appear at the beginning, while the penalized tokens |
| 213 | + # will appear at the end of the list. |
| 214 | + penalized_token_id = sampler_output.logprob_token_ids[batch_idx][ |
| 215 | + VOCAB_SIZE - 1] |
| 216 | + penalized_log_prod = sampler_output.logprobs[batch_idx][VOCAB_SIZE - 1] |
| 217 | + non_penalized_token_id = sampler_output.logprob_token_ids[batch_idx][0] |
| 218 | + non_penalized_log_prod = sampler_output.logprobs[batch_idx][0] |
| 219 | + assert non_penalized_log_prod > penalized_log_prod |
| 220 | + if presence_penalty > 0: |
| 221 | + # If `presence_penalty` is set to a value greater than 0, it |
| 222 | + # indicates a preference for new tokens over those already |
| 223 | + # present in the output. |
| 224 | + # Verify that the penalized token ID exists in the output, while the |
| 225 | + # non-penalized token ID does not. |
| 226 | + assert penalized_token_id in output_token_ids[batch_idx] |
| 227 | + assert non_penalized_token_id not in output_token_ids[batch_idx] |
| 228 | + elif presence_penalty < 0: |
| 229 | + # If `presence_penalty` is set to a value less than 0, it indicates |
| 230 | + # a preference for existing tokens over new ones. Verify that the |
| 231 | + # non-penalized token ID exists in the output, while the penalized |
| 232 | + # token ID does not. |
| 233 | + assert non_penalized_token_id in output_token_ids[batch_idx] |
| 234 | + assert penalized_token_id not in output_token_ids[batch_idx] |
| 235 | + |
| 236 | + |
| 237 | +@pytest.mark.parametrize("device", CUDA_DEVICES) |
| 238 | +@pytest.mark.parametrize("batch_size", [1, 2, 32]) |
| 239 | +@pytest.mark.parametrize("frequency_penalty", [-2.0, 2.0]) |
| 240 | +def test_sampler_frequency_penalty(device: str, batch_size: int, |
| 241 | + frequency_penalty: float): |
| 242 | + """ |
| 243 | + Test to verify that if frequency penalty is enabled then tokens are |
| 244 | + penalized as per their frequency of occurrence. |
| 245 | + """ |
| 246 | + torch.set_default_device(device) |
| 247 | + # Create fake logits where each token is assigned the same |
| 248 | + # logit value. |
| 249 | + fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE) |
| 250 | + sampling_metadata = _create_default_sampling_metadata( |
| 251 | + NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) |
| 252 | + sampling_metadata.frequency_penalties = _create_penalty_tensor( |
| 253 | + batch_size, frequency_penalty, torch.device(device)) |
| 254 | + output_token_ids, sorted_token_ids_in_output = \ |
| 255 | + _create_weighted_output_token_list(batch_size, VOCAB_SIZE) |
| 256 | + sampling_metadata.output_token_ids = output_token_ids |
| 257 | + sampling_metadata.no_penalties = False |
| 258 | + sampler = Sampler() |
| 259 | + sampler_output = sampler(fake_logits, sampling_metadata) |
| 260 | + for batch_idx in range(batch_size): |
| 261 | + logprobs_token_ids = sampler_output.logprob_token_ids[batch_idx] |
| 262 | + non_penalized_token_id = logprobs_token_ids[0] |
| 263 | + penalized_token_id = logprobs_token_ids[VOCAB_SIZE - 1] |
| 264 | + distinct_sorted_token_ids_in_output = \ |
| 265 | + sorted_token_ids_in_output[batch_idx] |
| 266 | + most_frequent_token_id = distinct_sorted_token_ids_in_output[ |
| 267 | + len(distinct_sorted_token_ids_in_output) - 1] |
| 268 | + if frequency_penalty > 0: |
| 269 | + # If `frequency_penalty` is set to > 0, it indicates |
| 270 | + # a preference for new tokens over existing ones. Verify that the |
| 271 | + # non-penalized token ID is not present in the output, while the |
| 272 | + # most penalized token is the one that occurs most frequently in |
| 273 | + # the output. |
| 274 | + assert non_penalized_token_id \ |
| 275 | + not in distinct_sorted_token_ids_in_output |
| 276 | + assert penalized_token_id == most_frequent_token_id |
| 277 | + elif frequency_penalty < 0: |
| 278 | + # If `frequency_penalty` is set to < 0, it indicates |
| 279 | + # a preference for existing tokens over new ones. Verify that the |
| 280 | + # non-penalized token ID is the one that occurs most frequently |
| 281 | + # in the output, while the penalized token ID is one that has not |
| 282 | + # yet appeared. |
| 283 | + assert non_penalized_token_id == most_frequent_token_id |
| 284 | + assert penalized_token_id \ |
| 285 | + not in distinct_sorted_token_ids_in_output |
| 286 | + |
| 287 | + |
| 288 | +@pytest.mark.parametrize("device", CUDA_DEVICES) |
| 289 | +@pytest.mark.parametrize("batch_size", [1, 2, 32]) |
| 290 | +@pytest.mark.parametrize("repetition_penalty", [0.1, 1.9]) |
| 291 | +def test_sampler_repetition_penalty(device: str, batch_size: int, |
| 292 | + repetition_penalty: float): |
| 293 | + """ |
| 294 | + Test to verify that when the repetition penalty is enabled, tokens |
| 295 | + are penalized based on their presence in the prompt or the existing |
| 296 | + output. |
| 297 | + """ |
| 298 | + torch.set_default_device(device) |
| 299 | + # Create fake logits where each token is assigned the same |
| 300 | + # logit value. |
| 301 | + fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE) |
| 302 | + sampling_metadata = _create_default_sampling_metadata( |
| 303 | + NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) |
| 304 | + sampling_metadata.repetition_penalties = _create_penalty_tensor( |
| 305 | + batch_size, repetition_penalty, torch.device(device)) |
| 306 | + sampling_metadata.no_penalties = False |
| 307 | + sampler = Sampler() |
| 308 | + sampler_output = sampler(fake_logits, sampling_metadata) |
| 309 | + for batch_idx in range(batch_size): |
| 310 | + logprobs_token_ids = sampler_output.logprob_token_ids[batch_idx] |
| 311 | + non_penalized_token_id = logprobs_token_ids[0] |
| 312 | + penalized_token_id = logprobs_token_ids[VOCAB_SIZE - 1] |
| 313 | + prompt_tokens = sampling_metadata.prompt_token_ids[ |
| 314 | + batch_idx][:].tolist() |
| 315 | + output_tokens = sampling_metadata.output_token_ids[batch_idx] |
| 316 | + if repetition_penalty > 1.0: |
| 317 | + # If `repetition_penalty` > 1.0, verify that the non-penalized |
| 318 | + # token ID has not been seen before, while the penalized token ID |
| 319 | + # exists either in the prompt or the output. |
| 320 | + assert (non_penalized_token_id not in prompt_tokens and \ |
| 321 | + non_penalized_token_id not in output_tokens) |
| 322 | + assert (penalized_token_id in prompt_tokens or \ |
| 323 | + penalized_token_id in output_tokens) |
| 324 | + elif repetition_penalty < 1.0: |
| 325 | + # If `repetition_penalty` < 1.0, verify that the penalized |
| 326 | + # token ID has not been seen before, while the non-penalized |
| 327 | + # token ID exists either in the prompt or the output. |
| 328 | + assert (penalized_token_id not in prompt_tokens and \ |
| 329 | + penalized_token_id not in output_tokens) |
| 330 | + assert (non_penalized_token_id in prompt_tokens or \ |
| 331 | + non_penalized_token_id in output_tokens) |
0 commit comments