Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TEST: fix and extend tests for waveforms #14

Merged
merged 3 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ __pycache__
*.py[cod]

# Pytest
.coverage
.pytest_cache
9 changes: 6 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ classifiers = [
"Operating System :: OS Independent",
]
dependencies = [
"mne",
"colorednoise",
"mne",
]

[project.urls]
Expand All @@ -28,11 +28,14 @@ Issues = "https://github.com/ctrltz/meegsim/issues"

[project.optional-dependencies]
dev = [
"pytest"
"mock",
"pytest",
"pytest-cov"
]

[tool.pytest.ini_options]
pythonpath = "."
testpaths = [
"tests",
]
]
addopts = "--cov=src/meegsim --cov-report term-missing"
12 changes: 5 additions & 7 deletions src/meegsim/waveform.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,17 @@ def waveform_fn(n_series, times, *, kwarg1='aaa', kwarg2='bbb'):
Waveforms that are not urgent to have but could in principle be useful:
* non-sinusoidal stuff (harmonics, peak-trough asymmetry)
"""
import warnings

import colorednoise as cn
import numpy as np
import warnings

from scipy.signal import butter, filtfilt
import colorednoise as cn

from .utils import normalize_power, get_sfreq


def narrowband_oscillation(n_series, times, *, fmin=None, fmax=None, order=2, random_state=None):
#
# Ideas for tests
# Test 2 (order)
# - we could check that we pass correct value to filtfilt (requires mocking)
# - check the shape of the resulting array
"""
Generate time series in a requested frequency band by filtering white noise

Expand Down
21 changes: 19 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import numpy as np
import mne
import pytest

from meegsim.utils import combine_stcs, normalize_power
from meegsim.utils import combine_stcs, normalize_power, get_sfreq


def prepare_stc(vertices, num_samples=5):
Expand Down Expand Up @@ -52,4 +53,20 @@ def test_normalize_power():

# Should not change the shape but should change the norm
assert data.shape == normalized.shape
assert np.allclose(np.linalg.norm(normalized, axis=1), 1)
assert np.allclose(np.linalg.norm(normalized, axis=1), 1)


def test_get_sfreq():
sfreq = 250
times = np.arange(0, sfreq) / sfreq
assert get_sfreq(times) == sfreq


def test_get_sfreq_too_few_timepoints_raises():
with pytest.raises(ValueError, match='must contain at least two points'):
get_sfreq(np.array([0]))


def test_get_sfreq_unequal_spacing_raises():
with pytest.raises(ValueError, match='not uniformly spaced'):
get_sfreq(np.array([0, 0.01, 0.1]))
143 changes: 86 additions & 57 deletions tests/test_waveform.py
Original file line number Diff line number Diff line change
@@ -1,75 +1,88 @@
import numpy as np
from scipy.signal import welch
import pytest

from mock import patch
from scipy.signal import welch

from meegsim.utils import get_sfreq
from meegsim.waveform import white_noise, narrowband_oscillation, one_over_f_noise


def prepare_inputs():
def prepare_times(sfreq, duration):
n_times = sfreq * duration
times = np.arange(n_times) / sfreq
return n_times, times


@pytest.mark.parametrize(
"waveform,waveform_params",
[
(narrowband_oscillation, dict(fmin=8, fmax=12)),
(narrowband_oscillation, dict(fmin=16, fmax=24)),
(one_over_f_noise, dict(slope=1)),
(one_over_f_noise, dict(slope=2)),
(white_noise, dict()),
]
)
def test_waveforms_random_state(waveform, waveform_params):
"""
Test that all waveforms support random state.
"""
n_series = 10
n_times = 100
times = np.linspace(0, 1, num=n_times)
return n_series, n_times, times


def test_white_noise_shape():
n_series, n_times, times = prepare_inputs()

data = white_noise(n_series, times)
assert data.shape == (n_series, n_times)


def test_white_noise_random_state():
n_series, _, times = prepare_inputs()
_, times = prepare_times(sfreq=250, duration=30)

# Different time series are generated by default
data1 = white_noise(n_series, times)
data2 = white_noise(n_series, times)
data1 = waveform(n_series, times, **waveform_params)
data2 = waveform(n_series, times, **waveform_params)
assert not np.allclose(data1, data2)

# The results are reproducible when random_state is set
random_state = 1234567890
data1 = white_noise(n_series, times, random_state=random_state)
data2 = white_noise(n_series, times, random_state=random_state)
data1 = waveform(n_series, times, random_state=random_state, **waveform_params)
data2 = waveform(n_series, times, random_state=random_state, **waveform_params)
assert np.allclose(data1, data2)


