Skip to content

CogView4 (supports different length c and uc) #10649

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 88 commits into from
Feb 15, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
88 commits
Select commit Hold shift + click to select a range
2640bcf
init
zRzRzRzRzRzRzR Jan 14, 2025
eba11fa
Merge branch 'huggingface:main' into cogview4
zRzRzRzRzRzRzR Jan 14, 2025
6163679
encode with glm
zRzRzRzRzRzRzR Jan 14, 2025
6090ea7
draft schedule
zRzRzRzRzRzRzR Jan 15, 2025
c7d1227
feat(scheduler): Add CogView scheduler implementation
OleehyO Jan 16, 2025
e9f6626
Merge remote-tracking branch 'origin/cogview4' into cogview4
OleehyO Jan 16, 2025
549b357
Merge branch 'huggingface:main' into cogview4
zRzRzRzRzRzRzR Jan 16, 2025
004d002
Merge branch 'huggingface:main' into cogview4
zRzRzRzRzRzRzR Jan 16, 2025
f4457fb
feat(embeddings): add CogView 2D rotary positional embedding
OleehyO Jan 17, 2025
5f8d33b
Merge remote-tracking branch 'origin/cogview4' into cogview4
OleehyO Jan 17, 2025
9a93218
1
zRzRzRzRzRzRzR Jan 17, 2025
ca000dd
Update pipeline_cogview4.py
zRzRzRzRzRzRzR Jan 17, 2025
7ab4a3f
fix the timestep init and sigma
zRzRzRzRzRzRzR Jan 18, 2025
56ceaa6
update latent
zRzRzRzRzRzRzR Jan 19, 2025
a7179a2
draft patch(not work)
zRzRzRzRzRzRzR Jan 19, 2025
c9ddf50
Merge branch 'cogview4'
zRzRzRzRzRzRzR Jan 22, 2025
2f30cc1
Merge pull request #2 from zRzRzRzRzRzRzR/main
zRzRzRzRzRzRzR Jan 22, 2025
e6b8907
fix
zRzRzRzRzRzRzR Jan 22, 2025
0ab7260
[WIP][cogview4]: implement initial CogView4 pipeline
OleehyO Jan 23, 2025
f608f82
[WIP][cogview4][refactor]: Split condition/uncondition forward pass i…
OleehyO Jan 23, 2025
b86bfd4
use with -2 hidden state
zRzRzRzRzRzRzR Jan 23, 2025
c4d1e69
remove text_projector
zRzRzRzRzRzRzR Jan 23, 2025
7916140
1
zRzRzRzRzRzRzR Jan 23, 2025
f8945ce
[WIP] Add tensor-reload to align input from transformer block
OleehyO Jan 24, 2025
bf7f322
[WIP] for older glm
zRzRzRzRzRzRzR Jan 24, 2025
dd6568b
use with cogview4 transformers forward twice of u and uc
zRzRzRzRzRzRzR Jan 25, 2025
6f5407e
Merge branch 'huggingface:main' into cogview4
zRzRzRzRzRzRzR Jan 25, 2025
9e5b991
Update convert_cogview4_to_diffusers.py
zRzRzRzRzRzRzR Jan 25, 2025
36b1682
remove this
zRzRzRzRzRzRzR Jan 26, 2025
804f5cc
Merge pull request #3 from zRzRzRzRzRzRzR/main
zRzRzRzRzRzRzR Jan 28, 2025
16c2397
use main example
zRzRzRzRzRzRzR Jan 28, 2025
601696d
change back
zRzRzRzRzRzRzR Jan 28, 2025
84115dc
reset
zRzRzRzRzRzRzR Jan 28, 2025
95a103f
setback
zRzRzRzRzRzRzR Jan 28, 2025
d932f67
back
zRzRzRzRzRzRzR Jan 28, 2025
b04f15d
back 4
zRzRzRzRzRzRzR Jan 28, 2025
5d33f3f
Fix qkv conversion logic for CogView4 to Diffusers format
zRzRzRzRzRzRzR Jan 28, 2025
b889b37
back5
zRzRzRzRzRzRzR Jan 28, 2025
e239c3c
revert to sat to cogview4 version
zRzRzRzRzRzRzR Jan 28, 2025
310da29
update a new convert from megatron
zRzRzRzRzRzRzR Jan 28, 2025
3bd6d30
[WIP][cogview4]: implement CogView4 attention processor
OleehyO Jan 28, 2025
f826aec
[cogview4] implement CogView4 transformer block
OleehyO Jan 28, 2025
8d8ed8b
Merge branch 'huggingface:main' into cogview4
zRzRzRzRzRzRzR Jan 28, 2025
bf1fdc8
with new attn
zRzRzRzRzRzRzR Jan 28, 2025
6a3a07f
[bugfix] fix dimension mismatch in CogView4 attention
OleehyO Jan 28, 2025
de274f3
[cogview4][WIP]: update final normalization in CogView4 transformer
OleehyO Jan 28, 2025
e94999e
Merge remote-tracking branch 'origin/cogview4' into cogview4
OleehyO Jan 28, 2025
e238284
Merge branch 'huggingface:main' into cogview4
zRzRzRzRzRzRzR Feb 1, 2025
a9b1e16
Merge branch 'huggingface:main' into cogview4
zRzRzRzRzRzRzR Feb 5, 2025
46277b2
1
zRzRzRzRzRzRzR Feb 5, 2025
ebbaa5b
put back
zRzRzRzRzRzRzR Feb 5, 2025
f1ccdd2
Update transformer_cogview4.py
zRzRzRzRzRzRzR Feb 5, 2025
030a467
change time_shift
zRzRzRzRzRzRzR Feb 6, 2025
ad40575
Update pipeline_cogview4.py
zRzRzRzRzRzRzR Feb 6, 2025
81d39ee
change timesteps
zRzRzRzRzRzRzR Feb 6, 2025
45f9e88
fix
zRzRzRzRzRzRzR Feb 6, 2025
1dbeaa8
change text_encoder_id
zRzRzRzRzRzRzR Feb 6, 2025
f209600
[cogview4][rope] align RoPE implementation with Megatron
OleehyO Feb 6, 2025
992f5a3
[cogview4][bugfix] apply silu activation to time embeddings in CogView4
OleehyO Feb 6, 2025
03a1c3b
[cogview4][chore] clean up pipeline code
OleehyO Feb 6, 2025
dd34794
Merge remote-tracking branch 'origin/cogview4' into cogview4
OleehyO Feb 6, 2025
3dab073
[cogview4][scheduler] Implement CogView4 scheduler and pipeline
OleehyO Feb 6, 2025
63982d6
now It work
zRzRzRzRzRzRzR Feb 6, 2025
90a5706
Merge branch 'huggingface:main' into cogview4
zRzRzRzRzRzRzR Feb 6, 2025
d4748e0
add timestep
zRzRzRzRzRzRzR Feb 7, 2025
95f851d
batch
zRzRzRzRzRzRzR Feb 7, 2025
cb56282
change convert scipt
zRzRzRzRzRzRzR Feb 7, 2025
fedf325
refactor pt. 1; make style
a-r-r-o-w Feb 10, 2025
90d29c7
Merge branch 'huggingface:main' into cogview4
zRzRzRzRzRzRzR Feb 10, 2025
4c01c9d
refactor pt. 2
a-r-r-o-w Feb 12, 2025
c1b8004
refactor pt. 3
a-r-r-o-w Feb 12, 2025
9d55d0a
add tests
a-r-r-o-w Feb 12, 2025
5e6de42
make fix-copies
a-r-r-o-w Feb 12, 2025
30dd0ad
Merge branch 'main' into cogview4
a-r-r-o-w Feb 12, 2025
2046cf2
update toctree.yml
a-r-r-o-w Feb 12, 2025
39e1198
use flow match scheduler instead of custom
a-r-r-o-w Feb 13, 2025
b566a9f
Merge branch 'main' into cogview4
a-r-r-o-w Feb 13, 2025
b4c9fde
remove scheduling_cogview.py
a-r-r-o-w Feb 13, 2025
a137e17
add tiktoken to test dependencies
a-r-r-o-w Feb 13, 2025
da420fb
Update src/diffusers/models/embeddings.py
a-r-r-o-w Feb 13, 2025
4003b9c
apply suggestions from review
a-r-r-o-w Feb 13, 2025
35c0ec6
use diffusers apply_rotary_emb
a-r-r-o-w Feb 13, 2025
d328c5e
update flow match scheduler to accept timesteps
a-r-r-o-w Feb 14, 2025
d637d3a
Merge branch 'main' into cogview4
a-r-r-o-w Feb 14, 2025
4c37ef0
fix comment
a-r-r-o-w Feb 14, 2025
90c240b
apply review sugestions
a-r-r-o-w Feb 14, 2025
5c11298
Merge branch 'main' into cogview4
a-r-r-o-w Feb 14, 2025
2f12b7a
Update src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
a-r-r-o-w Feb 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
use flow match scheduler instead of custom
  • Loading branch information
