Skip to content

Commit 0af18f2

Browse files
ezyangpytorchmergebot
authored andcommitted
Unify TEST_CUDNN definition (pytorch#105594)
Signed-off-by: Edward Z. Yang <[email protected]> Pull Request resolved: pytorch#105594 Approved by: https://github.com/larryliu0820, https://github.com/voznesenskym
1 parent b64bd4a commit 0af18f2

File tree

3 files changed

+2
-16
lines changed

3 files changed

+2
-16
lines changed

test/jit/test_freezing.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from torch.testing._internal.common_quantization import skipIfNoFBGEMM
1414
from torch.testing._internal.common_quantized import override_quantized_engine
1515
from torch.testing._internal.common_utils import set_default_dtype, skipCUDAMemoryLeakCheckIf, TEST_WITH_ROCM
16+
from torch.testing._internal.common_cuda import TEST_CUDNN
1617
from torch.testing._internal.jit_utils import JitTestCase
1718
from torch.utils import mkldnn as mkldnn_utils
1819

@@ -30,10 +31,6 @@
3031

3132
TEST_CUDA = torch.cuda.is_available()
3233
TEST_ROCM = torch.cuda.is_available() and torch.version.hip is not None
33-
TEST_CUDNN = False
34-
if TEST_CUDA and not TEST_ROCM: # Skip ROCM
35-
torch.ones(1).cuda() # initialize cuda context
36-
TEST_CUDNN = TEST_CUDA and torch.backends.cudnn.is_acceptable(torch.tensor(1., device=torch.device('cuda:0')))
3734

3835
def removeExceptions(graph):
3936
for n in graph.findAllNodes('prim::RaiseException'):

test/test_cpp_extensions_jit.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import glob
1212

1313
import torch.testing._internal.common_utils as common
14+
from torch.testing._internal.common_cuda import TEST_CUDNN
1415
import torch
1516
import torch.backends.cudnn
1617
import torch.utils.cpp_extension
@@ -20,13 +21,7 @@
2021

2122

2223
TEST_CUDA = torch.cuda.is_available() and CUDA_HOME is not None
23-
TEST_CUDNN = False
2424
TEST_ROCM = torch.cuda.is_available() and torch.version.hip is not None and ROCM_HOME is not None
25-
if TEST_CUDA and torch.version.cuda is not None: # the skip CUDNN test for ROCm
26-
CUDNN_HEADER_EXISTS = os.path.isfile(os.path.join(CUDA_HOME, "include/cudnn.h"))
27-
TEST_CUDNN = (
28-
TEST_CUDA and CUDNN_HEADER_EXISTS and torch.backends.cudnn.is_available()
29-
)
3025
TEST_MPS = torch.backends.mps.is_available()
3126
IS_WINDOWS = sys.platform == "win32"
3227

test/test_cpp_extensions_open_device_registration.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,7 @@
1515

1616

1717
TEST_CUDA = torch.cuda.is_available() and CUDA_HOME is not None
18-
TEST_CUDNN = False
1918
TEST_ROCM = torch.cuda.is_available() and torch.version.hip is not None and ROCM_HOME is not None
20-
if TEST_CUDA and torch.version.cuda is not None: # the skip CUDNN test for ROCm
21-
CUDNN_HEADER_EXISTS = os.path.isfile(os.path.join(CUDA_HOME, "include/cudnn.h"))
22-
TEST_CUDNN = (
23-
TEST_CUDA and CUDNN_HEADER_EXISTS and torch.backends.cudnn.is_available()
24-
)
2519

2620

2721
def remove_build_path():

0 commit comments

Comments
 (0)