def test_one_over_f_noise_shape():
n_series, n_times, times = prepare_inputs()
@pytest.mark.parametrize(
"waveform,waveform_params",
[
(narrowband_oscillation, dict(fmin=8, fmax=12)),
(narrowband_oscillation, dict(fmin=16, fmax=24)),
(one_over_f_noise, dict(slope=1)),
(one_over_f_noise, dict(slope=2)),
(white_noise, dict()),
]
)
def test_waveforms_shape(waveform, waveform_params):
"""
Test that the result of all waveform functions has correct shape.
"""
n_series = 10
n_times, times = prepare_times(sfreq=250, duration=30)

data = one_over_f_noise(n_series, times)
data = waveform(n_series, times, **waveform_params)
assert data.shape == (n_series, n_times)


def test_one_over_f_noise_random_state():
n_series, _, times = prepare_inputs()

# Different time series are generated by default
data1 = one_over_f_noise(n_series, times)
data2 = one_over_f_noise(n_series, times)
assert not np.allclose(data1, data2)

# The results are reproducible when random_state is set
random_state = 1234567890
data1 = one_over_f_noise(n_series, times, random_state=random_state)
data2 = one_over_f_noise(n_series, times, random_state=random_state)
assert np.allclose(data1, data2)


@pytest.mark.parametrize("fmin, fmax", [
(4.0, 7.0),
(8.0, 12.0),
(20.0, 30.0),
(15.0, 35.0),
])
def test_frequencies_in_band(fmin, fmax):
# Test that frequencies within the specified band have higher power
n_series, n_times, times = prepare_inputs()
def test_narrowband_oscillation_fmin_fmax(fmin, fmax):
"""
Test that frequencies within the specified band have higher power
than the rest of the spectra.
"""
n_series = 10
n_times, times = prepare_times(sfreq=250, duration=30)

data = narrowband_oscillation(n_series, times, fmin=fmin, fmax=fmax)
fs = get_sfreq(times)

# Calculate power spectral density
freqs, power = welch(data, fs=fs, axis=1)
fs = get_sfreq(times)
freqs, power = welch(data, fs=fs, nfft=fs, nperseg=fs, axis=1)

# Sort frequencies by power
sorted_freqs = freqs[np.argsort(power.mean(axis=0))[::-1]]
Expand All @@ -78,22 +91,38 @@ def test_frequencies_in_band(fmin, fmax):
band_fmin_fmax = (freqs >= fmin) & (freqs <= fmax)
band_freqs = sorted_freqs[:np.sum(band_fmin_fmax)]
assert len(band_freqs) > 0, "No frequencies found in the specified band."
assert np.all((band_freqs >= fmin) & (band_freqs <= fmax)), "Not all powerful frequencies are in the specified band."
assert np.all((band_freqs >= fmin) & (band_freqs <= fmax)), \
"Not all powerful frequencies are in the specified band."
assert data.shape == (n_series, n_times), "Shape mismatch"


def test_random_state_consistency():
# Test that fixing the random_state gives consistent results
n_series, n_times, times = prepare_inputs()
fs = 100.0
random_state = 42
data1 = narrowband_oscillation(n_series, times, fs, random_state=random_state)
data2 = narrowband_oscillation(n_series, times, fs, random_state=random_state)
assert np.allclose(data1, data2), "Results differ with the same random_state."
# return dummy values for the function to run
# import the functions from our module to resolve 'from ... import ...' definition
# more about: https://nedbatchelder.com/blog/201908/why_your_mock_doesnt_work.html
@patch('meegsim.waveform.filtfilt', return_value=np.ones((1, 100)))
@patch('meegsim.waveform.butter', return_value=(0, 0))
def test_narrowband_oscillation_order(butter_mock, filtfilt_mock):
_, times = prepare_times(sfreq=250, duration=30)

# order is set to 2 by default
narrowband_oscillation(n_series=10, times=times, fmin=8, fmax=12)
butter_mock.assert_called()
assert butter_mock.call_args.kwargs['N'] == 2

# custom slope value also should work
narrowband_oscillation(n_series=10, times=times, fmin=8, fmax=12, order=4)
assert butter_mock.call_args.kwargs['N'] == 4


# return a dummy value for normalize_power to work
@patch('colorednoise.powerlaw_psd_gaussian', return_value=np.ones((1, 100)))
def test_one_over_f_noise_slope(noise_mock):
_, times = prepare_times(sfreq=250, duration=30)

# Test that different random_state gives different results
data3 = narrowband_oscillation(n_series, times, fs)
data4 = narrowband_oscillation(n_series, times, fs)
assert not np.allclose(data3, data4), "Results should differ with different random_state."
assert data1.shape == (n_series, n_times), "Shape mismatch"
# slope is set to 1 by default
one_over_f_noise(n_series=10, times=times)
assert 1 in noise_mock.call_args.args

# custom slope value also should work
one_over_f_noise(n_series=10, times=times, slope=1.5)
assert 1.5 in noise_mock.call_args.args