Skip to content

[Bug fix] "previous_timestep()" in DDPM scheduling compatible with "trailing" and "linspace" options #9384

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Dec 3, 2024

Conversation

AnandK27
Copy link
Contributor

@AnandK27 AnandK27 commented Sep 7, 2024

What does this PR do?

Fixes the bug of previous_timestep() not giving the right timestep for numbers not a factor of 1000 during "trailing" and "linspace" options. The previous_timestep has to be accruate for the options only during inference and since during inference the timesteps array is available, I just grab the value from there as for custom_timesteps. During training it is going to be 1 less than the current timestep.

(This is simplest solution I can do, if you need calculating the previous timestep for each case without the timesteps array let me know!)

Test Script

import torch
from diffusers import DDPMScheduler

# Initialize the DDPM scheduler
scheduler = DDPMScheduler.from_pretrained("google/ddpm-celebahq-256")

#test params
scheduler.config.timestep_spacing = 'trailing' # or 'linspace'
scheduler.set_timesteps(num_inference_steps = 7) # any prime number

# Set up a dummy pipeline
generator = torch.manual_seed(0)

# Initialize some dummy latents (normally this would come from your model)
latents = torch.randn((1, 3, 64, 64))

# Inference loop
for i, t in enumerate(scheduler.timesteps):
    # In a real scenario, you'd run your model here to get model_output
    model_output = torch.randn_like(latents)
    
    # Get the previous timestep
    prev_timestep = scheduler.previous_timestep(t)
    
    # Print current and previous timesteps
    print(f"Step {i}:")
    print(f"  Current timestep: {t}")
    print(f"  Previous timestep: {prev_timestep}")
    
    # Scheduler step
    latents = scheduler.step(model_output, t, latents, generator=generator).prev_sample

print("Test complete.")

Output

Step 0:
  Current timestep: 999
  Previous timestep: 856
Step 1:
  Current timestep: 856
  Previous timestep: 713
Step 2:
  Current timestep: 713
  Previous timestep: 570
Step 3:
  Current timestep: 570
  Previous timestep: 428
Step 4:
  Current timestep: 428
  Previous timestep: 285
Step 5:
  Current timestep: 285
  Previous timestep: 142
Step 6:
  Current timestep: 142
  Previous timestep: -1
Test complete.

Fixes #9261

Before submitting

Who can review?

@yiyixuxu @RochMollero @bghira
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@RochMollero
Copy link

RochMollero commented Sep 7, 2024 via email

@AnandK27
Copy link
Contributor Author

AnandK27 commented Sep 7, 2024

Yeah the change has to be propagated to remaining schedulers. There is also a lot of duplication in schedulers and will need an overhaul to modularize it. Probably the moderators can open issues for these.

@yiyixuxu yiyixuxu self-requested a review September 9, 2024 16:24
@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Sep 9, 2024

thanks so much for your PR! will take a look!
all our newer schedulers has a step counter so would not have this issue

@AnandK27
Copy link
Contributor Author

AnandK27 commented Sep 9, 2024

I will close the PR then

@yiyixuxu
Copy link
Collaborator

hey! sorry I just noticed this
it appears my comment #9384 (comment) has been misleading, I only meant that we do not need to apply the change for remaining schedulers (not all of them anyway), this is still very much an issue with DDPM

feel free to re-open this

@AnandK27 AnandK27 reopened this Oct 16, 2024
Copy link
Contributor

github-actions bot commented Nov 9, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Nov 9, 2024
@yiyixuxu yiyixuxu removed the stale Issues that haven't received updates label Dec 3, 2024
@yiyixuxu yiyixuxu requested a review from hlky December 3, 2024 03:52
@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Dec 3, 2024

can you take a look? @hlky

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor

@hlky hlky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @AnandK27! 🤗

@yiyixuxu yiyixuxu merged commit 5effcd3 into huggingface:main Dec 3, 2024
15 checks passed
lawrence-cj pushed a commit to lawrence-cj/diffusers that referenced this pull request Dec 4, 2024
…railing" and "linspace" options (huggingface#9384)

* Update scheduling_ddpm.py

* fix copies

---------

Co-authored-by: YiYi Xu <[email protected]>
Co-authored-by: hlky <[email protected]>
sayakpaul pushed a commit that referenced this pull request Dec 23, 2024
…railing" and "linspace" options (#9384)

* Update scheduling_ddpm.py

* fix copies

---------

Co-authored-by: YiYi Xu <[email protected]>
Co-authored-by: hlky <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

"previous_timestep()" in DDPM scheduling not compatible with "trailing" option. DDIM bugged too
5 participants