Skip to content

Commit

Permalink
clean up tests/test_profiler.py (#867)
Browse files Browse the repository at this point in the history
* cleanup docstrings, _get_total_cprofile_duration in module

* relax profiler overhead tolerance
  • Loading branch information
jeremyjordan authored Feb 19, 2020
1 parent c58aab0 commit ea8878b
Showing 1 changed file with 15 additions and 20 deletions.
35 changes: 15 additions & 20 deletions tests/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,26 @@

from pytorch_lightning.profiler import Profiler, AdvancedProfiler

PROFILER_OVERHEAD_MAX_TOLERANCE = 0.0001
PROFILER_OVERHEAD_MAX_TOLERANCE = 0.001


@pytest.fixture
def simple_profiler():
"""Creates a new profiler for every test with `simple_profiler` as an arg."""
profiler = Profiler()
return profiler


@pytest.fixture
def advanced_profiler():
"""Creates a new profiler for every test with `advanced_profiler` as an arg."""
profiler = AdvancedProfiler()
return profiler


@pytest.mark.parametrize("action,expected", [("a", [3, 1]), ("b", [2]), ("c", [1])])
def test_simple_profiler_durations(simple_profiler, action, expected):
"""
ensure the reported durations are reasonably accurate
"""
"""Ensure the reported durations are reasonably accurate."""

for duration in expected:
with simple_profiler.profile(action):
Expand All @@ -37,9 +37,7 @@ def test_simple_profiler_durations(simple_profiler, action, expected):


def test_simple_profiler_overhead(simple_profiler, n_iter=5):
"""
ensure that the profiler doesn't introduce too much overhead during training
"""
"""Ensure that the profiler doesn't introduce too much overhead during training."""
for _ in range(n_iter):
with simple_profiler.profile("no-op"):
pass
Expand All @@ -49,24 +47,25 @@ def test_simple_profiler_overhead(simple_profiler, n_iter=5):


def test_simple_profiler_describe(simple_profiler):
"""
ensure the profiler won't fail when reporting the summary
"""
"""Ensure the profiler won't fail when reporting the summary."""
simple_profiler.describe()


def _get_total_cprofile_duration(profile):
return sum([x.totaltime for x in profile.getstats()])


@pytest.mark.parametrize("action,expected", [("a", [3, 1]), ("b", [2]), ("c", [1])])
def test_advanced_profiler_durations(advanced_profiler, action, expected):
def _get_total_duration(profile):
return sum([x.totaltime for x in profile.getstats()])
"""Ensure the reported durations are reasonably accurate."""

for duration in expected:
with advanced_profiler.profile(action):
time.sleep(duration)

# different environments have different precision when it comes to time.sleep()
# see: https://github.com/PyTorchLightning/pytorch-lightning/issues/796
recored_total_duration = _get_total_duration(
recored_total_duration = _get_total_cprofile_duration(
advanced_profiler.profiled_actions[action]
)
expected_total_duration = np.sum(expected)
Expand All @@ -76,21 +75,17 @@ def _get_total_duration(profile):


def test_advanced_profiler_overhead(advanced_profiler, n_iter=5):
"""
ensure that the profiler doesn't introduce too much overhead during training
"""
"""Ensure that the profiler doesn't introduce too much overhead during training."""
for _ in range(n_iter):
with advanced_profiler.profile("no-op"):
pass

action_profile = advanced_profiler.profiled_actions["no-op"]
total_duration = sum([x.totaltime for x in action_profile.getstats()])
total_duration = _get_total_cprofile_duration(action_profile)
average_duration = total_duration / n_iter
assert average_duration < PROFILER_OVERHEAD_MAX_TOLERANCE


def test_advanced_profiler_describe(advanced_profiler):
"""
ensure the profiler won't fail when reporting the summary
"""
"""Ensure the profiler won't fail when reporting the summary."""
advanced_profiler.describe()

0 comments on commit ea8878b

Please sign in to comment.