Skip to content

Commit b6d91fb

Browse files
authored
Merge pull request #90 from 907Resident/dev
GPU Info
2 parents 68de766 + f8cb88c commit b6d91fb

File tree

10 files changed

+106
-15
lines changed

10 files changed

+106
-15
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,9 @@ In line with [NEP 29][nep-29], this project supports:
176176

177177
[[top](#sections)]
178178

179+
#### v. 2.4.0 (May 23, 2023)
180+
181+
- Adds a new `--gpu` flag to print out GPU information (currently limited to NVIDIA devices) ([#90](https://github.com/rasbt/watermark/pull/63), via contribution by [907Resident](https://github.com/907Resident))
179182

180183

181184
#### v. 2.3.1 (May 27, 2022)

appveyor.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ install:
1818

1919
test_script:
2020
- set PYTHONPATH=%PYTHONPATH%;%CD%
21+
- pip install -e .
2122
- pytest -sv
2223

2324
notifications:

binder/requirements.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,4 @@ numpy
33
scipy
44
scikit-learn
55
jupyter
6-
7-
-e .
6+
py3nvml

requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
ipython >= 6.0
2+
importlib-metadata >= 1.4
3+
py3nvml >= 0.2

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[metadata]
2-
version = 2.3.1
2+
version = 2.4.0
33
license_file = LICENSE
44
classifiers =
55
Development Status :: 5 - Production/Stable

setup.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,20 @@
55
#
66
# License: BSD 3 clause
77

8+
from os.path import dirname, join, realpath
89
from textwrap import dedent
910

1011
from setuptools import find_packages, setup
1112

13+
14+
PROJECT_ROOT = dirname(realpath(__file__))
15+
REQUIREMENTS_FILE = join(PROJECT_ROOT, "requirements.txt")
16+
17+
with open(REQUIREMENTS_FILE) as f:
18+
install_reqs = f.read().splitlines()
19+
20+
install_reqs.append("setuptools")
21+
1222
# Also see settings in setup.cfg
1323
setup(
1424
name="watermark",
@@ -21,10 +31,7 @@
2131
author_email="[email protected]",
2232
url="https://github.com/rasbt/watermark",
2333
packages=find_packages(exclude=[]),
24-
install_requires=[
25-
"ipython",
26-
'importlib-metadata >= 1.4 ; python_version < "3.8"',
27-
],
34+
install_requires=install_reqs,
2835
long_description=dedent(
2936
"""\
3037
An IPython magic extension for printing date and time stamps, version

watermark/magic.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@ class WaterMark(Magics):
7070
help='prints the current version of watermark')
7171
@argument('-iv', '--iversions', action='store_true',
7272
help='prints the name/version of all imported modules')
73+
@argument('--gpu', action='store_true',
74+
help='prints GPU information (currently limited to NVIDIA GPUs),'
75+
' if available')
7376
@line_magic
7477
def watermark(self, line):
7578
"""

watermark/tests/test_watermark.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
# -*- coding: utf-8 -*-
22

3+
import sys
4+
import os
5+
6+
sys.path.append(os.path.join("../watermark"))
7+
38
import watermark
49

510

watermark/tests/test_watermark_gpu.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# -*- coding: utf-8 -*-
2+
import sys
3+
import os
4+
5+
sys.path.append(os.path.join("../watermark"))
6+
7+
import watermark
8+
9+
def test_gpu_info():
10+
a = watermark.watermark(gpu=True)
11+
txt = a.split('\n')
12+
clean_txt = []
13+
for t in txt:
14+
t = t.strip()
15+
if t:
16+
t = t.split(':')[0]
17+
clean_txt.append(t.strip())
18+
clean_txt = set(clean_txt)
19+
20+
expected = [
21+
'GPU Info',
22+
]
23+
24+
for i in expected:
25+
assert i in clean_txt, print(f'{i} not in {clean_txt}')

watermark/watermark.py

Lines changed: 53 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
import types
1919
from multiprocessing import cpu_count
2020
from socket import gethostname
21+
import platform
22+
from py3nvml import py3nvml
2123

2224
try:
2325
import importlib.metadata as importlib_metadata
@@ -30,14 +32,32 @@
3032
from .version import __version__
3133

3234

33-
def watermark(author=None, email=None, github_username=None,
34-
website=None, current_date=False, datename=False,
35-
current_time=False, iso8601=False, timezone=False,
36-
updated=False, custom_time=None, python=False,
37-
packages=None, conda=False, hostname=False, machine=False,
38-
githash=False, gitrepo=False, gitbranch=False,
39-
watermark=False, iversions=False, watermark_self=None,
40-
globals_=None):
35+
def watermark(
36+
author=None,
37+
email=None,
38+
github_username=None,
39+
website=None,
40+
current_date=False,
41+
datename=False,
42+
current_time=False,
43+
iso8601=False,
44+
timezone=False,
45+
updated=False,
46+
custom_time=None,
47+
python=False,
48+
packages=None,
49+
conda=False,
50+
hostname=False,
51+
machine=False,
52+
githash=False,
53+
gitrepo=False,
54+
gitbranch=False,
55+
watermark=False,
56+
iversions=False,
57+
gpu=False,
58+
watermark_self=None,
59+
globals_=None
60+
):
4161

4262
'''Function to print date/time stamps and various system information.
4363
@@ -107,6 +127,9 @@ def watermark(author=None, email=None, github_username=None,
107127
108128
iversions :
109129
prints the name/version of all imported modules
130+
131+
gpu :
132+
prints GPU information (currently limited to NVIDIA GPUs), if available
110133
111134
watermark_self :
112135
instance of the watermark magics class, which is required
@@ -182,6 +205,8 @@ def watermark(author=None, email=None, github_username=None,
182205
"to show imported package versions."
183206
)
184207
output.append(_get_all_import_versions(ns))
208+
if args['gpu']:
209+
output.append(_get_gpu_info())
185210
if args['watermark']:
186211
output.append({"Watermark": __version__})
187212

@@ -306,3 +331,23 @@ def _get_all_import_versions(vars):
306331
def _get_conda_env():
307332
name = os.getenv('CONDA_DEFAULT_ENV', 'n/a')
308333
return {"conda environment": name}
334+
335+
336+
def _get_gpu_info():
337+
try:
338+
gpu_info = [""]
339+
py3nvml.nvmlInit()
340+
num_gpus = py3nvml.nvmlDeviceGetCount()
341+
for i in range(num_gpus):
342+
handle = py3nvml.nvmlDeviceGetHandleByIndex(i)
343+
gpu_name = py3nvml.nvmlDeviceGetName(handle)
344+
gpu_info.append(f"GPU {i}: {gpu_name}")
345+
py3nvml.nvmlShutdown()
346+
return {"GPU Info": "\n ".join(gpu_info)}
347+
348+
except py3nvml.NVMLError_LibraryNotFound:
349+
return {"GPU Info": "NVIDIA drivers do not appear "
350+
"to be installed on this machine."}
351+
except:
352+
return {"GPU Info": "GPU information is not "
353+
"available for this machine."}

0 commit comments

Comments
 (0)