a-r-r-o-w committed Feb 13, 2025
commit 39e1198029b8df98cdda202066073734f00d7d6d
17 changes: 3 additions & 14 deletions scripts/convert_cogview4_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from accelerate import init_empty_weights
from transformers import GlmForCausalLM, PreTrainedTokenizerFast

from diffusers import AutoencoderKL, CogView4DDIMScheduler, CogView4Pipeline, CogView4Transformer2DModel
from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
from diffusers.utils.import_utils import is_accelerate_available

Expand Down Expand Up @@ -222,19 +222,8 @@ def main(args):
for param in text_encoder.parameters():
param.data = param.data.contiguous()

scheduler = CogView4DDIMScheduler.from_config(
{
"shift_scale": 1.0,
"beta_end": 0.012,
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
"clip_sample": False,
"num_train_timesteps": 1000,
"prediction_type": "v_prediction",
"rescale_betas_zero_snr": True,
"set_alpha_to_one": True,
"timestep_spacing": "linspace",
}
scheduler = FlowMatchEulerDiscreteScheduler(
base_shift=0.25, max_shift=0.75, base_image_seq_len=256, use_dynamic_shifting=True, time_shift_type="linear"
)

pipe = CogView4Pipeline(
Expand Down
22 changes: 3 additions & 19 deletions scripts/convert_cogview4_to_diffusers_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,7 @@
from tqdm import tqdm
from transformers import GlmForCausalLM, PreTrainedTokenizerFast

from diffusers import (
AutoencoderKL,
CogView4DDIMScheduler,
CogView4Pipeline,
CogView4Transformer2DModel,
)
from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint


Expand Down Expand Up @@ -345,19 +340,8 @@ def main(args):
param.data = param.data.contiguous()

# Initialize the scheduler
scheduler = CogView4DDIMScheduler.from_config(
{
"shift_scale": 1.0,
"beta_end": 0.012,
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
"clip_sample": False,
"num_train_timesteps": 1000,
"prediction_type": "v_prediction",
"rescale_betas_zero_snr": True,
"set_alpha_to_one": True,
"timestep_spacing": "linspace",
}
scheduler = FlowMatchEulerDiscreteScheduler(
base_shift=0.25, max_shift=0.75, base_image_seq_len=256, use_dynamic_shifting=True, time_shift_type="linear"
)

# Create the pipeline
Expand Down
2 changes: 0 additions & 2 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,6 @@
"CMStochasticIterativeScheduler",
"CogVideoXDDIMScheduler",
"CogVideoXDPMScheduler",
"CogView4DDIMScheduler",
"DDIMInverseScheduler",
"DDIMParallelScheduler",
"DDIMScheduler",
Expand Down Expand Up @@ -707,7 +706,6 @@
CMStochasticIterativeScheduler,
CogVideoXDDIMScheduler,
CogVideoXDPMScheduler,
CogView4DDIMScheduler,
DDIMInverseScheduler,
DDIMParallelScheduler,
DDIMScheduler,
Expand Down
124 changes: 107 additions & 17 deletions src/diffusers/pipelines/cogview4/pipeline_cogview4.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
from typing import Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
from transformers import AutoTokenizer, GlmModel

from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, CogView4Transformer2DModel
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import CogView4DDIMScheduler
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from .pipeline_output import CogView4PipelineOutput
Expand Down Expand Up @@ -53,6 +55,82 @@
"""


def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
base_shift: float = 0.25,
max_shift: float = 0.75,
):
# m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
# b = base_shift - m * base_seq_len
# mu = image_seq_len * m + b
# return mu

m = (image_seq_len / base_seq_len) ** 0.5
mu = m * max_shift + base_shift
return mu


# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.

Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.

Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps


class CogView4Pipeline(DiffusionPipeline):
r"""
Pipeline for text-to-image generation using CogView4.
Expand Down Expand Up @@ -86,7 +164,7 @@ def __init__(
text_encoder: GlmModel,
vae: AutoencoderKL,
transformer: CogView4Transformer2DModel,
scheduler: CogView4DDIMScheduler,
scheduler: FlowMatchEulerDiscreteScheduler,
):
super().__init__()

Expand Down Expand Up @@ -219,8 +297,10 @@ def encode_prompt(

return prompt_embeds, negative_prompt_embeds

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
if latents is not None:
return latents.to(device)

shape = (
batch_size,
num_channels_latents,
Expand All @@ -232,14 +312,7 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)

if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)

# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
return latents

def check_inputs(
Expand Down Expand Up @@ -322,6 +395,7 @@ def __call__(
width: Optional[int] = None,
num_inference_steps: int = 50,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 5.0,
num_images_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
Expand Down Expand Up @@ -359,6 +433,10 @@ def __call__(
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to `5.0`):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Expand Down Expand Up @@ -491,9 +569,22 @@ def __call__(
image_seq_len = ((height // self.vae_scale_factor) * (width // self.vae_scale_factor)) // (
self.transformer.config.patch_size**2
)
self.scheduler.set_timesteps(num_inference_steps, image_seq_len, device)
timesteps = self.scheduler.timesteps
self._num_timesteps = len(timesteps)

timesteps = (
np.linspace(self.scheduler.config.num_train_timesteps, 1.0, num_inference_steps)
if timesteps is None
else np.array(timesteps)
)
timesteps = timesteps.astype(np.int64)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit different from what we usually do. We don't use self.scheduler.timesteps returned from the call to retrieve_timesteps because those timesteps are from after resolution-based timestep shifting is applied. For CogView4, it seems like we need to use the timesteps from before applying shifting, but sigmas from after applying shifting.

Copy link
Collaborator

@yiyixuxu yiyixuxu Feb 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's really werid - but if it's on purpose (not a oversight or something), we need support it from retrieve_timesteps and also the set_timesteps method from scheduler to accept both custom timesteps and custom sigmas. Either that or maybe add an option to calculate timesteps based on the sigmas pre-shiting

otherwise I don't think it would function correctly for img2img or training, where you do not start from the first timestep and need to search against self.scheduler.timesteps
e.g. this function won't work

def index_for_timestep(self, timestep, schedule_timesteps=None):

Copy link
Member

@a-r-r-o-w a-r-r-o-w Feb 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have access to the original codebase yet, so it will be hard to check if it's an oversight. It is weird that we have to do it this way, but if we don't do it (that is having sigmas corresponding to timesteps), the final outputs come out with some residual noise.

Also seems like in my latest update I made a mistake doing timesteps.astype(np.float32) from some local testing. Basically, we want integer timesteps here first (to round down the float values from linspace), but then need float32 timesteps for our scheduler to not raise an error:

raise ValueError(
(
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
" one of the `scheduler.timesteps` as a timestep."
),
)

So, it will have to be something like timesteps.astype(np.int64).astype(np.float32) to be consistent with the behaviour when we started updating the PR and to not error out in our scheduler

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could return timesteps without shifting, then apply shifting on the fly in scheduler.step?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it will have to be something like timesteps.astype(np.int64).astype(np.float32)

ok, it seems like custom timesteps might be the way to go because this logic here is just really custom (even if we calculate the timesteps without shifting, we also need to do this round up thing first)
basically, you need to:

  1. remove the ValueError about passing sigma and timesteps at the same time
  2. add timesteps to set_timesteps
    timesteps: Optional[List[int]] = None,

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll do some more testing with a fresh mind in the morning to verify if we need differing timesteps vs sigmas here (to recheck if it is a possible oversight or not). I do think I did everything correctly when I tried earlier today, and as a result we might need to do what you mentioned, but wouldn't hurt to delay a little longer and verify again if it might help save us a bunch of changes.

@zRzRzRzRzRzRzR If it would be possible to share just the scheduler implemention related files with us, it would really help us understand if changes are required. No problem if not :) We can wait for the official release from THUDM and update our implementation

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we're going to have to go forward with the update regarding both timesteps and sigmas being provided. Almost anything else I try seems to result in residual noise. I've pushed some changes regarding handling with some additional comments in .step() for better understanding. Okay to remove if it's not really required

sigmas = timesteps / self.scheduler.config.num_train_timesteps if sigmas is None else sigmas
mu = calculate_shift(
image_seq_len,
self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("base_shift", 0.25),
self.scheduler.config.get("max_shift", 0.75),
)
_, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas, mu=mu)
timesteps = torch.from_numpy(timesteps).to(device)

# Denoising loop
transformer_dtype = self.transformer.dtype
Expand All @@ -504,8 +595,7 @@ def __call__(
if self.interrupt:
continue

latent_model_input = self.scheduler.scale_model_input(latents, t)
latent_model_input = latent_model_input.to(transformer_dtype)
latent_model_input = latents.to(transformer_dtype)

# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0])
Expand Down Expand Up @@ -536,7 +626,7 @@ def __call__(
else:
noise_pred = noise_pred_cond

latents = self.scheduler.step(noise_pred, latents, t).prev_sample
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]

# call the callback, if provided
if callback_on_step_end is not None:
Expand Down
2 changes: 0 additions & 2 deletions src/diffusers/schedulers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
_import_structure["scheduling_consistency_models"] = ["CMStochasticIterativeScheduler"]
_import_structure["scheduling_ddim"] = ["DDIMScheduler"]
_import_structure["scheduling_ddim_cogvideox"] = ["CogVideoXDDIMScheduler"]
_import_structure["scheduling_ddim_cogview4"] = ["CogView4DDIMScheduler"]
_import_structure["scheduling_ddim_inverse"] = ["DDIMInverseScheduler"]
_import_structure["scheduling_ddim_parallel"] = ["DDIMParallelScheduler"]
_import_structure["scheduling_ddpm"] = ["DDPMScheduler"]
Expand Down Expand Up @@ -145,7 +144,6 @@
from .scheduling_consistency_models import CMStochasticIterativeScheduler
from .scheduling_ddim import DDIMScheduler
from .scheduling_ddim_cogvideox import CogVideoXDDIMScheduler
from .scheduling_ddim_cogview4 import CogView4DDIMScheduler
from .scheduling_ddim_inverse import DDIMInverseScheduler
from .scheduling_ddim_parallel import DDIMParallelScheduler
from .scheduling_ddpm import DDPMScheduler
Expand Down
Loading