Skip to content

Commit c7407a3

Browse files
committed
[refactor] move positional embeddings to patch embed layer for CogVideoX (#9263)
* remove frame limit in cogvideox * remove debug prints * Update src/diffusers/models/transformers/cogvideox_transformer_3d.py * revert pipeline; remove frame limitation * revert transformer changes * address review comments * add error message * apply suggestions from review
1 parent 7be2611 commit c7407a3

File tree

2 files changed

+81
-34
lines changed

2 files changed

+81
-34
lines changed

src/diffusers/models/embeddings.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,15 +342,58 @@ def __init__(
342342
embed_dim: int = 1920,
343343
text_embed_dim: int = 4096,
344344
bias: bool = True,
345+
sample_width: int = 90,
346+
sample_height: int = 60,
347+
sample_frames: int = 49,
348+
temporal_compression_ratio: int = 4,
349+
max_text_seq_length: int = 226,
350+
spatial_interpolation_scale: float = 1.875,
351+
temporal_interpolation_scale: float = 1.0,
352+
use_positional_embeddings: bool = True,
345353
) -> None:
346354
super().__init__()
355+
347356
self.patch_size = patch_size
357+
self.embed_dim = embed_dim
358+
self.sample_height = sample_height
359+
self.sample_width = sample_width
360+
self.sample_frames = sample_frames
361+
self.temporal_compression_ratio = temporal_compression_ratio
362+
self.max_text_seq_length = max_text_seq_length
363+
self.spatial_interpolation_scale = spatial_interpolation_scale
364+
self.temporal_interpolation_scale = temporal_interpolation_scale
365+
self.use_positional_embeddings = use_positional_embeddings
348366

349367
self.proj = nn.Conv2d(
350368
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
351369
)
352370
self.text_proj = nn.Linear(text_embed_dim, embed_dim)
353371

372+
if use_positional_embeddings:
373+
pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames)
374+
self.register_buffer("pos_embedding", pos_embedding, persistent=False)
375+
376+
def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor:
377+
post_patch_height = sample_height // self.patch_size
378+
post_patch_width = sample_width // self.patch_size
379+
post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1
380+
num_patches = post_patch_height * post_patch_width * post_time_compression_frames
381+
382+
pos_embedding = get_3d_sincos_pos_embed(
383+
self.embed_dim,
384+
(post_patch_width, post_patch_height),
385+
post_time_compression_frames,
386+
self.spatial_interpolation_scale,
387+
self.temporal_interpolation_scale,
388+
)
389+
pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1)
390+
joint_pos_embedding = torch.zeros(
391+
1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False
392+
)
393+
joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding)
394+
395+
return joint_pos_embedding
396+
354397
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
355398
r"""
356399
Args:
@@ -371,6 +414,21 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
371414
embeds = torch.cat(
372415
[text_embeds, image_embeds], dim=1
373416
).contiguous() # [batch, seq_length + num_frames x height x width, channels]
417+
418+
if self.use_positional_embeddings:
419+
pre_time_compression_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
420+
if (
421+
self.sample_height != height
422+
or self.sample_width != width
423+
or self.sample_frames != pre_time_compression_frames
424+
):
425+
pos_embedding = self._get_positional_embeddings(height, width, pre_time_compression_frames)
426+
pos_embedding = pos_embedding.to(embeds.device, dtype=embeds.dtype)
427+
else:
428+
pos_embedding = self.pos_embedding
429+
430+
embeds = embeds + pos_embedding
431+
374432
return embeds
375433

376434

src/diffusers/models/transformers/cogvideox_transformer_3d.py

Lines changed: 23 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from ...utils.torch_utils import maybe_allow_in_graph
2424
from ..attention import Attention, FeedForward
2525
from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
26-
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed
26+
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
2727
from ..modeling_outputs import Transformer2DModelOutput
2828
from ..modeling_utils import ModelMixin
2929
from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero
@@ -239,33 +239,29 @@ def __init__(
239239
super().__init__()
240240
inner_dim = num_attention_heads * attention_head_dim
241241

242-
post_patch_height = sample_height // patch_size
243-
post_patch_width = sample_width // patch_size
244-
post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1
245-
self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames
246-
247242
# 1. Patch embedding
248-
self.patch_embed = CogVideoXPatchEmbed(patch_size, in_channels, inner_dim, text_embed_dim, bias=True)
249-
self.embedding_dropout = nn.Dropout(dropout)
250-
251-
# 2. 3D positional embeddings
252-
spatial_pos_embedding = get_3d_sincos_pos_embed(
253-
inner_dim,
254-
(post_patch_width, post_patch_height),
255-
post_time_compression_frames,
256-
spatial_interpolation_scale,
257-
temporal_interpolation_scale,
243+
self.patch_embed = CogVideoXPatchEmbed(
244+
patch_size=patch_size,
245+
in_channels=in_channels,
246+
embed_dim=inner_dim,
247+
text_embed_dim=text_embed_dim,
248+
bias=True,
249+
sample_width=sample_width,
250+
sample_height=sample_height,
251+
sample_frames=sample_frames,
252+
temporal_compression_ratio=temporal_compression_ratio,
253+
max_text_seq_length=max_text_seq_length,
254+
spatial_interpolation_scale=spatial_interpolation_scale,
255+
temporal_interpolation_scale=temporal_interpolation_scale,
256+
use_positional_embeddings=not use_rotary_positional_embeddings,
258257
)
259-
spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1)
260-
pos_embedding = torch.zeros(1, max_text_seq_length + self.num_patches, inner_dim, requires_grad=False)
261-
pos_embedding.data[:, max_text_seq_length:].copy_(spatial_pos_embedding)
262-
self.register_buffer("pos_embedding", pos_embedding, persistent=False)
258+
self.embedding_dropout = nn.Dropout(dropout)
263259

264-
# 3. Time embeddings
260+
# 2. Time embeddings
265261
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
266262
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
267263

268-
# 4. Define spatio-temporal transformers blocks
264+
# 3. Define spatio-temporal transformers blocks
269265
self.transformer_blocks = nn.ModuleList(
270266
[
271267
CogVideoXBlock(
@@ -284,7 +280,7 @@ def __init__(
284280
)
285281
self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
286282

287-
# 5. Output blocks
283+
# 4. Output blocks
288284
self.norm_out = AdaLayerNorm(
289285
embedding_dim=time_embed_dim,
290286
output_dim=2 * inner_dim,
@@ -422,20 +418,13 @@ def forward(
422418

423419
# 2. Patch embedding
424420
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
421+
hidden_states = self.embedding_dropout(hidden_states)
425422

426-
# 3. Position embedding
427423
text_seq_length = encoder_hidden_states.shape[1]
428-
if not self.config.use_rotary_positional_embeddings:
429-
seq_length = height * width * num_frames // (self.config.patch_size**2)
430-
431-
pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length]
432-
hidden_states = hidden_states + pos_embeds
433-
hidden_states = self.embedding_dropout(hidden_states)
434-
435424
encoder_hidden_states = hidden_states[:, :text_seq_length]
436425
hidden_states = hidden_states[:, text_seq_length:]
437426

438-
# 4. Transformer blocks
427+
# 3. Transformer blocks
439428
for i, block in enumerate(self.transformer_blocks):
440429
if self.training and self.gradient_checkpointing:
441430

@@ -471,11 +460,11 @@ def custom_forward(*inputs):
471460
hidden_states = self.norm_final(hidden_states)
472461
hidden_states = hidden_states[:, text_seq_length:]
473462

474-
# 5. Final block
463+
# 4. Final block
475464
hidden_states = self.norm_out(hidden_states, temb=emb)
476465
hidden_states = self.proj_out(hidden_states)
477466

478-
# 6. Unpatchify
467+
# 5. Unpatchify
479468
p = self.config.patch_size
480469
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p)
481470
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)

0 commit comments

Comments
 (0)