Skip to content

fix: vllm missing logprobs #5279

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions .github/workflows/test-extra.yml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,26 @@ jobs:
make --jobs=5 --output-sync=target -C backend/python/diffusers
make --jobs=5 --output-sync=target -C backend/python/diffusers test

tests-vllm:
runs-on: ubuntu-latest
steps:
- name: Clone
uses: actions/checkout@v4
with:
submodules: true
- name: Dependencies
run: |
sudo apt-get update
sudo apt-get install -y build-essential ffmpeg
sudo apt-get install -y ca-certificates cmake curl patch python3-pip
sudo apt-get install -y libopencv-dev
# Install UV
curl -LsSf https://astral.sh/uv/install.sh | sh
pip install --user --no-cache-dir grpcio-tools==1.64.1
- name: Test vllm backend
run: |
make --jobs=5 --output-sync=target -C backend/python/vllm
make --jobs=5 --output-sync=target -C backend/python/vllm test
# tests-transformers-musicgen:
# runs-on: ubuntu-latest
# steps:
Expand Down
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -598,10 +598,12 @@ prepare-extra-conda-environments: protogen-python
prepare-test-extra: protogen-python
$(MAKE) -C backend/python/transformers
$(MAKE) -C backend/python/diffusers
$(MAKE) -C backend/python/vllm

test-extra: prepare-test-extra
$(MAKE) -C backend/python/transformers test
$(MAKE) -C backend/python/diffusers test
$(MAKE) -C backend/python/vllm test

backend-assets:
mkdir -p backend-assets
Expand Down
51 changes: 32 additions & 19 deletions backend/python/vllm/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,27 +194,40 @@ async def PredictStream(self, request, context):
await iterations.aclose()

async def _predict(self, request, context, streaming=False):
# Build the sampling parameters
# NOTE: this must stay in sync with the vllm backend
request_to_sampling_params = {
"N": "n",
"PresencePenalty": "presence_penalty",
"FrequencyPenalty": "frequency_penalty",
"RepetitionPenalty": "repetition_penalty",
"Temperature": "temperature",
"TopP": "top_p",
"TopK": "top_k",
"MinP": "min_p",
"Seed": "seed",
"StopPrompts": "stop",
"StopTokenIds": "stop_token_ids",
"BadWords": "bad_words",
"IncludeStopStrInOutput": "include_stop_str_in_output",
"IgnoreEOS": "ignore_eos",
"Tokens": "max_tokens",
"MinTokens": "min_tokens",
"Logprobs": "logprobs",
"PromptLogprobs": "prompt_logprobs",
"SkipSpecialTokens": "skip_special_tokens",
"SpacesBetweenSpecialTokens": "spaces_between_special_tokens",
"TruncatePromptTokens": "truncate_prompt_tokens",
"GuidedDecoding": "guided_decoding",
}

# Build sampling parameters
sampling_params = SamplingParams(top_p=0.9, max_tokens=200)
if request.TopP != 0:
sampling_params.top_p = request.TopP
if request.Tokens > 0:
sampling_params.max_tokens = request.Tokens
if request.Temperature != 0:
sampling_params.temperature = request.Temperature
if request.TopK != 0:
sampling_params.top_k = request.TopK
if request.PresencePenalty != 0:
sampling_params.presence_penalty = request.PresencePenalty
if request.FrequencyPenalty != 0:
sampling_params.frequency_penalty = request.FrequencyPenalty
if request.StopPrompts:
sampling_params.stop = request.StopPrompts
if request.IgnoreEOS:
sampling_params.ignore_eos = request.IgnoreEOS
if request.Seed != 0:
sampling_params.seed = request.Seed

for request_field, param_field in request_to_sampling_params.items():
if hasattr(request, request_field):
value = getattr(request, request_field)
if value not in (None, 0, [], False):
setattr(sampling_params, param_field, value)

# Extract image paths and process images
prompt = request.Prompt
Expand Down
47 changes: 47 additions & 0 deletions backend/python/vllm/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,53 @@ def test_text(self):
finally:
self.tearDown()

def test_sampling_params(self):
"""
This method tests if all sampling parameters are correctly processed
NOTE: this does NOT test for correctness, just that we received a compatible response
"""
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/opt-125m"))
self.assertTrue(response.success)

req = backend_pb2.PredictOptions(
Prompt="The capital of France is",
TopP=0.8,
Tokens=50,
Temperature=0.7,
TopK=40,
PresencePenalty=0.1,
FrequencyPenalty=0.2,
RepetitionPenalty=1.1,
MinP=0.05,
Seed=42,
StopPrompts=["\n"],
StopTokenIds=[50256],
BadWords=["badword"],
IncludeStopStrInOutput=True,
IgnoreEOS=True,
MinTokens=5,
Logprobs=5,
PromptLogprobs=5,
SkipSpecialTokens=True,
SpacesBetweenSpecialTokens=True,
TruncatePromptTokens=10,
GuidedDecoding=True,
N=2,
)
resp = stub.Predict(req)
self.assertIsNotNone(resp.message)
self.assertIsNotNone(resp.logprobs)
except Exception as err:
print(err)
self.fail("sampling params service failed")
finally:
self.tearDown()


def test_embedding(self):
"""
This method tests if the embeddings are generated successfully
Expand Down
Loading