Skip to content

Commit 4076ea0

Browse files
authored
fix: vllm missing logprobs (#5279)
* working to address missing items referencing #3436, #2930 - if i could test it, this might show that the output from the vllm backend is processed and returned to the user Signed-off-by: Wyatt Neal <[email protected]> * adding in vllm tests to test-extras Signed-off-by: Wyatt Neal <[email protected]> * adding in tests to pipeline for execution Signed-off-by: Wyatt Neal <[email protected]> * removing todo block, test via pipeline Signed-off-by: Wyatt Neal <[email protected]> --------- Signed-off-by: Wyatt Neal <[email protected]>
1 parent 26cbf77 commit 4076ea0

File tree

4 files changed

+101
-19
lines changed

4 files changed

+101
-19
lines changed

.github/workflows/test-extra.yml

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,26 @@ jobs:
7878
make --jobs=5 --output-sync=target -C backend/python/diffusers
7979
make --jobs=5 --output-sync=target -C backend/python/diffusers test
8080
81+
tests-vllm:
82+
runs-on: ubuntu-latest
83+
steps:
84+
- name: Clone
85+
uses: actions/checkout@v4
86+
with:
87+
submodules: true
88+
- name: Dependencies
89+
run: |
90+
sudo apt-get update
91+
sudo apt-get install -y build-essential ffmpeg
92+
sudo apt-get install -y ca-certificates cmake curl patch python3-pip
93+
sudo apt-get install -y libopencv-dev
94+
# Install UV
95+
curl -LsSf https://astral.sh/uv/install.sh | sh
96+
pip install --user --no-cache-dir grpcio-tools==1.64.1
97+
- name: Test vllm backend
98+
run: |
99+
make --jobs=5 --output-sync=target -C backend/python/vllm
100+
make --jobs=5 --output-sync=target -C backend/python/vllm test
81101
# tests-transformers-musicgen:
82102
# runs-on: ubuntu-latest
83103
# steps:

Makefile

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -598,10 +598,12 @@ prepare-extra-conda-environments: protogen-python
598598
prepare-test-extra: protogen-python
599599
$(MAKE) -C backend/python/transformers
600600
$(MAKE) -C backend/python/diffusers
601+
$(MAKE) -C backend/python/vllm
601602

602603
test-extra: prepare-test-extra
603604
$(MAKE) -C backend/python/transformers test
604605
$(MAKE) -C backend/python/diffusers test
606+
$(MAKE) -C backend/python/vllm test
605607

606608
backend-assets:
607609
mkdir -p backend-assets

backend/python/vllm/backend.py

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -194,27 +194,40 @@ async def PredictStream(self, request, context):
194194
await iterations.aclose()
195195

196196
async def _predict(self, request, context, streaming=False):
197+
# Build the sampling parameters
198+
# NOTE: this must stay in sync with the vllm backend
199+
request_to_sampling_params = {
200+
"N": "n",
201+
"PresencePenalty": "presence_penalty",
202+
"FrequencyPenalty": "frequency_penalty",
203+
"RepetitionPenalty": "repetition_penalty",
204+
"Temperature": "temperature",
205+
"TopP": "top_p",
206+
"TopK": "top_k",
207+
"MinP": "min_p",
208+
"Seed": "seed",
209+
"StopPrompts": "stop",
210+
"StopTokenIds": "stop_token_ids",
211+
"BadWords": "bad_words",
212+
"IncludeStopStrInOutput": "include_stop_str_in_output",
213+
"IgnoreEOS": "ignore_eos",
214+
"Tokens": "max_tokens",
215+
"MinTokens": "min_tokens",
216+
"Logprobs": "logprobs",
217+
"PromptLogprobs": "prompt_logprobs",
218+
"SkipSpecialTokens": "skip_special_tokens",
219+
"SpacesBetweenSpecialTokens": "spaces_between_special_tokens",
220+
"TruncatePromptTokens": "truncate_prompt_tokens",
221+
"GuidedDecoding": "guided_decoding",
222+
}
197223

198-
# Build sampling parameters
199224
sampling_params = SamplingParams(top_p=0.9, max_tokens=200)
200-
if request.TopP != 0:
201-
sampling_params.top_p = request.TopP
202-
if request.Tokens > 0:
203-
sampling_params.max_tokens = request.Tokens
204-
if request.Temperature != 0:
205-
sampling_params.temperature = request.Temperature
206-
if request.TopK != 0:
207-
sampling_params.top_k = request.TopK
208-
if request.PresencePenalty != 0:
209-
sampling_params.presence_penalty = request.PresencePenalty
210-
if request.FrequencyPenalty != 0:
211-
sampling_params.frequency_penalty = request.FrequencyPenalty
212-
if request.StopPrompts:
213-
sampling_params.stop = request.StopPrompts
214-
if request.IgnoreEOS:
215-
sampling_params.ignore_eos = request.IgnoreEOS
216-
if request.Seed != 0:
217-
sampling_params.seed = request.Seed
225+
226+
for request_field, param_field in request_to_sampling_params.items():
227+
if hasattr(request, request_field):
228+
value = getattr(request, request_field)
229+
if value not in (None, 0, [], False):
230+
setattr(sampling_params, param_field, value)
218231

219232
# Extract image paths and process images
220233
prompt = request.Prompt

backend/python/vllm/test.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,53 @@ def test_text(self):
7575
finally:
7676
self.tearDown()
7777

78+
def test_sampling_params(self):
79+
"""
80+
This method tests if all sampling parameters are correctly processed
81+
NOTE: this does NOT test for correctness, just that we received a compatible response
82+
"""
83+
try:
84+
self.setUp()
85+
with grpc.insecure_channel("localhost:50051") as channel:
86+
stub = backend_pb2_grpc.BackendStub(channel)
87+
response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/opt-125m"))
88+
self.assertTrue(response.success)
89+
90+
req = backend_pb2.PredictOptions(
91+
Prompt="The capital of France is",
92+
TopP=0.8,
93+
Tokens=50,
94+
Temperature=0.7,
95+
TopK=40,
96+
PresencePenalty=0.1,
97+
FrequencyPenalty=0.2,
98+
RepetitionPenalty=1.1,
99+
MinP=0.05,
100+
Seed=42,
101+
StopPrompts=["\n"],
102+
StopTokenIds=[50256],
103+
BadWords=["badword"],
104+
IncludeStopStrInOutput=True,
105+
IgnoreEOS=True,
106+
MinTokens=5,
107+
Logprobs=5,
108+
PromptLogprobs=5,
109+
SkipSpecialTokens=True,
110+
SpacesBetweenSpecialTokens=True,
111+
TruncatePromptTokens=10,
112+
GuidedDecoding=True,
113+
N=2,
114+
)
115+
resp = stub.Predict(req)
116+
self.assertIsNotNone(resp.message)
117+
self.assertIsNotNone(resp.logprobs)
118+
except Exception as err:
119+
print(err)
120+
self.fail("sampling params service failed")
121+
finally:
122+
self.tearDown()
123+
124+
78125
def test_embedding(self):
79126
"""
80127
This method tests if the embeddings are generated successfully

0 commit comments

Comments
 (0)