Skip to content

Commit 6152c90

Browse files
filter torch cuda arch < 6.0 (#955)
* filter arch < 6.0 * remove unused codes
1 parent 748a9c7 commit 6152c90

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

setup.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717

1818
TORCH_CUDA_ARCH_LIST = os.environ.get("TORCH_CUDA_ARCH_LIST")
1919

20+
if TORCH_CUDA_ARCH_LIST:
21+
arch_list = [arch for arch in TORCH_CUDA_ARCH_LIST.split() if float(arch.split('+')[0]) >= 6.0]
22+
os.environ["TORCH_CUDA_ARCH_LIST"] = " ".join(arch_list)
23+
2024
version_vars = {}
2125
exec("exec(open('gptqmodel/version.py').read()); version=__version__", {}, version_vars)
2226
gptqmodel_version = version_vars['version']
@@ -109,6 +113,7 @@ def get_version_tag(is_cuda_release: bool = True) -> str:
109113
if got_cuda_between_v6_and_v8:
110114
FORCE_BUILD = True
111115

116+
112117
if BUILD_CUDA_EXT:
113118
if CUDA_RELEASE == "1":
114119
common_setup_kwargs["version"] += f"+{get_version_tag(True)}"

0 commit comments

Comments
 (0)