@@ -268,7 +268,7 @@ class SingletonInputsAdapter:
268
268
def prompt (self ) -> Optional [str ]:
269
269
inputs = self .inputs
270
270
271
- if inputs [ "type" ] == "token" or inputs [ "type" ] == "multimodal" :
271
+ if is_token_inputs ( inputs ) or is_multimodal_inputs ( inputs ) :
272
272
return inputs .get ("prompt" )
273
273
274
274
assert_never (inputs ) # type: ignore[arg-type]
@@ -277,7 +277,7 @@ def prompt(self) -> Optional[str]:
277
277
def prompt_token_ids (self ) -> List [int ]:
278
278
inputs = self .inputs
279
279
280
- if inputs [ "type" ] == "token" or inputs [ "type" ] == "multimodal" :
280
+ if is_token_inputs ( inputs ) or is_multimodal_inputs ( inputs ) :
281
281
return inputs .get ("prompt_token_ids" , [])
282
282
283
283
assert_never (inputs ) # type: ignore[arg-type]
@@ -286,7 +286,7 @@ def prompt_token_ids(self) -> List[int]:
286
286
def token_type_ids (self ) -> List [int ]:
287
287
inputs = self .inputs
288
288
289
- if inputs [ "type" ] == "token" or inputs [ "type" ] == "multimodal" :
289
+ if is_token_inputs ( inputs ) or is_multimodal_inputs ( inputs ) :
290
290
return inputs .get ("token_type_ids" , [])
291
291
292
292
assert_never (inputs ) # type: ignore[arg-type]
@@ -295,7 +295,7 @@ def token_type_ids(self) -> List[int]:
295
295
def prompt_embeds (self ) -> Optional [torch .Tensor ]:
296
296
inputs = self .inputs
297
297
298
- if inputs [ "type" ] == "token" or inputs [ "type" ] == "multimodal" :
298
+ if is_token_inputs ( inputs ) or is_multimodal_inputs ( inputs ) :
299
299
return None
300
300
301
301
assert_never (inputs ) # type: ignore[arg-type]
@@ -304,10 +304,9 @@ def prompt_embeds(self) -> Optional[torch.Tensor]:
304
304
def multi_modal_data (self ) -> "MultiModalDataDict" :
305
305
inputs = self .inputs
306
306
307
- if inputs [ "type" ] == "token" :
307
+ if is_token_inputs ( inputs ) :
308
308
return inputs .get ("multi_modal_data" , {})
309
-
310
- if inputs ["type" ] == "multimodal" :
309
+ elif is_multimodal_inputs (inputs ):
311
310
return inputs .get ("mm_kwargs" , {})
312
311
313
312
assert_never (inputs ) # type: ignore[arg-type]
@@ -316,10 +315,9 @@ def multi_modal_data(self) -> "MultiModalDataDict":
316
315
def multi_modal_inputs (self ) -> Union [Dict , "MultiModalKwargs" ]:
317
316
inputs = self .inputs
318
317
319
- if inputs [ "type" ] == "token" :
318
+ if is_token_inputs ( inputs ) :
320
319
return inputs .get ("multi_modal_inputs" , {})
321
-
322
- if inputs ["type" ] == "multimodal" :
320
+ elif is_multimodal_inputs (inputs ):
323
321
return inputs .get ("mm_kwargs" , {})
324
322
325
323
assert_never (inputs ) # type: ignore[arg-type]
@@ -331,7 +329,6 @@ def multi_modal_hashes(self) -> List[str]:
331
329
if is_token_inputs (inputs ):
332
330
return inputs .get ("multi_modal_hashes" , [])
333
331
elif is_multimodal_inputs (inputs ):
334
- # only the case when we use MultiModalInputsV2
335
332
return inputs .get ("mm_hashes" , [])
336
333
337
334
assert_never (inputs ) # type: ignore[arg-type]
@@ -340,10 +337,9 @@ def multi_modal_hashes(self) -> List[str]:
340
337
def multi_modal_placeholders (self ) -> "MultiModalPlaceholderDict" :
341
338
inputs = self .inputs
342
339
343
- if inputs [ "type" ] == "token" :
340
+ if is_token_inputs ( inputs ) :
344
341
return inputs .get ("multi_modal_placeholders" , {})
345
-
346
- if inputs ["type" ] == "multimodal" :
342
+ elif is_multimodal_inputs (inputs ):
347
343
return inputs .get ("mm_placeholders" , {})
348
344
349
345
assert_never (inputs ) # type: ignore[arg-type]
@@ -352,10 +348,9 @@ def multi_modal_placeholders(self) -> "MultiModalPlaceholderDict":
352
348
def mm_processor_kwargs (self ) -> Dict [str , Any ]:
353
349
inputs = self .inputs
354
350
355
- if inputs [ "type" ] == "token" :
351
+ if is_token_inputs ( inputs ) :
356
352
return inputs .get ("mm_processor_kwargs" , {})
357
-
358
- if inputs ["type" ] == "multimodal" :
353
+ elif is_multimodal_inputs (inputs ):
359
354
return {}
360
355
361
356
assert_never (inputs ) # type: ignore[arg-type]
0 commit comments