Skip to content

Commit ddcafd2

Browse files
Remove setting CUDA_VERSION to empty string and deprecated apt-key (#1486)
Remove setting CUDA_VERSION to empty string for CPU images. Remove getting deprecated apt-key.gpg key file from https://packages.cloud.google.com/
1 parent 8a20862 commit ddcafd2

File tree

3 files changed

+6
-11
lines changed

3 files changed

+6
-11
lines changed

Dockerfile.tmpl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -78,17 +78,9 @@ RUN apt-get install -y ocl-icd-libopencl1 clinfo && \
7878
uv pip install --system /tmp/lightgbm/*.whl && \
7979
rm -rf /tmp/lightgbm && \
8080
/tmp/clean-layer.sh
81-
82-
# Remove CUDA_VERSION from non-GPU image.
83-
{{ else }}
84-
ENV CUDA_VERSION=""
8581
{{ end }}
8682

8783

88-
# Update GPG key per documentation at https://cloud.google.com/compute/docs/troubleshooting/known-issues
89-
RUN curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -
90-
RUN curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key --keyring /usr/share/keyrings/cloud.google.gpg add -
91-
9284
# Use a fixed apt-get repo to stop intermittent failures due to flaky httpredir connections,
9385
# as described by Lionel Chan at http://stackoverflow.com/a/37426929/5881346
9486
RUN sed -i "s/httpredir.debian.org/debian.uchicago.edu/" /etc/apt/sources.list && \

tests/common.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@ def getAcceleratorName():
1111
except FileNotFoundError:
1212
return("nvidia-smi not found.")
1313

14-
gpu_test = unittest.skipIf(len(os.environ.get('CUDA_VERSION', '')) == 0, 'Not running GPU tests')
14+
def isGPU():
15+
return os.path.isfile('/proc/driver/nvidia/version')
16+
17+
gpu_test = unittest.skipIf(not isGPU(), 'Not running GPU tests')
1518
# b/342143152 P100s are slowly being unsupported in new release of popular ml tools such as RAPIDS.
1619
p100_exempt = unittest.skipIf(getAcceleratorName() == "Tesla P100-PCIE-16GB", 'Not running p100 exempt tests')
1720
tpu_test = unittest.skipIf(len(os.environ.get('ISTPUVM', '')) == 0, 'Not running TPU tests')

tests/test_jax.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import jax
77
import jax.numpy as np
88

9-
from common import gpu_test
9+
from common import gpu_test, isGPU
1010
from jax import grad, jit
1111

1212

@@ -21,5 +21,5 @@ def test_grad(self):
2121
self.assertEqual(0.4199743, ag)
2222

2323
def test_backend(self):
24-
expected_backend = 'cpu' if len(os.environ.get('CUDA_VERSION', '')) == 0 else 'gpu'
24+
expected_backend = 'cpu' if not isGPU() else 'gpu'
2525
self.assertEqual(expected_backend, jax.default_backend())

0 commit comments

Comments
 (0)