Skip to content

Commit 221d6db

Browse files
Updated _encode_prompt_with_clip and encode_prompt in train_dreamboth_sd3 (#9800)
* updated encode prompt and clip encod prompt --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent f546404 commit 221d6db

File tree

1 file changed

+17
-9
lines changed

1 file changed

+17
-9
lines changed

examples/dreambooth/train_dreambooth_sd3.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -902,20 +902,26 @@ def _encode_prompt_with_clip(
902902
tokenizer,
903903
prompt: str,
904904
device=None,
905+
text_input_ids=None,
905906
num_images_per_prompt: int = 1,
906907
):
907908
prompt = [prompt] if isinstance(prompt, str) else prompt
908909
batch_size = len(prompt)
909910

910-
text_inputs = tokenizer(
911-
prompt,
912-
padding="max_length",
913-
max_length=77,
914-
truncation=True,
915-
return_tensors="pt",
916-
)
911+
if tokenizer is not None:
912+
text_inputs = tokenizer(
913+
prompt,
914+
padding="max_length",
915+
max_length=77,
916+
truncation=True,
917+
return_tensors="pt",
918+
)
919+
920+
text_input_ids = text_inputs.input_ids
921+
else:
922+
if text_input_ids is None:
923+
raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
917924

918-
text_input_ids = text_inputs.input_ids
919925
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
920926

921927
pooled_prompt_embeds = prompt_embeds[0]
@@ -937,6 +943,7 @@ def encode_prompt(
937943
max_sequence_length,
938944
device=None,
939945
num_images_per_prompt: int = 1,
946+
text_input_ids_list=None,
940947
):
941948
prompt = [prompt] if isinstance(prompt, str) else prompt
942949

@@ -945,13 +952,14 @@ def encode_prompt(
945952

946953
clip_prompt_embeds_list = []
947954
clip_pooled_prompt_embeds_list = []
948-
for tokenizer, text_encoder in zip(clip_tokenizers, clip_text_encoders):
955+
for i, (tokenizer, text_encoder) in enumerate(zip(clip_tokenizers, clip_text_encoders)):
949956
prompt_embeds, pooled_prompt_embeds = _encode_prompt_with_clip(
950957
text_encoder=text_encoder,
951958
tokenizer=tokenizer,
952959
prompt=prompt,
953960
device=device if device is not None else text_encoder.device,
954961
num_images_per_prompt=num_images_per_prompt,
962+
text_input_ids=text_input_ids_list[i] if text_input_ids_list else None,
955963
)
956964
clip_prompt_embeds_list.append(prompt_embeds)
957965
clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds)

0 commit comments

Comments
 (0)