Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: handle logical CUDA device IDs for GPUStatsMonitor if CUDA_VISIBLE_DEVICES set #8260

Merged
merged 5 commits into from
Jul 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed the `GPUStatsMonitor` callbacks to use the correct GPU IDs if `CUDA_VISIBLE_DEVICES` set ([#8260](https://github.com/PyTorchLightning/pytorch-lightning/pull/8260))

- Fixed `lr_scheduler` checkpointed state by calling `update_lr_schedulers` before saving checkpoints ([#7877](https://github.com/PyTorchLightning/pytorch-lightning/pull/7877))


Expand Down
59 changes: 47 additions & 12 deletions pytorch_lightning/callbacks/gpu_stats_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,16 @@
import shutil
import subprocess
import time
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Optional, Tuple

import torch

import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import DeviceType, rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import AttributeDict
from pytorch_lightning.utilities.types import STEP_OUTPUT


class GPUStatsMonitor(Callback):
Expand Down Expand Up @@ -101,7 +105,13 @@ def __init__(
'temperature': temperature
})

def on_train_start(self, trainer, pl_module) -> None:
# The logical device IDs for selected devices
self._device_ids: List[int] = [] # will be assigned later in setup()

# The unmasked real GPU IDs
self._gpu_ids: List[str] = [] # will be assigned later in setup()

def setup(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', stage: Optional[str] = None) -> None:
if not trainer.logger:
raise MisconfigurationException('Cannot use GPUStatsMonitor callback with Trainer that has no logger.')

Expand All @@ -111,14 +121,20 @@ def on_train_start(self, trainer, pl_module) -> None:
f' since gpus attribute in Trainer is set to {trainer.gpus}.'
)

self._gpu_ids = ','.join(map(str, trainer.data_parallel_device_ids))
# The logical device IDs for selected devices
self._device_ids: List[int] = sorted(set(trainer.data_parallel_device_ids))

# The unmasked real GPU IDs
self._gpu_ids: List[int] = self._get_gpu_ids(self._device_ids)

def on_train_epoch_start(self, trainer, pl_module) -> None:
def on_train_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
self._snap_intra_step_time = None
self._snap_inter_step_time = None

@rank_zero_only
def on_train_batch_start(self, trainer, pl_module, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
def on_train_batch_start(
self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
if self._log_stats.intra_step_time:
self._snap_intra_step_time = time.time()

Expand All @@ -127,7 +143,7 @@ def on_train_batch_start(self, trainer, pl_module, batch: Any, batch_idx: int, d

gpu_stat_keys = self._get_gpu_stat_keys()
gpu_stats = self._get_gpu_stats([k for k, _ in gpu_stat_keys])
logs = self._parse_gpu_stats(self._gpu_ids, gpu_stats, gpu_stat_keys)
logs = self._parse_gpu_stats(self._device_ids, gpu_stats, gpu_stat_keys)

if self._log_stats.inter_step_time and self._snap_inter_step_time:
# First log at beginning of second step
Expand All @@ -137,7 +153,13 @@ def on_train_batch_start(self, trainer, pl_module, batch: Any, batch_idx: int, d

@rank_zero_only
def on_train_batch_end(
self, trainer, pl_module, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int
self,
trainer: 'pl.Trainer',
pl_module: 'pl.LightningModule',
outputs: STEP_OUTPUT,
batch: Any,
batch_idx: int,
dataloader_idx: int,
) -> None:
if self._log_stats.inter_step_time:
self._snap_inter_step_time = time.time()
Expand All @@ -147,19 +169,28 @@ def on_train_batch_end(

gpu_stat_keys = self._get_gpu_stat_keys() + self._get_gpu_device_stat_keys()
gpu_stats = self._get_gpu_stats([k for k, _ in gpu_stat_keys])
logs = self._parse_gpu_stats(self._gpu_ids, gpu_stats, gpu_stat_keys)
logs = self._parse_gpu_stats(self._device_ids, gpu_stats, gpu_stat_keys)

if self._log_stats.intra_step_time and self._snap_intra_step_time:
logs['batch_time/intra_step (ms)'] = (time.time() - self._snap_intra_step_time) * 1000

trainer.logger.log_metrics(logs, step=trainer.global_step)

@staticmethod
def _get_gpu_ids(device_ids: List[int]) -> List[str]:
"""Get the unmasked real GPU IDs"""
# All devices if `CUDA_VISIBLE_DEVICES` unset
default = ','.join(str(i) for i in range(torch.cuda.device_count()))
cuda_visible_devices: List[str] = os.getenv('CUDA_VISIBLE_DEVICES', default=default).split(',')
return [cuda_visible_devices[device_id].strip() for device_id in device_ids]

def _get_gpu_stats(self, queries: List[str]) -> List[List[float]]:
"""Run nvidia-smi to get the gpu stats"""
gpu_query = ','.join(queries)
format = 'csv,nounits,noheader'
gpu_ids = ','.join(self._gpu_ids)
result = subprocess.run(
[shutil.which('nvidia-smi'), f'--query-gpu={gpu_query}', f'--format={format}', f'--id={self._gpu_ids}'],
[shutil.which('nvidia-smi'), f'--query-gpu={gpu_query}', f'--format={format}', f'--id={gpu_ids}'],
encoding="utf-8",
stdout=subprocess.PIPE,
stderr=subprocess.PIPE, # for backward compatibility with python version 3.6
Expand All @@ -177,12 +208,16 @@ def _to_float(x: str) -> float:
return stats

@staticmethod
def _parse_gpu_stats(gpu_ids: str, stats: List[List[float]], keys: List[Tuple[str, str]]) -> Dict[str, float]:
def _parse_gpu_stats(
device_ids: List[int],
stats: List[List[float]],
keys: List[Tuple[str, str]],
) -> Dict[str, float]:
"""Parse the gpu stats into a loggable dict"""
logs = {}
for i, gpu_id in enumerate(gpu_ids.split(',')):
for i, device_id in enumerate(device_ids):
for j, (x, unit) in enumerate(keys):
logs[f'gpu_id: {gpu_id}/{x} ({unit})'] = stats[i][j]
logs[f'device_id: {device_id}/{x} ({unit})'] = stats[i][j]
return logs

def _get_gpu_stat_keys(self) -> List[Tuple[str, str]]:
Expand Down
37 changes: 35 additions & 2 deletions tests/callbacks/test_gpu_stats_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from unittest import mock

import numpy as np
import pytest
Expand Down Expand Up @@ -116,6 +117,38 @@ def test_gpu_stats_monitor_no_gpu_warning(tmpdir):


def test_gpu_stats_monitor_parse_gpu_stats():
XuehaiPan marked this conversation as resolved.
Show resolved Hide resolved
logs = GPUStatsMonitor._parse_gpu_stats('1,2', [[3, 4, 5], [6, 7]], [('gpu', 'a'), ('memory', 'b')])
expected = {'gpu_id: 1/gpu (a)': 3, 'gpu_id: 1/memory (b)': 4, 'gpu_id: 2/gpu (a)': 6, 'gpu_id: 2/memory (b)': 7}
logs = GPUStatsMonitor._parse_gpu_stats([1, 2], [[3, 4, 5], [6, 7]], [('gpu', 'a'), ('memory', 'b')])
expected = {
'device_id: 1/gpu (a)': 3,
'device_id: 1/memory (b)': 4,
'device_id: 2/gpu (a)': 6,
'device_id: 2/memory (b)': 7
}
assert logs == expected


@mock.patch.dict(os.environ, {})
@mock.patch('torch.cuda.is_available', return_value=True)
@mock.patch('torch.cuda.device_count', return_value=2)
def test_gpu_stats_monitor_get_gpu_ids_cuda_visible_devices_unset(device_count_mock, is_available_mock):
gpu_ids = GPUStatsMonitor._get_gpu_ids([1, 0])
expected = ['1', '0']
assert gpu_ids == expected


@mock.patch.dict(os.environ, {'CUDA_VISIBLE_DEVICES': '3,2,4'})
@mock.patch('torch.cuda.is_available', return_value=True)
@mock.patch('torch.cuda.device_count', return_value=3)
def test_gpu_stats_monitor_get_gpu_ids_cuda_visible_devices_integers(device_count_mock, is_available_mock):
gpu_ids = GPUStatsMonitor._get_gpu_ids([1, 2])
expected = ['2', '4']
assert gpu_ids == expected


@mock.patch.dict(os.environ, {'CUDA_VISIBLE_DEVICES': 'GPU-01a23b4c,GPU-56d78e9f,GPU-02a46c8e'})
@mock.patch('torch.cuda.is_available', return_value=True)
@mock.patch('torch.cuda.device_count', return_value=3)
def test_gpu_stats_monitor_get_gpu_ids_cuda_visible_devices_uuids(device_count_mock, is_available_mock):
gpu_ids = GPUStatsMonitor._get_gpu_ids([1, 2])
expected = ['GPU-56d78e9f', 'GPU-02a46c8e']
assert gpu_ids == expected