diff --git a/gpustat/core.py b/gpustat/core.py index de0ecbe..85c85cf 100644 --- a/gpustat/core.py +++ b/gpustat/core.py @@ -250,12 +250,13 @@ def jsonify(self): class GPUStatCollection(object): - def __init__(self, gpu_list): + def __init__(self, gpu_list, driver_version=None): self.gpus = gpu_list # attach additional system information self.hostname = platform.node() self.query_time = datetime.now() + self.driver_version = driver_version @staticmethod def new_query(): @@ -263,6 +264,11 @@ def new_query(): N.nvmlInit() + def _decode(b): + if isinstance(b, bytes): + return b.decode() # for python3, to unicode + return b + def get_gpu_info(handle): """Get one GPU information specified by nvml handle""" @@ -284,11 +290,6 @@ def get_process_info(nv_process): process['pid'] = nv_process.pid return process - def _decode(b): - if isinstance(b, bytes): - return b.decode() # for python3, to unicode - return b - name = _decode(N.nvmlDeviceGetName(handle)) uuid = _decode(N.nvmlDeviceGetUUID(handle)) @@ -374,8 +375,14 @@ def _decode(b): gpu_stat = GPUStat(gpu_info) gpu_list.append(gpu_stat) + # 2. additional info (driver version, etc). + try: + driver_version = _decode(N.nvmlSystemGetDriverVersion()) + except N.NVMLError: + driver_version = None # N/A + N.nvmlShutdown() - return GPUStatCollection(gpu_list) + return GPUStatCollection(gpu_list, driver_version=driver_version) def __len__(self): return len(self.gpus) @@ -424,15 +431,19 @@ def print_formatted(self, fp=sys.stdout, force_color=False, no_color=False, if show_header: time_format = locale.nl_langinfo(locale.D_T_FMT) - header_template = '{t.bold_white}{hostname:{width}}{t.normal} {timestr}' # noqa: E501 + header_template = '{t.bold_white}{hostname:{width}}{t.normal} ' + header_template += '{timestr} ' + header_template += '{t.bold_black}{driver_version}{t.normal}' + header_msg = header_template.format( hostname=self.hostname, width=gpuname_width + 3, # len("[?]") timestr=self.query_time.strftime(time_format), + driver_version=self.driver_version, t=t_color, ) - fp.write(header_msg) + fp.write(header_msg.strip()) fp.write(eol_char) # body diff --git a/gpustat/test_gpustat.py b/gpustat/test_gpustat.py index d4451f5..58cea63 100644 --- a/gpustat/test_gpustat.py +++ b/gpustat/test_gpustat.py @@ -44,6 +44,7 @@ def _configure_mock(N, Process, N.nvmlInit = MagicMock() N.nvmlShutdown = MagicMock() N.nvmlDeviceGetCount.return_value = 3 + N.nvmlSystemGetDriverVersion.return_value = '415.27.mock' mock_handles = ['mock-handle-%d' % i for i in range(3)]