From 2dfe63fcf964fe249e16554b7863a8a716f1422a Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Wed, 14 Apr 2021 23:03:34 -0400 Subject: [PATCH] update savefig default to save when given file_name --- fooof/plts/utils.py | 6 +++++- fooof/tests/plts/test_utils.py | 20 ++++++++++++++------ 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/fooof/plts/utils.py b/fooof/plts/utils.py index 0a970ee9..b6c99616 100644 --- a/fooof/plts/utils.py +++ b/fooof/plts/utils.py @@ -181,10 +181,14 @@ def savefig(func): @wraps(func) def decorated(*args, **kwargs): - save_fig = kwargs.pop('save_fig', False) + # Grab file name and path arguments, if they are in kwargs file_name = kwargs.pop('file_name', None) file_path = kwargs.pop('file_path', None) + # Check for an explicit argument for whether to save figure or not + # Defaults to saving when file name given (since bool(str)->True; bool(None)->False) + save_fig = kwargs.pop('save_fig', bool(file_name)) + func(*args, **kwargs) if save_fig: diff --git a/fooof/tests/plts/test_utils.py b/fooof/tests/plts/test_utils.py index 0c4dce7f..ed03772f 100644 --- a/fooof/tests/plts/test_utils.py +++ b/fooof/tests/plts/test_utils.py @@ -1,12 +1,12 @@ """Tests for fooof.plts.utils.""" import os -import tempfile - -from fooof.tests.tutils import plot_test from fooof.core.modutils import safe_import +from fooof.tests.tutils import plot_test +from fooof.tests.settings import TEST_PLOTS_PATH + from fooof.plts.utils import * mpl = safe_import('matplotlib') @@ -79,6 +79,14 @@ def test_savefig(): def example_plot(): plt.plot([1, 2], [3, 4]) - with tempfile.NamedTemporaryFile(mode='w+') as file: - example_plot(save_fig=True, file_name=file.name) - assert os.path.exists(file.name) + # Test defaults to saving given file path & name + example_plot(file_path=TEST_PLOTS_PATH, file_name='test_savefig1.pdf') + assert os.path.exists(os.path.join(TEST_PLOTS_PATH, 'test_savefig1.pdf')) + + # Test works the same when explicitly given `save_fig` + example_plot(save_fig=True, file_path=TEST_PLOTS_PATH, file_name='test_savefig2.pdf') + assert os.path.exists(os.path.join(TEST_PLOTS_PATH, 'test_savefig2.pdf')) + + # Test does not save when `save_fig` set to False + example_plot(save_fig=False, file_path=TEST_PLOTS_PATH, file_name='test_savefig3.pdf') + assert not os.path.exists(os.path.join(TEST_PLOTS_PATH, 'test_savefig3.pdf'))