23
23
from ...utils .torch_utils import maybe_allow_in_graph
24
24
from ..attention import Attention , FeedForward
25
25
from ..attention_processor import AttentionProcessor , CogVideoXAttnProcessor2_0 , FusedCogVideoXAttnProcessor2_0
26
- from ..embeddings import CogVideoXPatchEmbed , TimestepEmbedding , Timesteps , get_3d_sincos_pos_embed
26
+ from ..embeddings import CogVideoXPatchEmbed , TimestepEmbedding , Timesteps
27
27
from ..modeling_outputs import Transformer2DModelOutput
28
28
from ..modeling_utils import ModelMixin
29
29
from ..normalization import AdaLayerNorm , CogVideoXLayerNormZero
@@ -239,33 +239,29 @@ def __init__(
239
239
super ().__init__ ()
240
240
inner_dim = num_attention_heads * attention_head_dim
241
241
242
- post_patch_height = sample_height // patch_size
243
- post_patch_width = sample_width // patch_size
244
- post_time_compression_frames = (sample_frames - 1 ) // temporal_compression_ratio + 1
245
- self .num_patches = post_patch_height * post_patch_width * post_time_compression_frames
246
-
247
242
# 1. Patch embedding
248
- self .patch_embed = CogVideoXPatchEmbed (patch_size , in_channels , inner_dim , text_embed_dim , bias = True )
249
- self .embedding_dropout = nn .Dropout (dropout )
250
-
251
- # 2. 3D positional embeddings
252
- spatial_pos_embedding = get_3d_sincos_pos_embed (
253
- inner_dim ,
254
- (post_patch_width , post_patch_height ),
255
- post_time_compression_frames ,
256
- spatial_interpolation_scale ,
257
- temporal_interpolation_scale ,
243
+ self .patch_embed = CogVideoXPatchEmbed (
244
+ patch_size = patch_size ,
245
+ in_channels = in_channels ,
246
+ embed_dim = inner_dim ,
247
+ text_embed_dim = text_embed_dim ,
248
+ bias = True ,
249
+ sample_width = sample_width ,
250
+ sample_height = sample_height ,
251
+ sample_frames = sample_frames ,
252
+ temporal_compression_ratio = temporal_compression_ratio ,
253
+ max_text_seq_length = max_text_seq_length ,
254
+ spatial_interpolation_scale = spatial_interpolation_scale ,
255
+ temporal_interpolation_scale = temporal_interpolation_scale ,
256
+ use_positional_embeddings = not use_rotary_positional_embeddings ,
258
257
)
259
- spatial_pos_embedding = torch .from_numpy (spatial_pos_embedding ).flatten (0 , 1 )
260
- pos_embedding = torch .zeros (1 , max_text_seq_length + self .num_patches , inner_dim , requires_grad = False )
261
- pos_embedding .data [:, max_text_seq_length :].copy_ (spatial_pos_embedding )
262
- self .register_buffer ("pos_embedding" , pos_embedding , persistent = False )
258
+ self .embedding_dropout = nn .Dropout (dropout )
263
259
264
- # 3 . Time embeddings
260
+ # 2 . Time embeddings
265
261
self .time_proj = Timesteps (inner_dim , flip_sin_to_cos , freq_shift )
266
262
self .time_embedding = TimestepEmbedding (inner_dim , time_embed_dim , timestep_activation_fn )
267
263
268
- # 4 . Define spatio-temporal transformers blocks
264
+ # 3 . Define spatio-temporal transformers blocks
269
265
self .transformer_blocks = nn .ModuleList (
270
266
[
271
267
CogVideoXBlock (
@@ -284,7 +280,7 @@ def __init__(
284
280
)
285
281
self .norm_final = nn .LayerNorm (inner_dim , norm_eps , norm_elementwise_affine )
286
282
287
- # 5 . Output blocks
283
+ # 4 . Output blocks
288
284
self .norm_out = AdaLayerNorm (
289
285
embedding_dim = time_embed_dim ,
290
286
output_dim = 2 * inner_dim ,
@@ -422,20 +418,13 @@ def forward(
422
418
423
419
# 2. Patch embedding
424
420
hidden_states = self .patch_embed (encoder_hidden_states , hidden_states )
421
+ hidden_states = self .embedding_dropout (hidden_states )
425
422
426
- # 3. Position embedding
427
423
text_seq_length = encoder_hidden_states .shape [1 ]
428
- if not self .config .use_rotary_positional_embeddings :
429
- seq_length = height * width * num_frames // (self .config .patch_size ** 2 )
430
-
431
- pos_embeds = self .pos_embedding [:, : text_seq_length + seq_length ]
432
- hidden_states = hidden_states + pos_embeds
433
- hidden_states = self .embedding_dropout (hidden_states )
434
-
435
424
encoder_hidden_states = hidden_states [:, :text_seq_length ]
436
425
hidden_states = hidden_states [:, text_seq_length :]
437
426
438
- # 4 . Transformer blocks
427
+ # 3 . Transformer blocks
439
428
for i , block in enumerate (self .transformer_blocks ):
440
429
if self .training and self .gradient_checkpointing :
441
430
@@ -471,11 +460,11 @@ def custom_forward(*inputs):
471
460
hidden_states = self .norm_final (hidden_states )
472
461
hidden_states = hidden_states [:, text_seq_length :]
473
462
474
- # 5 . Final block
463
+ # 4 . Final block
475
464
hidden_states = self .norm_out (hidden_states , temb = emb )
476
465
hidden_states = self .proj_out (hidden_states )
477
466
478
- # 6 . Unpatchify
467
+ # 5 . Unpatchify
479
468
p = self .config .patch_size
480
469
output = hidden_states .reshape (batch_size , num_frames , height // p , width // p , channels , p , p )
481
470
output = output .permute (0 , 1 , 4 , 2 , 5 , 3 , 6 ).flatten (5 , 6 ).flatten (3 , 4 )
0 commit comments