Skip to content

Async vllm #693

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
May 22, 2025
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ nanotron = [
"tensorboardX"
]
tensorboardX = ["tensorboardX"]
vllm = ["vllm>=0.7.0", "ray", "more_itertools"]
vllm = ["vllm==0.8.4", "ray==2.43.0", "more_itertools"]
quality = ["ruff==v0.2.2","pre-commit"]
tests = ["pytest==7.4.0","deepdiff"]
dev = ["lighteval[accelerate,quality,tests,multilingual,math,extended_tasks,vllm]"]
Expand Down
1 change: 1 addition & 0 deletions src/lighteval/models/abstract_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class ModelInfo:

class LightevalModel(ABC):
DATASET_SPLITS = 4
is_async = False

"""Abstract model class defining the API that every model to plug into lighteval must follow."""

Expand Down
8 changes: 5 additions & 3 deletions src/lighteval/models/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from lighteval.models.transformers.adapter_model import AdapterModel, AdapterModelConfig
from lighteval.models.transformers.delta_model import DeltaModel, DeltaModelConfig
from lighteval.models.transformers.transformers_model import TransformersModel, TransformersModelConfig
from lighteval.models.vllm.vllm_model import VLLMModel, VLLMModelConfig
from lighteval.models.vllm.vllm_model import AsyncVLLMModel, VLLMModel, VLLMModelConfig
from lighteval.utils.imports import (
NO_LITELLM_ERROR_MSG,
NO_SGLANG_ERROR_MSG,
Expand Down Expand Up @@ -160,8 +160,10 @@ def load_model_with_accelerate_or_default(
elif isinstance(config, VLLMModelConfig):
if not is_vllm_available():
raise ImportError(NO_VLLM_ERROR_MSG)
model = VLLMModel(config=config)
return model
if config.is_async:
model = AsyncVLLMModel(config=config)
else:
model = VLLMModel(config=config)
else:
model = TransformersModel(config=config)

Expand Down
9 changes: 9 additions & 0 deletions src/lighteval/models/transformers/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
# SOFTWARE.

import logging
import shutil
from contextlib import nullcontext

import torch
Expand Down Expand Up @@ -105,3 +106,11 @@ def _create_auto_model(self) -> transformers.PreTrainedModel:
)

return model

def cleanup(self):
try:
tmp_weights_dir = f"{self.model_name}-adapter-applied"
shutil.rmtree(tmp_weights_dir)
logger.info(f"Removed {tmp_weights_dir}")
except OSError:
pass
9 changes: 9 additions & 0 deletions src/lighteval/models/transformers/delta_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
# SOFTWARE.

import logging
import shutil
from contextlib import nullcontext

import torch
Expand Down Expand Up @@ -87,3 +88,11 @@ def _create_auto_model(
)

return model

def cleanup(self):
try:
tmp_weights_dir = f"{self.model_name}-delta-applied"
shutil.rmtree(tmp_weights_dir)
logger.info(f"Removed {tmp_weights_dir}")
except OSError:
pass
191 changes: 188 additions & 3 deletions src/lighteval/models/vllm/vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import asyncio
import gc
import itertools
import logging
import os
from typing import Optional
from typing import Coroutine, Optional

import torch
from pydantic import NonNegativeFloat, PositiveInt
Expand All @@ -51,9 +52,13 @@
if is_vllm_available():
import ray
from more_itertools import distribute
from vllm import LLM, SamplingParams
from vllm.distributed.parallel_state import destroy_distributed_environment, destroy_model_parallel
from vllm import LLM, RequestOutput, SamplingParams
from vllm.distributed.parallel_state import (
destroy_distributed_environment,
destroy_model_parallel,
)
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.v1.engine.async_llm import AsyncEngineArgs, AsyncLLM

logging.getLogger("vllm").propagate = True
logging.getLogger("vllm").handlers.clear()
Expand All @@ -62,7 +67,9 @@
logging.getLogger("ray").handlers.clear()
else:
LLM = None
AsyncLLM = None
SamplingParams = None
AsyncEngineArgs = None
get_tokenizer = None
ray = None
distribute = None
Expand Down Expand Up @@ -93,6 +100,7 @@ class VLLMModelConfig(ModelConfig):
max_num_seqs: PositiveInt = 128 # maximum number of sequences per iteration; This variable and `max_num_batched_tokens` effectively control the batch size at prefill stage. See https://github.com/vllm-project/vllm/issues/2492 for detailed explaination.
max_num_batched_tokens: PositiveInt = 2048 # maximum number of tokens per batch
subfolder: str | None = None
is_async: bool = False # Whether to use the async version or sync version of the model


class VLLMModel(LightevalModel):
Expand Down Expand Up @@ -411,3 +419,180 @@ def loglikelihood_rolling():

def loglikelihood_single_token():
pass


class AsyncVLLMModel(VLLMModel):
"""VLLM models which deploy async natively (no ray). Supports DP and PP/TP but not batch size > 1"""

DATASET_SPLITS = 1
is_async = True

def cleanup(self):
gc.collect()
destroy_distributed_environment()
torch.cuda.empty_cache()

def _create_auto_model(self, config: VLLMModelConfig) -> Optional[AsyncLLM]:
"""
Creates an instance of the async vllm model loaded from HF. Requires using the v1 of VLLM.

Returns:
AsyncLLM: The created async VLLM instance
"""
self.model_args = {
"model": config.model_name,
"gpu_memory_utilization": config.gpu_memory_utilization,
"revision": config.revision + (f"/{config.subfolder}" if config.subfolder is not None else ""),
"dtype": config.dtype,
"trust_remote_code": config.trust_remote_code,
"tensor_parallel_size": config.tensor_parallel_size,
"data_parallel_size": config.data_parallel_size,
"pipeline_parallel_size": config.pipeline_parallel_size,
"max_model_len": self._max_length,
"swap_space": 4,
"seed": int(config.seed),
"max_num_seqs": int(config.max_num_seqs),
"max_num_batched_tokens": int(config.max_num_batched_tokens),
"enforce_eager": True,
}

if config.data_parallel_size > 1:
self._batch_size = "auto"

model = AsyncLLM.from_engine_args(AsyncEngineArgs(**self.model_args))

# If the max_length can't get extracted from the config, it will be inferred from the model
if self._max_length is None:
self._max_length = model.model_config.max_seq_len_to_capture

return model

async def _async_one_item(
self,
index: int,
request: GreedyUntilRequest | LoglikelihoodRequest,
) -> Coroutine[None, list, str]:
"""Contains the actual logic of the generation."""
sampling_params = SamplingParams(**self._config.generation_parameters.to_vllm_dict())

if isinstance(request, LoglikelihoodRequest):
sampling_params.temperature = 0
sampling_params.prompt_logprobs = 1
sampling_params.max_tokens = 1
sampling_params.detokenize = False
prompt = request.context + request.choice
index = f"logprob_{index}"
elif isinstance(request, GreedyUntilRequest):
sampling_params.n = request.num_samples
sampling_params.max_tokens = self._config.generation_parameters.max_new_tokens or request.generation_size
sampling_params.stop = [] if self.use_chat_template else request.stop_sequence
sampling_params.logprobs = int(request.use_logits)
prompt = request.context
index = f"generative_{index}"

generator = self.model.generate(request_id=str(index), prompt=prompt, sampling_params=sampling_params)
try:
while output := await anext(generator):
continue
except StopAsyncIteration:
pass

return output

async def _async_batch(self, requests: list[GreedyUntilRequest | LoglikelihoodRequest]) -> list:
processed_requests = [
self._async_one_item(index=index, request=request) for index, request in enumerate(requests)
]
results = await asyncio.gather(*processed_requests)
return results

async def greedy_until(
self,
requests: list[GreedyUntilRequest],
**kwargs,
) -> list[GenerativeResponse]:
"""
Generates responses using a greedy decoding strategy until certain ending conditions are met.

Args:
requests (list[Request]): list of requests containing the context and ending conditions.

Returns:
list[GenerateReturn]: list of generated responses.
"""
for request in requests:
request.stop_sequence = as_list(request.stop_sequence) + [self.tokenizer.eos_token]
request.tokenized_context = self.tok_encode(request.context)

results = []

responses: list[RequestOutput] = await self._async_batch(requests=requests)

for response in responses:
output_token_ids = [outputs.token_ids for outputs in response.outputs]
full_logprobs = [output.logprobs for output in response.outputs] or []
logprobs = [logprob[token_id].logprob for token_id, logprob in zip(output_token_ids[0], full_logprobs[0])]
result = [output.text for output in response.outputs]
input_token_ids = response.prompt_token_ids

cur_response = GenerativeResponse(
result=result,
logits=logprobs,
generated_tokens=list(output_token_ids),
input_tokens=input_token_ids,
)
results.append(cur_response)

return results

async def loglikelihood(
self,
requests: list[LoglikelihoodRequest],
return_bool_score: bool = True,
**kwargs,
) -> list[LoglikelihoodResponse]:
"""
Generates responses using a greedy decoding strategy until certain ending conditions are met and
stores the logprobs.

Args:
requests (list[Request]): list of requests containing the context and ending conditions.

Returns:
list[LoglikelihoodResponse]: list of generated responses.
"""

for request in requests:
if request.context == "":
request.tokenized_context = [self.tokenizer.eos_token_id]
request.tokenized_continuation = self.tok_encode(request.choice)
else:
# The following line is mandatory for compatibility with the harness
request.tokenized_context, request.tokenized_continuation = self.tok_encode_pair(
request.context, request.choice, pairwise=self.pairwise_tokenization
)

results = []

responses: list[RequestOutput] = await self._async_batch(requests=requests)

for response, input in zip(responses, requests):
continuation_logprobs = []
for token, logprobs in zip(input.tokenized_continuation[::-1], response.prompt_logprobs[::-1]):
continuation_logprobs.append(logprobs[token])
bool_score = all(logprob.rank == 1 for logprob in continuation_logprobs)
continuation_logprobs = [logprob.logprob for logprob in continuation_logprobs]
answer = LoglikelihoodResponse(
input_tokens=input.tokenized_context + input.tokenized_continuation,
generated_tokens=input.tokenized_continuation,
result=(sum(continuation_logprobs), bool_score if return_bool_score else None),
)
results.append(answer)

return results

def loglikelihood_rolling():
pass

def loglikelihood_single_token():
pass
40 changes: 27 additions & 13 deletions src/lighteval/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
# SOFTWARE.

import ast
import asyncio
import collections
import os
import random
import re
import shutil
from contextlib import nullcontext
from dataclasses import dataclass
from datetime import timedelta
Expand Down Expand Up @@ -295,14 +295,6 @@ def evaluate(self):
self.evaluation_tracker.metrics_logger.aggregate(task_dict=self.task_dict, bootstrap_iters=1000)
self.evaluation_tracker.details_logger.aggregate()

for weights in ["delta", "adapter"]:
try:
tmp_weights_dir = f"{self.evaluation_tracker.general_config_logger.model_name}-{weights}-applied"
shutil.rmtree(tmp_weights_dir)
logger.info(f"Removed {tmp_weights_dir}")
except OSError:
pass

def _unpack(self, x):
if isinstance(x, str):
return x
Expand Down Expand Up @@ -455,12 +447,22 @@ def _get_model_response_type(self, request_type):

return model_response_type

def _run_model(self):
# Running all requests depending on the model call type (log likelihood, generative, ...)
# to be able to batch them
logger.info("--- RUNNING MODEL ---")
async def _run_model_async(self):
sample_id_to_responses: dict[(SampleUid, MetricCategory), list[ModelResponse]] = collections.defaultdict(list)
for request_type, requests in self.requests.items():
logger.info(f"Sending {request_type} requests")
run_model = self.model.get_method_from_request_type(request_type=request_type)
responses = await run_model(requests)

# Storing the responses associated to the same samples together
for response, request in zip(responses, requests):
for metric_category in request.metric_categories:
sample_id = SampleUid(request.task_name, request.sample_index)
sample_id_to_responses[(sample_id, metric_category)].append(response)
return sample_id_to_responses

def _run_model_sync(self):
sample_id_to_responses: dict[(SampleUid, MetricCategory), list[ModelResponse]] = collections.defaultdict(list)
for request_type, requests in self.requests.items():
logger.info(f"Running {request_type} requests")
run_model = self.model.get_method_from_request_type(request_type=request_type)
Expand All @@ -471,6 +473,18 @@ def _run_model(self):
for metric_category in request.metric_categories:
sample_id = SampleUid(request.task_name, request.sample_index)
sample_id_to_responses[(sample_id, metric_category)].append(response)
return sample_id_to_responses

def _run_model(self):
# Running all requests depending on the model call type (log likelihood, generative, ...)
# to be able to batch them
logger.info("--- RUNNING MODEL ---")

if self.model.is_async:
sample_id_to_responses = asyncio.run(self._run_model_async())

else:
sample_id_to_responses = self._run_model_sync()

# Cleaning up the model before running metrics
self.model.cleanup()
Expand Down