Skip to content

Commit 5986992

Browse files
committed
refactor vllm/inputs/data.py to use newly defined functions
Signed-off-by: Tobias Pitters <[email protected]>
1 parent d6add6a commit 5986992

File tree

1 file changed

+12
-17
lines changed

1 file changed

+12
-17
lines changed

vllm/inputs/data.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ class SingletonInputsAdapter:
268268
def prompt(self) -> Optional[str]:
269269
inputs = self.inputs
270270

271-
if inputs["type"] == "token" or inputs["type"] == "multimodal":
271+
if is_token_inputs(inputs) or is_multimodal_inputs(inputs):
272272
return inputs.get("prompt")
273273

274274
assert_never(inputs) # type: ignore[arg-type]
@@ -277,7 +277,7 @@ def prompt(self) -> Optional[str]:
277277
def prompt_token_ids(self) -> List[int]:
278278
inputs = self.inputs
279279

280-
if inputs["type"] == "token" or inputs["type"] == "multimodal":
280+
if is_token_inputs(inputs) or is_multimodal_inputs(inputs):
281281
return inputs.get("prompt_token_ids", [])
282282

283283
assert_never(inputs) # type: ignore[arg-type]
@@ -286,7 +286,7 @@ def prompt_token_ids(self) -> List[int]:
286286
def token_type_ids(self) -> List[int]:
287287
inputs = self.inputs
288288

289-
if inputs["type"] == "token" or inputs["type"] == "multimodal":
289+
if is_token_inputs(inputs) or is_multimodal_inputs(inputs):
290290
return inputs.get("token_type_ids", [])
291291

292292
assert_never(inputs) # type: ignore[arg-type]
@@ -295,7 +295,7 @@ def token_type_ids(self) -> List[int]:
295295
def prompt_embeds(self) -> Optional[torch.Tensor]:
296296
inputs = self.inputs
297297

298-
if inputs["type"] == "token" or inputs["type"] == "multimodal":
298+
if is_token_inputs(inputs) or is_multimodal_inputs(inputs):
299299
return None
300300

301301
assert_never(inputs) # type: ignore[arg-type]
@@ -304,10 +304,9 @@ def prompt_embeds(self) -> Optional[torch.Tensor]:
304304
def multi_modal_data(self) -> "MultiModalDataDict":
305305
inputs = self.inputs
306306

307-
if inputs["type"] == "token":
307+
if is_token_inputs(inputs):
308308
return inputs.get("multi_modal_data", {})
309-
310-
if inputs["type"] == "multimodal":
309+
elif is_multimodal_inputs(inputs):
311310
return inputs.get("mm_kwargs", {})
312311

313312
assert_never(inputs) # type: ignore[arg-type]
@@ -316,10 +315,9 @@ def multi_modal_data(self) -> "MultiModalDataDict":
316315
def multi_modal_inputs(self) -> Union[Dict, "MultiModalKwargs"]:
317316
inputs = self.inputs
318317

319-
if inputs["type"] == "token":
318+
if is_token_inputs(inputs):
320319
return inputs.get("multi_modal_inputs", {})
321-
322-
if inputs["type"] == "multimodal":
320+
elif is_multimodal_inputs(inputs):
323321
return inputs.get("mm_kwargs", {})
324322

325323
assert_never(inputs) # type: ignore[arg-type]
@@ -331,7 +329,6 @@ def multi_modal_hashes(self) -> List[str]:
331329
if is_token_inputs(inputs):
332330
return inputs.get("multi_modal_hashes", [])
333331
elif is_multimodal_inputs(inputs):
334-
# only the case when we use MultiModalInputsV2
335332
return inputs.get("mm_hashes", [])
336333

337334
assert_never(inputs) # type: ignore[arg-type]
@@ -340,10 +337,9 @@ def multi_modal_hashes(self) -> List[str]:
340337
def multi_modal_placeholders(self) -> "MultiModalPlaceholderDict":
341338
inputs = self.inputs
342339

343-
if inputs["type"] == "token":
340+
if is_token_inputs(inputs):
344341
return inputs.get("multi_modal_placeholders", {})
345-
346-
if inputs["type"] == "multimodal":
342+
elif is_multimodal_inputs(inputs):
347343
return inputs.get("mm_placeholders", {})
348344

349345
assert_never(inputs) # type: ignore[arg-type]
@@ -352,10 +348,9 @@ def multi_modal_placeholders(self) -> "MultiModalPlaceholderDict":
352348
def mm_processor_kwargs(self) -> Dict[str, Any]:
353349
inputs = self.inputs
354350

355-
if inputs["type"] == "token":
351+
if is_token_inputs(inputs):
356352
return inputs.get("mm_processor_kwargs", {})
357-
358-
if inputs["type"] == "multimodal":
353+
elif is_multimodal_inputs(inputs):
359354
return {}
360355

361356
assert_never(inputs) # type: ignore[arg-type]

0 commit comments

Comments
 (0)