@@ -902,20 +902,26 @@ def _encode_prompt_with_clip(
902
902
tokenizer ,
903
903
prompt : str ,
904
904
device = None ,
905
+ text_input_ids = None ,
905
906
num_images_per_prompt : int = 1 ,
906
907
):
907
908
prompt = [prompt ] if isinstance (prompt , str ) else prompt
908
909
batch_size = len (prompt )
909
910
910
- text_inputs = tokenizer (
911
- prompt ,
912
- padding = "max_length" ,
913
- max_length = 77 ,
914
- truncation = True ,
915
- return_tensors = "pt" ,
916
- )
911
+ if tokenizer is not None :
912
+ text_inputs = tokenizer (
913
+ prompt ,
914
+ padding = "max_length" ,
915
+ max_length = 77 ,
916
+ truncation = True ,
917
+ return_tensors = "pt" ,
918
+ )
919
+
920
+ text_input_ids = text_inputs .input_ids
921
+ else :
922
+ if text_input_ids is None :
923
+ raise ValueError ("text_input_ids must be provided when the tokenizer is not specified" )
917
924
918
- text_input_ids = text_inputs .input_ids
919
925
prompt_embeds = text_encoder (text_input_ids .to (device ), output_hidden_states = True )
920
926
921
927
pooled_prompt_embeds = prompt_embeds [0 ]
@@ -937,6 +943,7 @@ def encode_prompt(
937
943
max_sequence_length ,
938
944
device = None ,
939
945
num_images_per_prompt : int = 1 ,
946
+ text_input_ids_list = None ,
940
947
):
941
948
prompt = [prompt ] if isinstance (prompt , str ) else prompt
942
949
@@ -945,13 +952,14 @@ def encode_prompt(
945
952
946
953
clip_prompt_embeds_list = []
947
954
clip_pooled_prompt_embeds_list = []
948
- for tokenizer , text_encoder in zip (clip_tokenizers , clip_text_encoders ):
955
+ for i , ( tokenizer , text_encoder ) in enumerate ( zip (clip_tokenizers , clip_text_encoders ) ):
949
956
prompt_embeds , pooled_prompt_embeds = _encode_prompt_with_clip (
950
957
text_encoder = text_encoder ,
951
958
tokenizer = tokenizer ,
952
959
prompt = prompt ,
953
960
device = device if device is not None else text_encoder .device ,
954
961
num_images_per_prompt = num_images_per_prompt ,
962
+ text_input_ids = text_input_ids_list [i ] if text_input_ids_list else None ,
955
963
)
956
964
clip_prompt_embeds_list .append (prompt_embeds )
957
965
clip_pooled_prompt_embeds_list .append (pooled_prompt_embeds )
0 commit comments