File tree Expand file tree Collapse file tree 3 files changed +6
-11
lines changed Expand file tree Collapse file tree 3 files changed +6
-11
lines changed Original file line number Diff line number Diff line change @@ -78,17 +78,9 @@ RUN apt-get install -y ocl-icd-libopencl1 clinfo && \
78
78
uv pip install --system /tmp/lightgbm/*.whl && \
79
79
rm -rf /tmp/lightgbm && \
80
80
/tmp/clean-layer.sh
81
-
82
- # Remove CUDA_VERSION from non-GPU image.
83
- {{ else }}
84
- ENV CUDA_VERSION=""
85
81
{{ end }}
86
82
87
83
88
- # Update GPG key per documentation at https://cloud.google.com/compute/docs/troubleshooting/known-issues
89
- RUN curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -
90
- RUN curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key --keyring /usr/share/keyrings/cloud.google.gpg add -
91
-
92
84
# Use a fixed apt-get repo to stop intermittent failures due to flaky httpredir connections,
93
85
# as described by Lionel Chan at http://stackoverflow.com/a/37426929/5881346
94
86
RUN sed -i "s/httpredir.debian.org/debian.uchicago.edu/" /etc/apt/sources.list && \
Original file line number Diff line number Diff line change @@ -11,7 +11,10 @@ def getAcceleratorName():
11
11
except FileNotFoundError :
12
12
return ("nvidia-smi not found." )
13
13
14
- gpu_test = unittest .skipIf (len (os .environ .get ('CUDA_VERSION' , '' )) == 0 , 'Not running GPU tests' )
14
+ def isGPU ():
15
+ return os .path .isfile ('/proc/driver/nvidia/version' )
16
+
17
+ gpu_test = unittest .skipIf (not isGPU (), 'Not running GPU tests' )
15
18
# b/342143152 P100s are slowly being unsupported in new release of popular ml tools such as RAPIDS.
16
19
p100_exempt = unittest .skipIf (getAcceleratorName () == "Tesla P100-PCIE-16GB" , 'Not running p100 exempt tests' )
17
20
tpu_test = unittest .skipIf (len (os .environ .get ('ISTPUVM' , '' )) == 0 , 'Not running TPU tests' )
Original file line number Diff line number Diff line change 6
6
import jax
7
7
import jax .numpy as np
8
8
9
- from common import gpu_test
9
+ from common import gpu_test , isGPU
10
10
from jax import grad , jit
11
11
12
12
@@ -21,5 +21,5 @@ def test_grad(self):
21
21
self .assertEqual (0.4199743 , ag )
22
22
23
23
def test_backend (self ):
24
- expected_backend = 'cpu' if len ( os . environ . get ( 'CUDA_VERSION' , '' )) == 0 else 'gpu'
24
+ expected_backend = 'cpu' if not isGPU () else 'gpu'
25
25
self .assertEqual (expected_backend , jax .default_backend ())
You can’t perform that action at this time.
0 commit comments