Skip to content

Commit 05b38c3

Browse files
authored
Fix Flux CLIP prompt embeds repeat for num_images_per_prompt > 1 (#9280)
update
1 parent 8f7fde5 commit 05b38c3

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

src/diffusers/pipelines/flux/pipeline_flux.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def _get_clip_prompt_embeds(
280280
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
281281

282282
# duplicate text embeddings for each generation per prompt, using mps friendly method
283-
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
283+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
284284
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
285285

286286
return prompt_embeds

src/diffusers/pipelines/flux/pipeline_flux_controlnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ def _get_clip_prompt_embeds(
302302
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
303303

304304
# duplicate text embeddings for each generation per prompt, using mps friendly method
305-
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
305+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
306306
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
307307

308308
return prompt_embeds

0 commit comments

Comments
 (0)