Skip to content

Commit ead2110

Browse files
[Core][Bugfix] Fix Online MM Beam Search (#19688)
Signed-off-by: Alex-Brooks <[email protected]>
1 parent 01220ce commit ead2110

File tree

3 files changed

+45
-12
lines changed

3 files changed

+45
-12
lines changed

tests/entrypoints/openai/test_vision.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,25 @@
2525
"https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png",
2626
]
2727

28+
EXPECTED_MM_BEAM_SEARCH_RES = [
29+
[
30+
"The image shows a wooden boardwalk leading through a",
31+
"The image shows a wooden boardwalk extending into a",
32+
],
33+
[
34+
"The image shows two parrots perched on",
35+
"The image shows two birds perched on a cur",
36+
],
37+
[
38+
"The image shows a Venn diagram with three over",
39+
"This image shows a Venn diagram with three over",
40+
],
41+
[
42+
"This image displays a gradient of colors ranging from",
43+
"This image displays a gradient of colors transitioning from",
44+
],
45+
]
46+
2847

2948
@pytest.fixture(scope="module")
3049
def server():
@@ -270,10 +289,13 @@ async def test_single_chat_session_image_base64encoded(
270289

271290
@pytest.mark.asyncio
272291
@pytest.mark.parametrize("model_name", [MODEL_NAME])
273-
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
292+
@pytest.mark.parametrize("image_idx", list(range(len(TEST_IMAGE_URLS))))
274293
async def test_single_chat_session_image_base64encoded_beamsearch(
275-
client: openai.AsyncOpenAI, model_name: str, image_url: str,
294+
client: openai.AsyncOpenAI, model_name: str, image_idx: int,
276295
base64_encoded_image: dict[str, str]):
296+
# NOTE: This test also validates that we pass MM data through beam search
297+
image_url = TEST_IMAGE_URLS[image_idx]
298+
expected_res = EXPECTED_MM_BEAM_SEARCH_RES[image_idx]
277299

278300
messages = [{
279301
"role":
@@ -297,10 +319,11 @@ async def test_single_chat_session_image_base64encoded_beamsearch(
297319
messages=messages,
298320
n=2,
299321
max_completion_tokens=10,
322+
temperature=0.0,
300323
extra_body=dict(use_beam_search=True))
301324
assert len(chat_completion.choices) == 2
302-
assert chat_completion.choices[
303-
0].message.content != chat_completion.choices[1].message.content
325+
for actual, expected_str in zip(chat_completion.choices, expected_res):
326+
assert actual.message.content == expected_str
304327

305328

306329
@pytest.mark.asyncio

vllm/engine/protocol.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,18 @@ async def beam_search(
8888
if processed_inputs["type"] == "embeds":
8989
raise NotImplementedError
9090

91-
prompt_token_ids = processed_inputs["prompt_token_ids"]
91+
# This is a workaround to fix multimodal beam search; this is a
92+
# bandaid fix for 2 small problems:
93+
# 1. Multi_modal_data on the processed_inputs currently resolves to
94+
# `None`.
95+
# 2. preprocessing above expands the multimodal placeholders. However,
96+
# this happens again in generation, so the double expansion causes
97+
# a mismatch.
98+
# TODO - would be ideal to handle this more gracefully.
99+
prompt_token_ids = prompt.get("prompt_token_ids")
100+
multi_modal_data = prompt.get("multi_modal_data")
101+
92102
prompt_text = processed_inputs.get("prompt")
93-
multi_modal_data = processed_inputs.get("multi_modal_data")
94103
mm_processor_kwargs = processed_inputs.get("mm_processor_kwargs")
95104

96105
tokenized_length = len(prompt_token_ids)

vllm/entrypoints/llm.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
from typing_extensions import TypeVar, deprecated
1616

1717
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
18-
BeamSearchSequence, get_beam_search_score)
18+
BeamSearchSequence,
19+
create_sort_beams_key_function)
1920
from vllm.config import (CompilationConfig, ModelDType, TokenizerMode,
2021
is_init_field)
2122
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
@@ -575,10 +576,11 @@ def beam_search(
575576
lora_requests = self._get_beam_search_lora_requests(
576577
lora_request, prompts)
577578

578-
def sort_beams_key(x: BeamSearchSequence) -> float:
579-
return get_beam_search_score(x.tokens, x.cum_logprob,
580-
tokenizer.eos_token_id,
581-
length_penalty)
579+
tokenizer = self.get_tokenizer()
580+
sort_beams_key = create_sort_beams_key_function(
581+
tokenizer.eos_token_id,
582+
length_penalty,
583+
)
582584

583585
def create_tokens_prompt_from_beam(
584586
beam: BeamSearchSequence) -> TokensPrompt:
@@ -593,7 +595,6 @@ def create_tokens_prompt_from_beam(
593595
"mm_processor_kwargs"] = beam.mm_processor_kwargs
594596
return TokensPrompt(**token_prompt_kwargs)
595597

596-
tokenizer = self.get_tokenizer()
597598
# generate 2 * beam_width candidates at each step
598599
# following the huggingface transformers implementation
599600
# at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa

0 commit comments

Comments
 (0)