Skip to content

[tests] Changes to the torch.compile() CI and tests #11508

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
May 26, 2025
Prev Previous commit
Next Next commit
propagate similar to the pipeline tests.
  • Loading branch information
sayakpaul committed May 6, 2025
commit 81109a574882321f86340f3b3ebcc80389d2c0b0
6 changes: 3 additions & 3 deletions tests/pipelines/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -2334,21 +2334,21 @@ def test_hotswapping_pipeline(self, rank0, rank1):
def test_hotswapping_compiled_pipline_linear(self, rank0, rank1):
# It's important to add this context to raise an error on recompilation
target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
with torch._dynamo.config.patch(error_on_recompile=True):
with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
self.check_pipeline_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)

@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
def test_hotswapping_compiled_pipline_conv2d(self, rank0, rank1):
# It's important to add this context to raise an error on recompilation
target_modules = ["conv", "conv1", "conv2"]
with torch._dynamo.config.patch(error_on_recompile=True):
with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
self.check_pipeline_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)

@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
def test_hotswapping_compiled_pipline_both_linear_and_conv2d(self, rank0, rank1):
# It's important to add this context to raise an error on recompilation
target_modules = ["to_q", "conv"]
with torch._dynamo.config.patch(error_on_recompile=True):
with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
self.check_pipeline_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)

def test_enable_lora_hotswap_called_after_adapter_added_raises(self):
Expand Down