-
Notifications
You must be signed in to change notification settings - Fork 6k
[tests] Changes to the torch.compile()
CI and tests
#11508
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e6b9b67
a17d537
e9fee7c
81109a5
7dfd599
8cd92b5
a1ce459
23e2794
2db5af1
073fb7c
8d4c70a
3b383ed
08c26a8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -19,20 +19,16 @@ | |||||||||
from diffusers import HunyuanVideoTransformer3DModel | ||||||||||
from diffusers.utils.testing_utils import ( | ||||||||||
enable_full_determinism, | ||||||||||
is_torch_compile, | ||||||||||
require_torch_2, | ||||||||||
require_torch_gpu, | ||||||||||
slow, | ||||||||||
torch_device, | ||||||||||
) | ||||||||||
|
||||||||||
from ..test_modeling_common import ModelTesterMixin | ||||||||||
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin | ||||||||||
|
||||||||||
|
||||||||||
enable_full_determinism() | ||||||||||
|
||||||||||
|
||||||||||
class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase): | ||||||||||
class HunyuanVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase): | ||||||||||
model_class = HunyuanVideoTransformer3DModel | ||||||||||
main_input_name = "hidden_states" | ||||||||||
uses_custom_attn_processor = True | ||||||||||
|
@@ -96,23 +92,8 @@ def test_gradient_checkpointing_is_applied(self): | |||||||||
expected_set = {"HunyuanVideoTransformer3DModel"} | ||||||||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set) | ||||||||||
|
||||||||||
@require_torch_gpu | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would make sense to add these decorators to the top of the Mixin no? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These decorators are present in the mixin: diffusers/tests/models/test_modeling_common.py Lines 1762 to 1765 in b5c2050
|
||||||||||
@require_torch_2 | ||||||||||
@is_torch_compile | ||||||||||
@slow | ||||||||||
def test_torch_compile_recompilation_and_graph_break(self): | ||||||||||
torch._dynamo.reset() | ||||||||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() | ||||||||||
|
||||||||||
model = self.model_class(**init_dict).to(torch_device) | ||||||||||
model = torch.compile(model, fullgraph=True) | ||||||||||
|
||||||||||
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad(): | ||||||||||
_ = model(**inputs_dict) | ||||||||||
_ = model(**inputs_dict) | ||||||||||
|
||||||||||
|
||||||||||
class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase): | ||||||||||
class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase): | ||||||||||
model_class = HunyuanVideoTransformer3DModel | ||||||||||
main_input_name = "hidden_states" | ||||||||||
uses_custom_attn_processor = True | ||||||||||
|
@@ -179,23 +160,8 @@ def test_gradient_checkpointing_is_applied(self): | |||||||||
expected_set = {"HunyuanVideoTransformer3DModel"} | ||||||||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set) | ||||||||||
|
||||||||||
@require_torch_gpu | ||||||||||
@require_torch_2 | ||||||||||
@is_torch_compile | ||||||||||
@slow | ||||||||||
def test_torch_compile_recompilation_and_graph_break(self): | ||||||||||
torch._dynamo.reset() | ||||||||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() | ||||||||||
|
||||||||||
model = self.model_class(**init_dict).to(torch_device) | ||||||||||
model = torch.compile(model, fullgraph=True) | ||||||||||
|
||||||||||
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad(): | ||||||||||
_ = model(**inputs_dict) | ||||||||||
_ = model(**inputs_dict) | ||||||||||
|
||||||||||
|
||||||||||
class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase): | ||||||||||
class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase): | ||||||||||
model_class = HunyuanVideoTransformer3DModel | ||||||||||
main_input_name = "hidden_states" | ||||||||||
uses_custom_attn_processor = True | ||||||||||
|
@@ -260,23 +226,10 @@ def test_gradient_checkpointing_is_applied(self): | |||||||||
expected_set = {"HunyuanVideoTransformer3DModel"} | ||||||||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set) | ||||||||||
|
||||||||||
@require_torch_gpu | ||||||||||
@require_torch_2 | ||||||||||
@is_torch_compile | ||||||||||
@slow | ||||||||||
def test_torch_compile_recompilation_and_graph_break(self): | ||||||||||
torch._dynamo.reset() | ||||||||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() | ||||||||||
|
||||||||||
model = self.model_class(**init_dict).to(torch_device) | ||||||||||
model = torch.compile(model, fullgraph=True) | ||||||||||
|
||||||||||
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad(): | ||||||||||
_ = model(**inputs_dict) | ||||||||||
_ = model(**inputs_dict) | ||||||||||
|
||||||||||
|
||||||||||
class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase): | ||||||||||
class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests( | ||||||||||
ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase | ||||||||||
): | ||||||||||
model_class = HunyuanVideoTransformer3DModel | ||||||||||
main_input_name = "hidden_states" | ||||||||||
uses_custom_attn_processor = True | ||||||||||
|
@@ -342,18 +295,3 @@ def test_output(self): | |||||||||
def test_gradient_checkpointing_is_applied(self): | ||||||||||
expected_set = {"HunyuanVideoTransformer3DModel"} | ||||||||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set) | ||||||||||
|
||||||||||
@require_torch_gpu | ||||||||||
@require_torch_2 | ||||||||||
@is_torch_compile | ||||||||||
@slow | ||||||||||
def test_torch_compile_recompilation_and_graph_break(self): | ||||||||||
torch._dynamo.reset() | ||||||||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() | ||||||||||
|
||||||||||
model = self.model_class(**init_dict).to(torch_device) | ||||||||||
model = torch.compile(model, fullgraph=True) | ||||||||||
|
||||||||||
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad(): | ||||||||||
_ = model(**inputs_dict) | ||||||||||
_ = model(**inputs_dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small nit. We can remove the -k "compile" in the test runner step.