|
18 | 18 |
|
19 | 19 | from diffusers import HunyuanVideoTransformer3DModel
|
20 | 20 | from diffusers.utils.testing_utils import enable_full_determinism, torch_device
|
| 21 | +from diffusers.utils.testing_utils import require_torch_gpu, require_torch_2, is_torch_compile, slow |
21 | 22 |
|
22 | 23 | from ..test_modeling_common import ModelTesterMixin
|
23 | 24 |
|
@@ -88,6 +89,25 @@ def prepare_init_args_and_inputs_for_common(self):
|
88 | 89 | def test_gradient_checkpointing_is_applied(self):
|
89 | 90 | expected_set = {"HunyuanVideoTransformer3DModel"}
|
90 | 91 | super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
| 92 | + |
| 93 | + def test_gradient_checkpointing_is_applied(self): |
| 94 | + expected_set = {"HunyuanVideoTransformer3DModel"} |
| 95 | + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) |
| 96 | + |
| 97 | + @require_torch_gpu |
| 98 | + @require_torch_2 |
| 99 | + @is_torch_compile |
| 100 | + @slow |
| 101 | + def test_torch_compile_recompilation_and_graph_break(self): |
| 102 | + torch._dynamo.reset() |
| 103 | + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
| 104 | + |
| 105 | + model = self.model_class(**init_dict).to(torch_device) |
| 106 | + model = torch.compile(model, fullgraph=True) |
| 107 | + |
| 108 | + with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad(): |
| 109 | + _ = model(**inputs_dict) |
| 110 | + _ = model(**inputs_dict) |
91 | 111 |
|
92 | 112 |
|
93 | 113 | class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
|
@@ -157,6 +177,21 @@ def test_gradient_checkpointing_is_applied(self):
|
157 | 177 | expected_set = {"HunyuanVideoTransformer3DModel"}
|
158 | 178 | super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
159 | 179 |
|
| 180 | + @require_torch_gpu |
| 181 | + @require_torch_2 |
| 182 | + @is_torch_compile |
| 183 | + @slow |
| 184 | + def test_torch_compile_recompilation_and_graph_break(self): |
| 185 | + torch._dynamo.reset() |
| 186 | + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
| 187 | + |
| 188 | + model = self.model_class(**init_dict).to(torch_device) |
| 189 | + model = torch.compile(model, fullgraph=True) |
| 190 | + |
| 191 | + with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad(): |
| 192 | + _ = model(**inputs_dict) |
| 193 | + _ = model(**inputs_dict) |
| 194 | + |
160 | 195 |
|
161 | 196 | class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
|
162 | 197 | model_class = HunyuanVideoTransformer3DModel
|
@@ -222,6 +257,21 @@ def test_output(self):
|
222 | 257 | def test_gradient_checkpointing_is_applied(self):
|
223 | 258 | expected_set = {"HunyuanVideoTransformer3DModel"}
|
224 | 259 | super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
| 260 | + |
| 261 | + @require_torch_gpu |
| 262 | + @require_torch_2 |
| 263 | + @is_torch_compile |
| 264 | + @slow |
| 265 | + def test_torch_compile_recompilation_and_graph_break(self): |
| 266 | + torch._dynamo.reset() |
| 267 | + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
| 268 | + |
| 269 | + model = self.model_class(**init_dict).to(torch_device) |
| 270 | + model = torch.compile(model, fullgraph=True) |
| 271 | + |
| 272 | + with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad(): |
| 273 | + _ = model(**inputs_dict) |
| 274 | + _ = model(**inputs_dict) |
225 | 275 |
|
226 | 276 |
|
227 | 277 | class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
|
@@ -286,7 +336,22 @@ def prepare_init_args_and_inputs_for_common(self):
|
286 | 336 |
|
287 | 337 | def test_output(self):
|
288 | 338 | super().test_output(expected_output_shape=(1, *self.output_shape))
|
289 |
| - |
| 339 | + |
290 | 340 | def test_gradient_checkpointing_is_applied(self):
|
291 | 341 | expected_set = {"HunyuanVideoTransformer3DModel"}
|
292 | 342 | super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
| 343 | + |
| 344 | + @require_torch_gpu |
| 345 | + @require_torch_2 |
| 346 | + @is_torch_compile |
| 347 | + @slow |
| 348 | + def test_torch_compile_recompilation_and_graph_break(self): |
| 349 | + torch._dynamo.reset() |
| 350 | + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
| 351 | + |
| 352 | + model = self.model_class(**init_dict).to(torch_device) |
| 353 | + model = torch.compile(model, fullgraph=True) |
| 354 | + |
| 355 | + with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad(): |
| 356 | + _ = model(**inputs_dict) |
| 357 | + _ = model(**inputs_dict) |
0 commit comments