From ea8878bc143d899a961e30ded59b2a8588a92e01 Mon Sep 17 00:00:00 2001 From: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com> Date: Wed, 19 Feb 2020 07:09:28 -0500 Subject: [PATCH] clean up tests/test_profiler.py (#867) * cleanup docstrings, _get_total_cprofile_duration in module * relax profiler overhead tolerance --- tests/test_profiler.py | 35 +++++++++++++++-------------------- 1 file changed, 15 insertions(+), 20 deletions(-) diff --git a/tests/test_profiler.py b/tests/test_profiler.py index 410e452bca577..5fae874e925b5 100644 --- a/tests/test_profiler.py +++ b/tests/test_profiler.py @@ -4,26 +4,26 @@ from pytorch_lightning.profiler import Profiler, AdvancedProfiler -PROFILER_OVERHEAD_MAX_TOLERANCE = 0.0001 +PROFILER_OVERHEAD_MAX_TOLERANCE = 0.001 @pytest.fixture def simple_profiler(): + """Creates a new profiler for every test with `simple_profiler` as an arg.""" profiler = Profiler() return profiler @pytest.fixture def advanced_profiler(): + """Creates a new profiler for every test with `advanced_profiler` as an arg.""" profiler = AdvancedProfiler() return profiler @pytest.mark.parametrize("action,expected", [("a", [3, 1]), ("b", [2]), ("c", [1])]) def test_simple_profiler_durations(simple_profiler, action, expected): - """ - ensure the reported durations are reasonably accurate - """ + """Ensure the reported durations are reasonably accurate.""" for duration in expected: with simple_profiler.profile(action): @@ -37,9 +37,7 @@ def test_simple_profiler_durations(simple_profiler, action, expected): def test_simple_profiler_overhead(simple_profiler, n_iter=5): - """ - ensure that the profiler doesn't introduce too much overhead during training - """ + """Ensure that the profiler doesn't introduce too much overhead during training.""" for _ in range(n_iter): with simple_profiler.profile("no-op"): pass @@ -49,16 +47,17 @@ def test_simple_profiler_overhead(simple_profiler, n_iter=5): def test_simple_profiler_describe(simple_profiler): - """ - ensure the profiler won't fail when reporting the summary - """ + """Ensure the profiler won't fail when reporting the summary.""" simple_profiler.describe() +def _get_total_cprofile_duration(profile): + return sum([x.totaltime for x in profile.getstats()]) + + @pytest.mark.parametrize("action,expected", [("a", [3, 1]), ("b", [2]), ("c", [1])]) def test_advanced_profiler_durations(advanced_profiler, action, expected): - def _get_total_duration(profile): - return sum([x.totaltime for x in profile.getstats()]) + """Ensure the reported durations are reasonably accurate.""" for duration in expected: with advanced_profiler.profile(action): @@ -66,7 +65,7 @@ def _get_total_duration(profile): # different environments have different precision when it comes to time.sleep() # see: https://github.com/PyTorchLightning/pytorch-lightning/issues/796 - recored_total_duration = _get_total_duration( + recored_total_duration = _get_total_cprofile_duration( advanced_profiler.profiled_actions[action] ) expected_total_duration = np.sum(expected) @@ -76,21 +75,17 @@ def _get_total_duration(profile): def test_advanced_profiler_overhead(advanced_profiler, n_iter=5): - """ - ensure that the profiler doesn't introduce too much overhead during training - """ + """Ensure that the profiler doesn't introduce too much overhead during training.""" for _ in range(n_iter): with advanced_profiler.profile("no-op"): pass action_profile = advanced_profiler.profiled_actions["no-op"] - total_duration = sum([x.totaltime for x in action_profile.getstats()]) + total_duration = _get_total_cprofile_duration(action_profile) average_duration = total_duration / n_iter assert average_duration < PROFILER_OVERHEAD_MAX_TOLERANCE def test_advanced_profiler_describe(advanced_profiler): - """ - ensure the profiler won't fail when reporting the summary - """ + """Ensure the profiler won't fail when reporting the summary.""" advanced_profiler.describe()