Skip to content

Commit 7988119

Browse files
authored
Merge pull request #2 from a-r-r-o-w/latte-2
update _toctree.yml for docs and fix example
2 parents 2ea37c9 + 521ed5c commit 7988119

File tree

6 files changed

+25
-13
lines changed

6 files changed

+25
-13
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,8 @@
249249
title: DiTTransformer2DModel
250250
- local: api/models/hunyuan_transformer2d
251251
title: HunyuanDiT2DModel
252+
- local: api/models/latte_transformer3d
253+
title: LatteTransformer3DModel
252254
- local: api/models/lumina_nextdit2d
253255
title: LuminaNextDiT2DModel
254256
- local: api/models/transformer_temporal

docs/source/en/api/models/latte_transformer3d.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,3 @@ A Diffusion Transformer model for 3D data from [Latte](https://github.com/Vchite
1717
## LatteTransformer3DModel
1818

1919
[[autodoc]] LatteTransformer3DModel
20-

src/diffusers/models/embeddings.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,21 @@ def get_timestep_embedding(
3535
"""
3636
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
3737
38-
:param timesteps: a 1-D Tensor of N indices, one per batch element.
39-
These may be fractional.
40-
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
41-
embeddings. :return: an [N x dim] Tensor of positional embeddings.
38+
Args
39+
timesteps (torch.Tensor):
40+
a 1-D Tensor of N indices, one per batch element. These may be fractional.
41+
embedding_dim (int):
42+
the dimension of the output.
43+
flip_sin_to_cos (bool):
44+
Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
45+
downscale_freq_shift (float):
46+
Controls the delta between frequencies between dimensions
47+
scale (float):
48+
Scaling factor applied to the embeddings.
49+
max_period (int):
50+
Controls the maximum frequency of the embeddings
51+
Returns
52+
torch.Tensor: an [N x dim] Tensor of positional embeddings.
4253
"""
4354
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
4455

src/diffusers/pipelines/latte/pipeline_latte.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,12 @@
5656
>>> from diffusers.utils import export_to_gif
5757
5858
>>> # You can replace the checkpoint id with "maxin-cn/Latte-1" too.
59-
>>> pipe = LattePipeline.from_pretrained("maxin-cn/Latte-1", torch_dtype=torch.float16)
59+
>>> pipe = LattePipeline.from_pretrained("maxin-cn/Latte-1", torch_dtype=torch.float16).to("cuda")
6060
>>> # Enable memory optimizations.
6161
>>> pipe.enable_model_cpu_offload()
6262
6363
>>> prompt = "A small cactus with a happy face in the Sahara desert."
64-
>>> videos = pipe(prompt).frames
64+
>>> videos = pipe(prompt).frames[0]
6565
>>> export_to_gif(videos, "latte.gif")
6666
```
6767
"""
@@ -576,7 +576,7 @@ def prepare_latents(
576576
# scale the initial noise by the standard deviation required by the scheduler
577577
latents = latents * self.scheduler.init_noise_sigma
578578
return latents
579-
579+
580580
@property
581581
def guidance_scale(self):
582582
return self._guidance_scale

src/diffusers/utils/dummy_torch_and_transformers_objects.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -677,7 +677,7 @@ def from_pretrained(cls, *args, **kwargs):
677677
requires_backends(cls, ["torch", "transformers"])
678678

679679

680-
class LDMTextToImagePipeline(metaclass=DummyObject):
680+
class LattePipeline(metaclass=DummyObject):
681681
_backends = ["torch", "transformers"]
682682

683683
def __init__(self, *args, **kwargs):
@@ -692,7 +692,7 @@ def from_pretrained(cls, *args, **kwargs):
692692
requires_backends(cls, ["torch", "transformers"])
693693

694694

695-
class LEditsPPPipelineStableDiffusion(metaclass=DummyObject):
695+
class LDMTextToImagePipeline(metaclass=DummyObject):
696696
_backends = ["torch", "transformers"]
697697

698698
def __init__(self, *args, **kwargs):
@@ -707,7 +707,7 @@ def from_pretrained(cls, *args, **kwargs):
707707
requires_backends(cls, ["torch", "transformers"])
708708

709709

710-
class LEditsPPPipelineStableDiffusionXL(metaclass=DummyObject):
710+
class LEditsPPPipelineStableDiffusion(metaclass=DummyObject):
711711
_backends = ["torch", "transformers"]
712712

713713
def __init__(self, *args, **kwargs):
@@ -722,7 +722,7 @@ def from_pretrained(cls, *args, **kwargs):
722722
requires_backends(cls, ["torch", "transformers"])
723723

724724

725-
class LattePipeline(metaclass=DummyObject):
725+
class LEditsPPPipelineStableDiffusionXL(metaclass=DummyObject):
726726
_backends = ["torch", "transformers"]
727727

728728
def __init__(self, *args, **kwargs):

tests/pipelines/latte/test_latte.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515

1616
import gc
17+
import inspect
1718
import tempfile
1819
import unittest
1920

@@ -38,7 +39,6 @@
3839
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
3940
from ..test_pipelines_common import PipelineTesterMixin, to_np
4041

41-
import inspect
4242

4343
enable_full_determinism()
4444

0 commit comments

Comments
 (0)