Skip to content

Commit

Permalink
[bugfix] Properly name PyTorchProfiler traces (#8009)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Carlos Mocholi <[email protected]>
  • Loading branch information
3 people authored and lexierule committed Jun 22, 2021
1 parent 7c58c4d commit 591d617
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 12 deletions.
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.distributed import _warn, rank_zero_info, rank_zero_only
from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE

Expand Down
17 changes: 11 additions & 6 deletions pytorch_lightning/profiler/profilers.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,19 @@ def _rank_zero_info(self, *args, **kwargs) -> None:
if self._local_rank in (None, 0):
log.info(*args, **kwargs)

def _prepare_filename(self, extension: str = ".txt") -> str:
filename = ""
def _prepare_filename(
self, action_name: Optional[str] = None, extension: str = ".txt", split_token: str = "-"
) -> str:
args = []
if self._stage is not None:
filename += f"{self._stage}-"
filename += str(self.filename)
args.append(self._stage)
if self.filename:
args.append(self.filename)
if self._local_rank is not None:
filename += f"-{self._local_rank}"
filename += extension
args.append(str(self._local_rank))
if action_name is not None:
args.append(action_name)
filename = split_token.join(args) + extension
return filename

def _prepare_streams(self) -> None:
Expand Down
8 changes: 6 additions & 2 deletions pytorch_lightning/profiler/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,11 +427,15 @@ def stop(self, action_name: str) -> None:
def on_trace_ready(profiler):
if self.dirpath is not None:
if self._export_to_chrome:
handler = tensorboard_trace_handler(self.dirpath, self._prepare_filename(extension=""))
handler = tensorboard_trace_handler(
self.dirpath, self._prepare_filename(action_name=action_name, extension="")
)
handler(profiler)

if self._export_to_flame_graph:
path = os.path.join(self.dirpath, self._prepare_filename(extension=".stack"))
path = os.path.join(
self.dirpath, self._prepare_filename(action_name=action_name, extension=".stack")
)
profiler.export_stacks(path, metric=self._metric)
else:
rank_zero_warn("The PyTorchProfiler failed to export trace as `dirpath` is None")
Expand Down
2 changes: 1 addition & 1 deletion tests/deprecated_api/test_remove_1-5.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from pytorch_lightning.utilities import device_parser
from pytorch_lightning.utilities.imports import _compare_version
from tests.deprecated_api import no_deprecated_call
from tests.helpers import BoringDataModule, BoringModel
from tests.helpers import BoringModel
from tests.helpers.utils import no_warning_call


Expand Down
4 changes: 2 additions & 2 deletions tests/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,8 +330,8 @@ def test_pytorch_profiler_trainer_ddp(tmpdir, pytorch_profiler):
files = [file for file in files if file.endswith('.json')]
assert len(files) == 2, files
local_rank = trainer.local_rank
assert any(f'training_step_{local_rank}' in f for f in files)
assert any(f'validation_step_{local_rank}' in f for f in files)
assert any(f'{local_rank}-training_step_and_backward' in f for f in files)
assert any(f'{local_rank}-validation_step' in f for f in files)


def test_pytorch_profiler_trainer_test(tmpdir):
Expand Down

0 comments on commit 591d617

Please sign in to comment.