From 2b825ac69ed15f60e1f9e0cc141cdc6c0ae95fe1 Mon Sep 17 00:00:00 2001 From: Nikolai Kapralov <4dvlup@gmail.com> Date: Fri, 9 Aug 2024 16:39:56 +0200 Subject: [PATCH 1/3] TEST: fix and extend tests for waveforms --- .gitignore | 1 + pyproject.toml | 9 ++- src/meegsim/waveform.py | 12 ++-- tests/test_utils.py | 23 ++++++- tests/test_waveform.py | 143 ++++++++++++++++++++++++---------------- 5 files changed, 119 insertions(+), 69 deletions(-) diff --git a/.gitignore b/.gitignore index 88467f7..c30f27b 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,5 @@ __pycache__ *.py[cod] # Pytest +.coverage .pytest_cache \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 61d67cf..661b71c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,8 +18,8 @@ classifiers = [ "Operating System :: OS Independent", ] dependencies = [ - "mne", "colorednoise", + "mne", ] [project.urls] @@ -28,11 +28,14 @@ Issues = "https://github.com/ctrltz/meegsim/issues" [project.optional-dependencies] dev = [ - "pytest" + "mock", + "pytest", + "pytest-cov" ] [tool.pytest.ini_options] pythonpath = "." testpaths = [ "tests", -] \ No newline at end of file +] +addopts = "--cov=src/meegsim --cov-report term-missing" \ No newline at end of file diff --git a/src/meegsim/waveform.py b/src/meegsim/waveform.py index 9db938f..b24d2b2 100644 --- a/src/meegsim/waveform.py +++ b/src/meegsim/waveform.py @@ -17,19 +17,17 @@ def waveform_fn(n_series, times, *, kwarg1='aaa', kwarg2='bbb'): Waveforms that are not urgent to have but could in principle be useful: * non-sinusoidal stuff (harmonics, peak-trough asymmetry) """ -import warnings + +import colorednoise as cn import numpy as np +import warnings + from scipy.signal import butter, filtfilt -import colorednoise as cn + from .utils import normalize_power, get_sfreq def narrowband_oscillation(n_series, times, *, fmin=None, fmax=None, order=2, random_state=None): - # - # Ideas for tests - # Test 2 (order) - # - we could check that we pass correct value to filtfilt (requires mocking) - # - check the shape of the resulting array """ Generate time series in a requested frequency band by filtering white noise diff --git a/tests/test_utils.py b/tests/test_utils.py index 8422308..df26c37 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,7 +1,8 @@ import numpy as np import mne +import pytest -from meegsim.utils import combine_stcs, normalize_power +from meegsim.utils import combine_stcs, normalize_power, get_sfreq def prepare_stc(vertices, num_samples=5): @@ -52,4 +53,22 @@ def test_normalize_power(): # Should not change the shape but should change the norm assert data.shape == normalized.shape - assert np.allclose(np.linalg.norm(normalized, axis=1), 1) \ No newline at end of file + assert np.allclose(np.linalg.norm(normalized, axis=1), 1) + + +def test_get_sfreq(): + sfreq = 250 + times = np.arange(0, sfreq) / sfreq + assert get_sfreq(times) == sfreq + + +def test_get_sfreq_too_few_timepoints_raises(): + with pytest.raises(ValueError) as excinfo: + get_sfreq(np.array([0])) + assert "must contain at least two points" in str(excinfo.value) + + +def test_get_sfreq_unequal_spacing_raises(): + with pytest.raises(ValueError) as excinfo: + get_sfreq(np.array([0, 0.01, 0.1])) + assert "not uniformly spaced" in str(excinfo.value) diff --git a/tests/test_waveform.py b/tests/test_waveform.py index db3b882..07d3950 100644 --- a/tests/test_waveform.py +++ b/tests/test_waveform.py @@ -1,75 +1,88 @@ import numpy as np -from scipy.signal import welch import pytest +from mock import patch +from scipy.signal import welch + from meegsim.utils import get_sfreq from meegsim.waveform import white_noise, narrowband_oscillation, one_over_f_noise -def prepare_inputs(): +def prepare_times(sfreq, duration): + n_times = sfreq * duration + times = np.arange(n_times) / sfreq + return n_times, times + + +@pytest.mark.parametrize( + "waveform,waveform_params", + [ + (narrowband_oscillation, dict(fmin=8, fmax=12)), + (narrowband_oscillation, dict(fmin=16, fmax=24)), + (one_over_f_noise, dict(slope=1)), + (one_over_f_noise, dict(slope=2)), + (white_noise, dict()), + ] +) +def test_waveforms_random_state(waveform, waveform_params): + """ + Test that all waveforms support random state. + """ n_series = 10 - n_times = 100 - times = np.linspace(0, 1, num=n_times) - return n_series, n_times, times - - -def test_white_noise_shape(): - n_series, n_times, times = prepare_inputs() - - data = white_noise(n_series, times) - assert data.shape == (n_series, n_times) - - -def test_white_noise_random_state(): - n_series, _, times = prepare_inputs() + _, times = prepare_times(sfreq=250, duration=30) # Different time series are generated by default - data1 = white_noise(n_series, times) - data2 = white_noise(n_series, times) + data1 = waveform(n_series, times, **waveform_params) + data2 = waveform(n_series, times, **waveform_params) assert not np.allclose(data1, data2) # The results are reproducible when random_state is set random_state = 1234567890 - data1 = white_noise(n_series, times, random_state=random_state) - data2 = white_noise(n_series, times, random_state=random_state) + data1 = waveform(n_series, times, random_state=random_state, **waveform_params) + data2 = waveform(n_series, times, random_state=random_state, **waveform_params) assert np.allclose(data1, data2) -def test_one_over_f_noise_shape(): - n_series, n_times, times = prepare_inputs() +@pytest.mark.parametrize( + "waveform,waveform_params", + [ + (narrowband_oscillation, dict(fmin=8, fmax=12)), + (narrowband_oscillation, dict(fmin=16, fmax=24)), + (one_over_f_noise, dict(slope=1)), + (one_over_f_noise, dict(slope=2)), + (white_noise, dict()), + ] +) +def test_waveforms_shape(waveform, waveform_params): + """ + Test that the result of all waveform functions has correct shape. + """ + n_series = 10 + n_times, times = prepare_times(sfreq=250, duration=30) - data = one_over_f_noise(n_series, times) + data = waveform(n_series, times, **waveform_params) assert data.shape == (n_series, n_times) -def test_one_over_f_noise_random_state(): - n_series, _, times = prepare_inputs() - - # Different time series are generated by default - data1 = one_over_f_noise(n_series, times) - data2 = one_over_f_noise(n_series, times) - assert not np.allclose(data1, data2) - - # The results are reproducible when random_state is set - random_state = 1234567890 - data1 = one_over_f_noise(n_series, times, random_state=random_state) - data2 = one_over_f_noise(n_series, times, random_state=random_state) - assert np.allclose(data1, data2) - - @pytest.mark.parametrize("fmin, fmax", [ (4.0, 7.0), (8.0, 12.0), (20.0, 30.0), (15.0, 35.0), ]) -def test_frequencies_in_band(fmin, fmax): - # Test that frequencies within the specified band have higher power - n_series, n_times, times = prepare_inputs() +def test_narrowband_oscillation_fmin_fmax(fmin, fmax): + """ + Test that frequencies within the specified band have higher power + than the rest of the spectra. + """ + n_series = 10 + n_times, times = prepare_times(sfreq=250, duration=30) + data = narrowband_oscillation(n_series, times, fmin=fmin, fmax=fmax) - fs = get_sfreq(times) + # Calculate power spectral density - freqs, power = welch(data, fs=fs, axis=1) + fs = get_sfreq(times) + freqs, power = welch(data, fs=fs, nfft=fs, nperseg=fs, axis=1) # Sort frequencies by power sorted_freqs = freqs[np.argsort(power.mean(axis=0))[::-1]] @@ -78,22 +91,38 @@ def test_frequencies_in_band(fmin, fmax): band_fmin_fmax = (freqs >= fmin) & (freqs <= fmax) band_freqs = sorted_freqs[:np.sum(band_fmin_fmax)] assert len(band_freqs) > 0, "No frequencies found in the specified band." - assert np.all((band_freqs >= fmin) & (band_freqs <= fmax)), "Not all powerful frequencies are in the specified band." + assert np.all((band_freqs >= fmin) & (band_freqs <= fmax)), \ + "Not all powerful frequencies are in the specified band." assert data.shape == (n_series, n_times), "Shape mismatch" -def test_random_state_consistency(): - # Test that fixing the random_state gives consistent results - n_series, n_times, times = prepare_inputs() - fs = 100.0 - random_state = 42 - data1 = narrowband_oscillation(n_series, times, fs, random_state=random_state) - data2 = narrowband_oscillation(n_series, times, fs, random_state=random_state) - assert np.allclose(data1, data2), "Results differ with the same random_state." +# return dummy values for the function to run +# import the functions from our module to resolve 'from ... import ...' definition +# more about: https://nedbatchelder.com/blog/201908/why_your_mock_doesnt_work.html +@patch('meegsim.waveform.filtfilt', return_value=np.ones((1, 100))) +@patch('meegsim.waveform.butter', return_value=(0, 0)) +def test_narrowband_oscillation_order(butter_mock, filtfilt_mock): + _, times = prepare_times(sfreq=250, duration=30) + + # order is set to 2 by default + narrowband_oscillation(n_series=10, times=times, fmin=8, fmax=12) + butter_mock.assert_called() + assert butter_mock.call_args.kwargs['N'] == 2 + + # custom slope value also should work + narrowband_oscillation(n_series=10, times=times, fmin=8, fmax=12, order=4) + assert butter_mock.call_args.kwargs['N'] == 4 + + +# return a dummy value for normalize_power to work +@patch('colorednoise.powerlaw_psd_gaussian', return_value=np.ones((1, 100))) +def test_one_over_f_noise_slope(noise_mock): + _, times = prepare_times(sfreq=250, duration=30) - # Test that different random_state gives different results - data3 = narrowband_oscillation(n_series, times, fs) - data4 = narrowband_oscillation(n_series, times, fs) - assert not np.allclose(data3, data4), "Results should differ with different random_state." - assert data1.shape == (n_series, n_times), "Shape mismatch" + # slope is set to 1 by default + one_over_f_noise(n_series=10, times=times) + assert 1 in noise_mock.call_args.args + # custom slope value also should work + one_over_f_noise(n_series=10, times=times, slope=1.5) + assert 1.5 in noise_mock.call_args.args From 5c3aa9a13f11e7204b8fa13c9e55cdea728cecbc Mon Sep 17 00:00:00 2001 From: Nikolai Kapralov <4dvlup@gmail.com> Date: Fri, 9 Aug 2024 17:11:05 +0200 Subject: [PATCH 2/3] Use match in pytest.raises --- tests/test_utils.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index df26c37..2ebd1c5 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -63,12 +63,10 @@ def test_get_sfreq(): def test_get_sfreq_too_few_timepoints_raises(): - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ValueError, match='must contain at least two points'): get_sfreq(np.array([0])) - assert "must contain at least two points" in str(excinfo.value) def test_get_sfreq_unequal_spacing_raises(): - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ValueError, match='not uniformly spaced') get_sfreq(np.array([0, 0.01, 0.1])) - assert "not uniformly spaced" in str(excinfo.value) From d74d34c7b0550f76b97f9cfda49e770af55cd079 Mon Sep 17 00:00:00 2001 From: Nikolai Kapralov <4dvlup@gmail.com> Date: Fri, 9 Aug 2024 17:11:31 +0200 Subject: [PATCH 3/3] missing semicolon --- tests/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 2ebd1c5..21d5bf0 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -68,5 +68,5 @@ def test_get_sfreq_too_few_timepoints_raises(): def test_get_sfreq_unequal_spacing_raises(): - with pytest.raises(ValueError, match='not uniformly spaced') + with pytest.raises(ValueError, match='not uniformly spaced'): get_sfreq(np.array([0, 0.01, 0.1]))