Skip to content

Commit aae0fce

Browse files
authored
Update test_models_transformer_hunyuan_video.py
1 parent a7e9f85 commit aae0fce

File tree

1 file changed

+66
-1
lines changed

1 file changed

+66
-1
lines changed

tests/models/transformers/test_models_transformer_hunyuan_video.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from diffusers import HunyuanVideoTransformer3DModel
2020
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
21+
from diffusers.utils.testing_utils import require_torch_gpu, require_torch_2, is_torch_compile, slow
2122

2223
from ..test_modeling_common import ModelTesterMixin
2324

@@ -88,6 +89,25 @@ def prepare_init_args_and_inputs_for_common(self):
8889
def test_gradient_checkpointing_is_applied(self):
8990
expected_set = {"HunyuanVideoTransformer3DModel"}
9091
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
92+
93+
def test_gradient_checkpointing_is_applied(self):
94+
expected_set = {"HunyuanVideoTransformer3DModel"}
95+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
96+
97+
@require_torch_gpu
98+
@require_torch_2
99+
@is_torch_compile
100+
@slow
101+
def test_torch_compile_recompilation_and_graph_break(self):
102+
torch._dynamo.reset()
103+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
104+
105+
model = self.model_class(**init_dict).to(torch_device)
106+
model = torch.compile(model, fullgraph=True)
107+
108+
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
109+
_ = model(**inputs_dict)
110+
_ = model(**inputs_dict)
91111

92112

93113
class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
@@ -157,6 +177,21 @@ def test_gradient_checkpointing_is_applied(self):
157177
expected_set = {"HunyuanVideoTransformer3DModel"}
158178
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
159179

180+
@require_torch_gpu
181+
@require_torch_2
182+
@is_torch_compile
183+
@slow
184+
def test_torch_compile_recompilation_and_graph_break(self):
185+
torch._dynamo.reset()
186+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
187+
188+
model = self.model_class(**init_dict).to(torch_device)
189+
model = torch.compile(model, fullgraph=True)
190+
191+
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
192+
_ = model(**inputs_dict)
193+
_ = model(**inputs_dict)
194+
160195

161196
class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
162197
model_class = HunyuanVideoTransformer3DModel
@@ -222,6 +257,21 @@ def test_output(self):
222257
def test_gradient_checkpointing_is_applied(self):
223258
expected_set = {"HunyuanVideoTransformer3DModel"}
224259
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
260+
261+
@require_torch_gpu
262+
@require_torch_2
263+
@is_torch_compile
264+
@slow
265+
def test_torch_compile_recompilation_and_graph_break(self):
266+
torch._dynamo.reset()
267+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
268+
269+
model = self.model_class(**init_dict).to(torch_device)
270+
model = torch.compile(model, fullgraph=True)
271+
272+
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
273+
_ = model(**inputs_dict)
274+
_ = model(**inputs_dict)
225275

226276

227277
class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
@@ -286,7 +336,22 @@ def prepare_init_args_and_inputs_for_common(self):
286336

287337
def test_output(self):
288338
super().test_output(expected_output_shape=(1, *self.output_shape))
289-
339+
290340
def test_gradient_checkpointing_is_applied(self):
291341
expected_set = {"HunyuanVideoTransformer3DModel"}
292342
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
343+
344+
@require_torch_gpu
345+
@require_torch_2
346+
@is_torch_compile
347+
@slow
348+
def test_torch_compile_recompilation_and_graph_break(self):
349+
torch._dynamo.reset()
350+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
351+
352+
model = self.model_class(**init_dict).to(torch_device)
353+
model = torch.compile(model, fullgraph=True)
354+
355+
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
356+
_ = model(**inputs_dict)
357+
_ = model(**inputs_dict)

0 commit comments

Comments
 (0)