diff --git a/pyproject.toml b/pyproject.toml index 931846b74..c2b41ffde 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -94,7 +94,7 @@ nanotron = [ "tensorboardX" ] tensorboardX = ["tensorboardX"] -vllm = ["vllm>=0.7.0", "ray", "more_itertools"] +vllm = ["vllm>=0.8.4", "ray", "more_itertools"] quality = ["ruff>=v0.11.0","pre-commit"] tests = ["pytest==7.4.0","deepdiff"] dev = ["lighteval[accelerate,quality,tests,multilingual,math,extended_tasks,vllm]"] diff --git a/src/lighteval/models/abstract_model.py b/src/lighteval/models/abstract_model.py index 58c27391a..78e1768f1 100644 --- a/src/lighteval/models/abstract_model.py +++ b/src/lighteval/models/abstract_model.py @@ -56,6 +56,7 @@ class ModelInfo: class LightevalModel(ABC): DATASET_SPLITS = 4 + is_async = False """Abstract model class defining the API that every model to plug into lighteval must follow.""" diff --git a/src/lighteval/models/model_loader.py b/src/lighteval/models/model_loader.py index ef9c77549..ed497165a 100644 --- a/src/lighteval/models/model_loader.py +++ b/src/lighteval/models/model_loader.py @@ -44,7 +44,7 @@ from lighteval.models.transformers.transformers_model import TransformersModel, TransformersModelConfig from lighteval.models.transformers.vlm_transformers_model import VLMTransformersModel, VLMTransformersModelConfig from lighteval.models.utils import ModelConfig -from lighteval.models.vllm.vllm_model import VLLMModel, VLLMModelConfig +from lighteval.models.vllm.vllm_model import AsyncVLLMModel, VLLMModel, VLLMModelConfig from lighteval.utils.imports import ( NO_LITELLM_ERROR_MSG, NO_SGLANG_ERROR_MSG, @@ -189,11 +189,12 @@ def load_model_with_accelerate_or_default( elif isinstance(config, VLLMModelConfig): if not is_vllm_available(): raise ImportError(NO_VLLM_ERROR_MSG) - model = VLLMModel(config=config) - return model + if config.is_async: + model = AsyncVLLMModel(config=config) + else: + model = VLLMModel(config=config) elif isinstance(config, VLMTransformersModelConfig): model = VLMTransformersModel(config=config) - return model else: model = TransformersModel(config=config) diff --git a/src/lighteval/models/transformers/adapter_model.py b/src/lighteval/models/transformers/adapter_model.py index fd341542c..e6df27cf2 100644 --- a/src/lighteval/models/transformers/adapter_model.py +++ b/src/lighteval/models/transformers/adapter_model.py @@ -21,6 +21,7 @@ # SOFTWARE. import logging +import shutil from contextlib import nullcontext import torch @@ -105,3 +106,11 @@ def _create_auto_model(self) -> transformers.PreTrainedModel: ) return model + + def cleanup(self): + try: + tmp_weights_dir = f"{self.model_name}-adapter-applied" + shutil.rmtree(tmp_weights_dir) + logger.info(f"Removed {tmp_weights_dir}") + except OSError: + pass diff --git a/src/lighteval/models/transformers/delta_model.py b/src/lighteval/models/transformers/delta_model.py index 51395b424..3638fe5af 100644 --- a/src/lighteval/models/transformers/delta_model.py +++ b/src/lighteval/models/transformers/delta_model.py @@ -21,6 +21,7 @@ # SOFTWARE. import logging +import shutil from contextlib import nullcontext import torch @@ -87,3 +88,11 @@ def _create_auto_model( ) return model + + def cleanup(self): + try: + tmp_weights_dir = f"{self.model_name}-delta-applied" + shutil.rmtree(tmp_weights_dir) + logger.info(f"Removed {tmp_weights_dir}") + except OSError: + pass diff --git a/src/lighteval/models/vllm/vllm_model.py b/src/lighteval/models/vllm/vllm_model.py index 5be1285f8..0b4892a20 100644 --- a/src/lighteval/models/vllm/vllm_model.py +++ b/src/lighteval/models/vllm/vllm_model.py @@ -20,11 +20,12 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +import asyncio import gc import itertools import logging import os -from typing import Optional +from typing import Coroutine, Optional import torch from pydantic import NonNegativeFloat, NonNegativeInt, PositiveInt @@ -51,9 +52,13 @@ if is_vllm_available(): import ray from more_itertools import distribute - from vllm import LLM, SamplingParams - from vllm.distributed.parallel_state import destroy_distributed_environment, destroy_model_parallel + from vllm import LLM, RequestOutput, SamplingParams + from vllm.distributed.parallel_state import ( + destroy_distributed_environment, + destroy_model_parallel, + ) from vllm.transformers_utils.tokenizer import get_tokenizer + from vllm.v1.engine.async_llm import AsyncEngineArgs, AsyncLLM logging.getLogger("vllm").propagate = True logging.getLogger("vllm").handlers.clear() @@ -62,7 +67,9 @@ logging.getLogger("ray").handlers.clear() else: LLM = None + AsyncLLM = None SamplingParams = None + AsyncEngineArgs = None get_tokenizer = None ray = None distribute = None @@ -97,6 +104,7 @@ class VLLMModelConfig(ModelConfig): max_num_seqs: PositiveInt = 128 # maximum number of sequences per iteration; This variable and `max_num_batched_tokens` effectively control the batch size at prefill stage. See https://github.com/vllm-project/vllm/issues/2492 for detailed explaination. max_num_batched_tokens: PositiveInt = 2048 # maximum number of tokens per batch subfolder: str | None = None + is_async: bool = False # Whether to use the async version or sync version of the model class VLLMModel(LightevalModel): @@ -422,3 +430,185 @@ def loglikelihood_rolling(): def loglikelihood_single_token(): pass + + +class AsyncVLLMModel(VLLMModel): + """VLLM models which deploy async natively (no ray). Supports DP and PP/TP but not batch size > 1""" + + DATASET_SPLITS = 1 + is_async = True + + def cleanup(self): + gc.collect() + destroy_distributed_environment() + torch.cuda.empty_cache() + + def _create_auto_model(self, config: VLLMModelConfig) -> Optional[AsyncLLM]: + """ + Creates an instance of the async vllm model loaded from HF. Requires using the v1 of VLLM. + + Returns: + AsyncLLM: The created async VLLM instance + """ + self.model_args = { + "model": config.model_name, + "gpu_memory_utilization": config.gpu_memory_utilization, + "revision": config.revision + (f"/{config.subfolder}" if config.subfolder is not None else ""), + "dtype": config.dtype, + "trust_remote_code": config.trust_remote_code, + "tensor_parallel_size": config.tensor_parallel_size, + "data_parallel_size": config.data_parallel_size, + "pipeline_parallel_size": config.pipeline_parallel_size, + "max_model_len": self._max_length, + "swap_space": 4, + "seed": int(config.seed), + "max_num_seqs": int(config.max_num_seqs), + "max_num_batched_tokens": int(config.max_num_batched_tokens), + "enforce_eager": True, + } + + if config.data_parallel_size > 1: + self._batch_size = "auto" + + model = AsyncLLM.from_engine_args(AsyncEngineArgs(**self.model_args)) + + # If the max_length can't get extracted from the config, it will be inferred from the model + if self._max_length is None: + self._max_length = model.model_config.max_seq_len_to_capture + + return model + + async def _async_one_item( + self, + index: int, + request: GreedyUntilRequest | LoglikelihoodRequest, + ) -> Coroutine[None, list, str]: + """Contains the actual logic of the generation.""" + sampling_params = SamplingParams(**self._config.generation_parameters.to_vllm_dict()) + + if isinstance(request, LoglikelihoodRequest): + sampling_params.temperature = 0 + sampling_params.prompt_logprobs = 1 + sampling_params.max_tokens = 1 + sampling_params.detokenize = False + prompt = request.context + request.choice + index = f"logprob_{index}" + elif isinstance(request, GreedyUntilRequest): + sampling_params.n = request.num_samples + if sampling_params.n > 1: + # Todo clementine: investigate more + logger.warning( + "Careful, there can be unexpected behavior when using sampling evals with the async vllm model" + ) + sampling_params.max_tokens = self._config.generation_parameters.max_new_tokens or request.generation_size + sampling_params.stop = [] if self.use_chat_template else request.stop_sequence + sampling_params.logprobs = int(request.use_logits) + prompt = request.context + index = f"generative_{index}" + + generator = self.model.generate(request_id=str(index), prompt=prompt, sampling_params=sampling_params) + try: + while output := await anext(generator): + continue + except StopAsyncIteration: + pass + + return output + + async def _async_batch(self, requests: list[GreedyUntilRequest | LoglikelihoodRequest]) -> list: + processed_requests = [ + self._async_one_item(index=index, request=request) for index, request in enumerate(requests) + ] + results = await asyncio.gather(*processed_requests) + return results + + async def greedy_until( + self, + requests: list[GreedyUntilRequest], + **kwargs, + ) -> list[GenerativeResponse]: + """ + Generates responses using a greedy decoding strategy until certain ending conditions are met. + + Args: + requests (list[Request]): list of requests containing the context and ending conditions. + + Returns: + list[GenerateReturn]: list of generated responses. + """ + for request in requests: + request.stop_sequence = as_list(request.stop_sequence) + [self.tokenizer.eos_token] + request.tokenized_context = self.tok_encode(request.context) + + results = [] + + responses: list[RequestOutput] = await self._async_batch(requests=requests) + + for response in responses: + output_token_ids = [outputs.token_ids for outputs in response.outputs] + full_logprobs = [output.logprobs for output in response.outputs] or [] + logprobs = [logprob[token_id].logprob for token_id, logprob in zip(output_token_ids[0], full_logprobs[0])] + result = [output.text for output in response.outputs] + input_token_ids = response.prompt_token_ids + + cur_response = GenerativeResponse( + result=result, + logits=logprobs, + generated_tokens=list(output_token_ids), + input_tokens=input_token_ids, + ) + results.append(cur_response) + + return results + + async def loglikelihood( + self, + requests: list[LoglikelihoodRequest], + return_bool_score: bool = True, + **kwargs, + ) -> list[LoglikelihoodResponse]: + """ + Generates responses using a greedy decoding strategy until certain ending conditions are met and + stores the logprobs. + + Args: + requests (list[Request]): list of requests containing the context and ending conditions. + + Returns: + list[LoglikelihoodResponse]: list of generated responses. + """ + + for request in requests: + if request.context == "": + request.tokenized_context = [self.tokenizer.eos_token_id] + request.tokenized_continuation = self.tok_encode(request.choice) + else: + # The following line is mandatory for compatibility with the harness + request.tokenized_context, request.tokenized_continuation = self.tok_encode_pair( + request.context, request.choice, pairwise=self.pairwise_tokenization + ) + + results = [] + + responses: list[RequestOutput] = await self._async_batch(requests=requests) + + for response, input in zip(responses, requests): + continuation_logprobs = [] + for token, logprobs in zip(input.tokenized_continuation[::-1], response.prompt_logprobs[::-1]): + continuation_logprobs.append(logprobs[token]) + bool_score = all(logprob.rank == 1 for logprob in continuation_logprobs) + continuation_logprobs = [logprob.logprob for logprob in continuation_logprobs] + answer = LoglikelihoodResponse( + input_tokens=input.tokenized_context + input.tokenized_continuation, + generated_tokens=input.tokenized_continuation, + result=(sum(continuation_logprobs), bool_score if return_bool_score else None), + ) + results.append(answer) + + return results + + def loglikelihood_rolling(self): + pass + + def loglikelihood_single_token(): + pass diff --git a/src/lighteval/pipeline.py b/src/lighteval/pipeline.py index e9459e0e1..50e4ca3ef 100644 --- a/src/lighteval/pipeline.py +++ b/src/lighteval/pipeline.py @@ -21,11 +21,11 @@ # SOFTWARE. import ast +import asyncio import collections import os import random import re -import shutil from contextlib import nullcontext from dataclasses import dataclass from datetime import timedelta @@ -296,14 +296,6 @@ def evaluate(self): self.evaluation_tracker.metrics_logger.aggregate(task_dict=self.task_dict, bootstrap_iters=1000) self.evaluation_tracker.details_logger.aggregate() - for weights in ["delta", "adapter"]: - try: - tmp_weights_dir = f"{self.evaluation_tracker.general_config_logger.model_name}-{weights}-applied" - shutil.rmtree(tmp_weights_dir) - logger.info(f"Removed {tmp_weights_dir}") - except OSError: - pass - def _unpack(self, x): if isinstance(x, str): return x @@ -456,12 +448,22 @@ def _get_model_response_type(self, request_type): return model_response_type - def _run_model(self): - # Running all requests depending on the model call type (log likelihood, generative, ...) - # to be able to batch them - logger.info("--- RUNNING MODEL ---") + async def _run_model_async(self): sample_id_to_responses: dict[(SampleUid, MetricCategory), list[ModelResponse]] = collections.defaultdict(list) + for request_type, requests in self.requests.items(): + logger.info(f"Sending {request_type} requests") + run_model = self.model.get_method_from_request_type(request_type=request_type) + responses = await run_model(requests) + + # Storing the responses associated to the same samples together + for response, request in zip(responses, requests): + for metric_category in request.metric_categories: + sample_id = SampleUid(request.task_name, request.sample_index) + sample_id_to_responses[(sample_id, metric_category)].append(response) + return sample_id_to_responses + def _run_model_sync(self): + sample_id_to_responses: dict[(SampleUid, MetricCategory), list[ModelResponse]] = collections.defaultdict(list) for request_type, requests in self.requests.items(): logger.info(f"Running {request_type} requests") run_model = self.model.get_method_from_request_type(request_type=request_type) @@ -472,6 +474,18 @@ def _run_model(self): for metric_category in request.metric_categories: sample_id = SampleUid(request.task_name, request.sample_index) sample_id_to_responses[(sample_id, metric_category)].append(response) + return sample_id_to_responses + + def _run_model(self): + # Running all requests depending on the model call type (log likelihood, generative, ...) + # to be able to batch them + logger.info("--- RUNNING MODEL ---") + + if self.model.is_async: + sample_id_to_responses = asyncio.run(self._run_model_async()) + + else: + sample_id_to_responses = self._run_model_sync() # Cleaning up the model before running metrics self.model.cleanup()