From 60b3300d30f51542bcae2b88086af13911adec70 Mon Sep 17 00:00:00 2001 From: Nikolai Kapralov <4dvlup@gmail.com> Date: Thu, 19 Sep 2024 17:17:43 +0200 Subject: [PATCH] ENH: adjust the SNR in SourceSimulator.simulate() (#31) Major: * connected the low-level SNR functions to the SourceSimulator * added checks for snr and snr_params * sensor space variance is now calculated for the sum of all sources, not the mean * sensor space variance is now calculated based on source space covariance (this will also be relevant for patches) Minor: * removed the default value for target SNR * removed the default frequency band for the adjustment of SNR * extracted dummy forward function to utils.prepare in tests * renamed adjust_snr to amplitude_adjustment --- CHANGELOG.md | 2 +- examples/dummy.py | 49 ++++++---- src/meegsim/_check.py | 76 +++++++++++++-- src/meegsim/simulate.py | 93 +++++++++++++----- src/meegsim/snr.py | 81 ++++++++++++---- src/meegsim/source_groups.py | 10 +- tests/test_check.py | 54 ++++++++++- tests/test_simulate.py | 126 +++++++++++++++++++++--- tests/test_snr.py | 183 ++++++++++++++++------------------- tests/utils/mocks.py | 14 +++ tests/utils/prepare.py | 51 ++++++++++ 11 files changed, 542 insertions(+), 197 deletions(-) create mode 100644 tests/utils/mocks.py diff --git a/CHANGELOG.md b/CHANGELOG.md index fcd95eb..e7fc424 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - SourceSimulator class that allows adding point sources with custom locations and waveforms to the simulation ([3bad4a8](https://github.com/ctrltz/meegsim/commit/3bad4a86a3712beb43fb404481c15e1a54250d87), [#24](https://github.com/ctrltz/meegsim/pull/24)) - Waveforms of white noise, narrowband oscillation (white noise filtered in a narrow frequency band) and 1/f noise with adjustable slope ([#8](https://github.com/ctrltz/meegsim/pull/8)) - Random vertices in the whole source space or in a subset of vertices as location for point sources ([#10](https://github.com/ctrltz/meegsim/pull/10)) -- Adjustment of the SNR of the point sources based on sensor space power ([#9](https://github.com/ctrltz/meegsim/pull/9)) +- Adjustment of the SNR of the point sources based on sensor space power ([#9](https://github.com/ctrltz/meegsim/pull/9), [#31](https://github.com/ctrltz/meegsim/pull/31)) - Phase-phase coupling with a constant phase lag or a probabilistic phase lag according to the von Mises distribution ([#11](https://github.com/ctrltz/meegsim/pull/11)) - Traversal of the coupling graph to ensure that the coupling is set up correctly when multiple connectivity edges are defined ([#12](https://github.com/ctrltz/meegsim/pull/12)) diff --git a/examples/dummy.py b/examples/dummy.py index 500971c..a4a2e0d 100644 --- a/examples/dummy.py +++ b/examples/dummy.py @@ -3,10 +3,11 @@ """ import json -import numpy as np import mne +import numpy as np from pathlib import Path +from scipy.signal import butter, filtfilt from meegsim.location import select_random from meegsim.simulate import SourceSimulator @@ -27,6 +28,10 @@ def to_json(sources): # Simulation parameters sfreq = 250 duration = 60 +seed = 1234 +target_snr = 20 + +b, a = butter(4, 2 * np.array([8, 12]) / sfreq, 'bandpass') # Channel info montage = mne.channels.make_standard_montage('standard_1020') @@ -40,29 +45,37 @@ def to_json(sources): sim = SourceSimulator(src) -# Select some vertices randomly (signal/noise does not matter for now) -sim.add_point_sources( - location=select_random, - waveform=narrowband_oscillation, - location_params=dict(n=10, vertices=[list(src[0]['vertno']), []]), - waveform_params=dict(fmin=8, fmax=12) -) +# Add noise sources sim.add_noise_sources( location=select_random, - location_params=dict(n=10, vertices=[[], list(src[1]['vertno'])]) + location_params=dict(n=10) ) -# Print the source groups to check internal structure -print(f'Source groups: {sim._source_groups}') -print(f'Noise groups: {sim._noise_groups}') +sc_noise = sim.simulate(sfreq, duration, random_state=seed) +raw_noise = sc_noise.to_raw(fwd, info) + +# Select some vertices randomly +sim.add_point_sources( + location=select_random, + waveform=narrowband_oscillation, + location_params=dict(n=1), + waveform_params=dict(fmin=8, fmax=12), + snr=target_snr, + snr_params=dict(fmin=8, fmax=12) +) -sc = sim.simulate(sfreq, duration, random_state=0) +sc_full = sim.simulate(sfreq, duration, fwd=fwd, random_state=seed) +raw_full = sc_full.to_raw(fwd, info) -# Print the sources to check internal structure -print(f'Simulated sources: {to_json(sc._sources)}') -print(f'Simulated noise sources: {to_json(sc._noise_sources)}') +n_samples = sc_full.times.size +noise_data = filtfilt(b, a, raw_noise.get_data()) +cov_raw_noise = (noise_data @ noise_data.T) / n_samples +full_data = filtfilt(b, a, raw_full.get_data()) +cov_raw_full = (full_data @ full_data.T) / n_samples +snr = np.mean(np.diag(cov_raw_full)) / np.mean(np.diag(cov_raw_noise)) - 1 +print(f'Target SNR = {target_snr:.2f}') +print(f'Actual SNR = {snr:.2f}') -raw = sc.to_raw(fwd, info) -spec = raw.compute_psd(n_fft=sfreq, n_overlap=sfreq//2, n_per_seg=sfreq) +spec = raw_full.compute_psd(n_fft=sfreq, n_overlap=sfreq//2, n_per_seg=sfreq) spec.plot(sphere='eeglab') input('Press any key to continue') \ No newline at end of file diff --git a/src/meegsim/_check.py b/src/meegsim/_check.py index c87bcdd..9e4f167 100644 --- a/src/meegsim/_check.py +++ b/src/meegsim/_check.py @@ -257,20 +257,76 @@ def check_names(names, n_sources, existing): raise ValueError('All names should be unique') -def check_snr(snr, n_vertices): - if snr is not None: - raise NotImplementedError('Adjustment of SNR is not supported yet') - # TODO: check that the number of SNR values matches the number of vertices - # or it is a single SNR value that can be applied to all vertices +def check_snr(snr, n_sources): + """ + Check the user input for SNR: it can either be None (no adjustment of SNR), + a single float value that applies to all sources or an array of values + with one for each source. + + Parameters + ---------- + snr: None, float, or array + The provided value(s) for SNR + n_sources: int + The number of sources. + + Raises + ------ + ValueError + If the provided SNR value(s) do not follow the format described above. + """ + + if snr is None: + return None + + snr = np.ravel(np.array(snr)) + if snr.size != 1 and snr.size != n_sources: + raise ValueError( + f'Expected either one SNR value that applies to all sources or ' + f'one SNR value for each of the {n_sources} sources, got {snr.size}' + ) + + # Only positive values make sense, raise error if negative ones are provided + if np.any(snr < 0): + raise ValueError('Each SNR value should be positive') + + # Broadcast to all sources if a single value was provided + if snr.size == 1: + snr = np.tile(snr, (n_sources,)) return snr -def check_snr_params(snr_params): - # TODO: we could try to extract fmin and fmax from waveform_params but - # not sure how confusing will it be, a dedicated waveform class could be - # easier to understand - pass +def check_snr_params(snr_params, snr): + """ + Check the user input for SNR parameters: if the SNR is adjusted (i.e., not None), + then fmin and fmax should be present in the dictionary to define a frequency band. + + Parameters + ---------- + snr_params: dict + The provided dictionary with parameters of the SNR adjustment. + snr: None, float, or array + The provided value for SNR + + Raises + ------ + ValueError + If the provided snr_params dictionary does not have the necessary parameters. + """ + if snr is None: + return snr_params + + if 'fmin' not in snr_params or 'fmax' not in snr_params: + raise ValueError( + 'Frequency band limits are required for the adjustment of SNR. ' + 'Please add fmin and fmax to the snr_params dictionary.' + ) + + if snr_params['fmin'] < 0 or snr_params['fmax'] < 0: + raise ValueError('Frequency limits should be positive') + + return snr_params def check_coupling(): diff --git a/src/meegsim/simulate.py b/src/meegsim/simulate.py index 6c85acc..3aa57f2 100644 --- a/src/meegsim/simulate.py +++ b/src/meegsim/simulate.py @@ -1,4 +1,5 @@ from .configuration import SourceConfiguration +from .snr import _adjust_snr from .source_groups import PointSourceGroup from .waveform import one_over_f_noise @@ -28,6 +29,10 @@ def __init__(self, src): # Store all coupling edges self._coupling = {} + # Keep track whether SNR of any source should be adjusted + # If yes, then a forward model is required for simulation + self.is_snr_adjusted = False + def add_point_sources( self, location, @@ -53,16 +58,18 @@ def add_point_sources( Waveforms of source activity provided either directly in an array (fixed for every configuration) or as a function that generates the waveforms (but differ between configurations if the generation is random). - snr: None (do not adjust SNR), float (same SNR for all sources), or array (one value per source) - TODO: fix when finalizing SNR - NB: only positive values make sense, raise error if negative ones are provided + snr: None, float, or array + SNR values for the defined sources. Can be None (no adjustment of SNR), + a single value that is used for all sources or an array with one SNR + value per source. location_params: dict, optional Keyword arguments that will be passed to the location function. waveform_params: dict, optional Keyword arguments that will be passed to the waveform function. snr_params: dict, optional - TODO: fix when finalizing SNR - fmin and fmax for the frequency band that will be used to adjust SNR. + Additional parameters required for the adjustment of SNR. + Specify fmin and fmax here to define the frequency band which + should used for calculating the SNR. names: list, optional A list of names for each source. If not specified, the names will be autogenerated using the format 'sgN-sM', where N is the index of the @@ -92,6 +99,10 @@ def add_point_sources( self._source_groups.append(point_sg) self._sources.extend(point_sg.names) + # Check if SNR should be adjusted + if point_sg.snr is not None: + self.is_snr_adjusted = True + # Return the names of newly added sources return point_sg.names @@ -177,44 +188,82 @@ def simulate( self, sfreq, duration, + fwd=None, random_state=None ): + """ + Simulate a configuration of defined sources. + + Parameters + ---------- + sfreq: float + The sampling frequency of the simulated data, in Hz. + duration: float + Duration of the simulated data, in seconds. + fwd: mne.Forward, optional + The forward model, only to be used for the adjustment of SNR. + If no adjustment is performed, the forward model is not required. + random_state: int or None, optional + The random state can be provided to obtain reproducible configurations. + If None (default), the simulated data will differ between function calls. + + Returns + ------- + sc: SourceConfiguration + The source configuration, which contains the defined sources and + their corresponding waveforms. + """ + if not (self._source_groups or self._noise_groups): raise ValueError('No sources were added to the configuration.') - return _simulate( + if self.is_snr_adjusted and fwd is None: + raise ValueError('A forward model is required for the adjustment ' + 'of SNR.') + + # Initialize the SourceConfiguration + sc = SourceConfiguration(self.src, sfreq, duration, random_state=random_state) + + # Simulate signal and noise + sources, noise_sources = _simulate( self._source_groups, self._noise_groups, + self.is_snr_adjusted, self.src, - sfreq, - duration, + sc.times, + fwd=fwd, random_state=random_state ) + # Add the sources to the simulated configuration + sc._sources = sources + sc._noise_sources = noise_sources + + return sc + def _simulate( source_groups, noise_groups, + is_snr_adjusted, src, - sfreq, - duration, + times, + fwd, random_state=None ): """ This function describes the simulation workflow. """ - # Initialize the SourceConfiguration - sc = SourceConfiguration(src, sfreq, duration, random_state=random_state) - + # Simulate all sources independently first (no coupling yet) noise_sources = [] for ng in noise_groups: - noise_sources.extend(ng.simulate(src, sc.times, random_state=random_state)) + noise_sources.extend(ng.simulate(src, times, random_state=random_state)) noise_sources = {s.name: s for s in noise_sources} sources = [] for sg in source_groups: - sources.extend(sg.simulate(src, sc.times, random_state=random_state)) + sources.extend(sg.simulate(src, times, random_state=random_state)) sources = {s.name: s for s in sources} # Setup the desired coupling patterns @@ -225,14 +274,8 @@ def _simulate( # If there are no cycles, traverse the graph and set coupling according to the selected method # Try calling the coupling with the provided parameters but be prepared for mismatches - # Adjust the SNR of sources in each source group - # 1. Estimate the noise variance in the specified band - # fwd_noise = mne.forward.restrict_forward_to_stc(fwd, stc_noise, on_missing='raise') - # noise_var = get_sensor_space_variance() - # 2. Adjust the amplitude of each signal source according to the desired SNR (if not None) - - # Add the sources to the simulated configuration - sc._sources = sources - sc._noise_sources = noise_sources + # Adjust the SNR if needed + if is_snr_adjusted: + sources = _adjust_snr(src, fwd, sources, source_groups, noise_sources) - return sc + return sources, noise_sources diff --git a/src/meegsim/snr.py b/src/meegsim/snr.py index de0ba5b..4df7d16 100644 --- a/src/meegsim/snr.py +++ b/src/meegsim/snr.py @@ -1,9 +1,10 @@ import numpy as np -import warnings import mne from scipy.signal import butter, filtfilt +from .sources import _combine_sources_into_stc + def get_sensor_space_variance(stc, fwd, *, fmin=None, fmax=None, filter=False): """ @@ -32,29 +33,38 @@ def get_sensor_space_variance(stc, fwd, *, fmin=None, fmax=None, filter=False): Variance with respect to leadfield. """ + stc_data = stc.data if filter: - if fmin is None: - warnings.warn("fmin was None. Setting fmin to 8 Hz", UserWarning) - fmin = 8. - if fmax is None: - warnings.warn("fmax was None. Setting fmax to 12 Hz", UserWarning) - fmax = 12. + if fmin is None or fmax is None: + raise ValueError( + 'Frequency band limits are required for the adjustment of SNR.' + ) b, a = butter(2, np.array([fmin, fmax]) / stc.sfreq * 2, btype='bandpass') - stc_data = filtfilt(b, a, stc.data, axis=1) - else: - stc_data = stc.data + stc_data = filtfilt(b, a, stc_data, axis=1) + + try: + fwd_restrict = mne.forward.restrict_forward_to_stc(fwd, stc, + on_missing='raise') + leadfield_restict = fwd_restrict['sol']['data'] + except ValueError: + raise ValueError( + 'The provided forward model does not contain some of the ' + 'simulated sources, so the SNR cannot be adjusted.' + ) - fwd_restrict = mne.forward.restrict_forward_to_stc(fwd, stc, on_missing='ignore') - leadfield_restict = fwd_restrict['sol']['data'] + n_samples = stc_data.shape[1] + n_sensors = leadfield_restict.shape[0] + source_cov = (stc_data @ stc_data.T) / n_samples + sensor_cov = leadfield_restict @ source_cov @ leadfield_restict.T + sensor_var = np.trace(sensor_cov) / n_sensors - stc_var = np.mean(stc_data ** 2) * np.mean(leadfield_restict ** 2) - return stc_var + return sensor_var -def adjust_snr(signal_var, noise_var, *, target_snr=1): +def amplitude_adjustment_factor(signal_var, noise_var, target_snr): """ - Derive the signal amplitude that allows obtaining target SNR + Derive the adjustment factor for signal amplitude that allows obtaining the target SNR Parameters ---------- @@ -66,13 +76,13 @@ def adjust_snr(signal_var, noise_var, *, target_snr=1): Variance of the simulated noise with respect to leadfield. Can be obtained with a function snr.get_sensor_space_variance. - target_snr: float, optional - Value of a desired SNR for the signal. default = 1. + target_snr: float + Value of a desired SNR for the signal. Returns ------- - out: float - The value that original signal should be scaled (multiplied) to in order to obtain desired SNR. + factor: float + The original signal should be multiplied by this value to obtain the desired SNR. """ snr_current = np.divide(signal_var, noise_var) @@ -89,3 +99,34 @@ def adjust_snr(signal_var, noise_var, *, target_snr=1): "signals.") return factor + + +def _adjust_snr(src, fwd, sources, source_groups, noise_sources): + # Get the stc and leadfield of all noise sources + stc_noise = _combine_sources_into_stc(noise_sources.values(), src) + + # Adjust the SNR of sources in each source group + for sg in source_groups: + if sg.snr is None: + continue + + # Estimate the noise variance in the specified frequency band + fmin, fmax = sg.snr_params['fmin'], sg.snr_params['fmax'] + noise_var = get_sensor_space_variance(stc_noise, fwd, + fmin=fmin, fmax=fmax, filter=True) + + # Adjust the amplitude of each source in the group to match the target SNR + for name, target_snr in zip(sg.names, sg.snr): + s = sources[name] + + # NOTE: taking a safer approach for now and filtering + # even if the signal is already a narrowband oscillation + signal_var = get_sensor_space_variance(s.to_stc(src), fwd, + fmin=fmin, fmax=fmax, filter=True) + + # NOTE: patch sources might require more complex calculations + # if the within-patch correlation is not equal to 1 + factor = amplitude_adjustment_factor(signal_var, noise_var, target_snr) + s.waveform *= factor + + return sources diff --git a/src/meegsim/source_groups.py b/src/meegsim/source_groups.py index 5326fd0..c6dda4e 100644 --- a/src/meegsim/source_groups.py +++ b/src/meegsim/source_groups.py @@ -113,14 +113,14 @@ def create( The location provided by the user. waveform: list of callable The waveform provided by the user. - snr: - TODO: fix when finalizing SNR + snr: None, float, or array + The SNR values provided by the user. location_params: dict, optional Additional keyword arguments for the location function. waveform_params: dict, optional Additional keyword arguments for the waveform function. - snr_params: - TODO: fix when finalizing SNR + snr_params: dict, optional + Additional parameters for the adjustment of SNR. names: The names of sources provided by the user. group: @@ -139,7 +139,7 @@ def create( location, n_sources = check_location(location, location_params, src) waveform = check_waveform(waveform, waveform_params, n_sources) snr = check_snr(snr, n_sources) - snr_params = check_snr_params(snr_params) + snr_params = check_snr_params(snr_params, snr) # Auto-generate or check the provided source names if not names: diff --git a/tests/test_check.py b/tests/test_check.py index b1ce739..dc75bb5 100644 --- a/tests/test_check.py +++ b/tests/test_check.py @@ -4,7 +4,7 @@ from functools import partial from meegsim._check import ( check_callable, check_vertices_list_of_tuples, check_vertices_in_src, - check_location, check_waveform, check_names + check_location, check_waveform, check_names, check_snr, check_snr_params ) from utils.prepare import prepare_source_space @@ -175,6 +175,58 @@ def waveform_fun(n_series, times, random_state=None): ) +def test_check_snr_is_none_passes(): + snr = check_snr(None, 5) + assert snr is None + + +@pytest.mark.parametrize("n_sources", [1, 5, 10]) +def test_check_snr_float_passes(n_sources): + snr = check_snr(1., n_sources) + assert snr.size == n_sources + assert np.all(snr == 1.) + + +def test_check_snr_array_valid_shape_passes(): + initial = [1, 2, 3, 4, 5] + snr = check_snr(initial, 5) + assert np.array_equal(snr, initial) + + +def test_check_snr_array_invalid_shape_raises(): + initial = [1, 2, 3, 4, 5] + with pytest.raises(ValueError, match="of the 3 sources, got 5"): + check_snr(initial, 3) + + +def test_check_snr_negative_snr_raises(): + with pytest.raises(ValueError, match="Each SNR value should be positive"): + check_snr([-1, 0, 1], 3) + + +def test_check_snr_params_snr_is_none(): + snr_params = check_snr_params(dict(), None) + assert not snr_params + + +def test_check_snr_params_no_fmin(): + with pytest.raises(ValueError, match="Please add fmin and fmax"): + check_snr_params(dict(fmax=12.), snr=1.) + + +def test_check_snr_params_no_fmax(): + with pytest.raises(ValueError, match="Please add fmin and fmax"): + check_snr_params(dict(fmin=8.), snr=1.) + + +def test_check_snr_params_negative_fmin_fmax(): + with pytest.raises(ValueError, match="Frequency limits should be positive"): + check_snr_params(dict(fmin=-8., fmax=12.), snr=1.) + + with pytest.raises(ValueError, match="Frequency limits should be positive"): + check_snr_params(dict(fmin=8., fmax=-12.), snr=1.) + + def test_check_names_should_pass(): initial = ['m1-lh', 'm1-rh', 's1-lh', 's1-rh'] check_names(initial, 4, ['v1-lh', 'v1-rh']) diff --git a/tests/test_simulate.py b/tests/test_simulate.py index 17ec464..01c7a6f 100644 --- a/tests/test_simulate.py +++ b/tests/test_simulate.py @@ -2,13 +2,13 @@ import numpy as np import pytest -from collections import namedtuple from mock import patch, Mock from meegsim.simulate import SourceSimulator, _simulate from meegsim.source_groups import PointSourceGroup -from utils.prepare import prepare_source_space +from utils.mocks import MockPointSource +from utils.prepare import prepare_source_space, prepare_forward def test_sourcesimulator_add_point_sources(): @@ -91,6 +91,47 @@ def test_sourcesimulator_add_noise_sources(): f"Expected eight sources to be created, got {len(sim._sources)}" +def test_sourcesimulator_is_snr_adjusted(): + src = prepare_source_space( + types=['surf', 'surf'], + vertices=[[0, 1], [0, 1]] + ) + sim = SourceSimulator(src) + + # Add noise sources + sim.add_noise_sources( + [(0, 0), (0, 1), (1, 0), (1, 1)], + np.ones((4, 100)) + ) + + # SNR should not be adjusted yet + assert not sim.is_snr_adjusted + + # Add point sources WITHOUT adjustment of SNR + sim.add_point_sources( + [(0, 0), (0, 1), (1, 0), (1, 1)], + np.ones((4, 100)) + ) + + # SNR should not be adjusted yet + assert not sim.is_snr_adjusted + + # Add point sources WITH adjustment of SNR + sim.add_point_sources( + [(0, 0), (0, 1), (1, 0), (1, 1)], + np.ones((4, 100)), + snr=10, + snr_params=dict(fmin=8, fmax=12) + ) + + # SNR should be adjusted now + assert sim.is_snr_adjusted + + # Forward model is required for simulations + with pytest.raises(ValueError, match="A forward model"): + sim.simulate(sfreq=100, duration=30) + + def test_sourcesimulator_simulate_empty_raises(): src = prepare_source_space( types=['surf', 'surf'], @@ -102,7 +143,7 @@ def test_sourcesimulator_simulate_empty_raises(): sim.simulate(sfreq=250, duration=30, random_state=0) -@patch('meegsim.simulate._simulate', return_value=0) +@patch('meegsim.simulate._simulate', return_value=([], [])) def test_sourcesimulator_simulate(simulate_mock): src = prepare_source_space( types=['surf', 'surf'], @@ -126,11 +167,19 @@ def test_sourcesimulator_simulate(simulate_mock): def test_simulate(): # return mock PointSource's # noise sources are created first (1 + 3), then actual sources (2) - MockSource = namedtuple("MockSource", ['name']) simulate_mock = Mock(side_effect=[ - [MockSource(name='s1')], - [MockSource(name='s4'), MockSource(name='s5'), MockSource(name='s6')], - [MockSource(name='s2'), MockSource(name='s3')], + [ + MockPointSource(name='s1') + ], + [ + MockPointSource(name='s4'), + MockPointSource(name='s5'), + MockPointSource(name='s6') + ], + [ + MockPointSource(name='s2'), + MockPointSource(name='s3') + ], ]) src = prepare_source_space( @@ -141,18 +190,21 @@ def test_simulate(): # some dummy data - 2 sources + (1 + 3 = 4) noise sources expected source_groups = [ PointSourceGroup(2, [(0, 0), (0, 1)], - np.ones((2, 100)), None, dict(), []), + np.ones((2, 100)), None, dict(), ['s2', 's3']), ] noise_groups = [ - PointSourceGroup(1, [(0, 0)], np.array([0]), None, dict(), []), + PointSourceGroup(1, [(0, 0)], np.array([0]), None, dict(), ['s1']), PointSourceGroup(3, [(0, 0), (0, 1), (1, 0)], - np.ones((3, 100)), None, dict(), []), + np.ones((3, 100)), None, dict(), ['s4', 's5', 's6']), ] with patch.object(meegsim.source_groups.PointSourceGroup, 'simulate', simulate_mock): - sc = _simulate(source_groups, noise_groups, src, - sfreq=250, duration=30, random_state=0) + sfreq = 100 + duration = 5 + times = np.arange(0, sfreq * duration) / sfreq + sources, noise_sources = _simulate(source_groups, noise_groups, False, src, + times=times, fwd=None, random_state=0) assert len(simulate_mock.call_args_list) == 3, \ f"Expected three calls of PointSourceGroup.simulate method" @@ -161,6 +213,50 @@ def test_simulate(): for kall in simulate_mock.call_args_list] assert all(random_states), "random_state was not passed correctly" - assert len(sc._sources) == 2, f"Expected 2 sources, got {len(sc._sources)}" - assert len(sc._noise_sources) == 4, \ - f"Expected 4 sources, got {len(sc._noise_sources)}" + assert len(sources) == 2, f"Expected 2 sources, got {len(sources)}" + assert len(noise_sources) == 4, \ + f"Expected 4 sources, got {len(noise_sources)}" + + +@patch('meegsim.simulate._adjust_snr', return_value = []) +def test_simulate_snr_adjustment(setup_snr_mock): + # return mock PointSource's - 1 noise source, 1 signal source + simulate_mock = Mock(side_effect=[ + [MockPointSource(name='n1')], + [MockPointSource(name='s1')] + ]) + + src = prepare_source_space( + types=['surf', 'surf'], + vertices=[[0, 1], [0, 1]] + ) + fwd = prepare_forward(5, 4) + + # Define source groups + source_groups = [ + PointSourceGroup( + n_sources=1, + location=[(0, 0)], + waveform=np.ones((1, 100)), + snr=np.array([5.]), + snr_params=dict(fmin=8, fmax=12), + names=['s1'] + ), + ] + noise_groups = [ + PointSourceGroup(1, [(1, 1)], np.ones((1, 100)), None, dict(), ['n1']), + ] + + with patch.object(meegsim.source_groups.PointSourceGroup, + 'simulate', simulate_mock): + sfreq = 100 + duration = 5 + times = np.arange(0, sfreq * duration) / sfreq + sources, noise_sources = _simulate(source_groups, noise_groups, True, src, + times=times, fwd=fwd, random_state=0) + + # Check that the SNR adjustment was performed + setup_snr_mock.assert_called() + + # Check that the result (empty list in the mock) was saved as is + assert not sources diff --git a/tests/test_snr.py b/tests/test_snr.py index 6bc8dc5..2527de3 100644 --- a/tests/test_snr.py +++ b/tests/test_snr.py @@ -4,80 +4,13 @@ import pytest -from meegsim.snr import get_sensor_space_variance, adjust_snr +from meegsim.snr import ( + get_sensor_space_variance, amplitude_adjustment_factor, _adjust_snr +) +from meegsim.source_groups import PointSourceGroup - -def create_dummy_sourcespace(vertices): - # Fill in dummy data as a constant time series equal to the vertex number - n_src_spaces = len(vertices) - type_src = 'surf' if n_src_spaces == 2 else 'vol' - src = [] - for i in range(n_src_spaces): - # Create a simple dummy data structure - n_verts = len(vertices[i]) - vertno = vertices[i] # Vertices for this hemisphere - xyz = np.random.rand(n_verts, 3) * 100 # Random positions - src_dict = dict( - vertno=vertno, - rr=xyz, - nn=np.random.rand(n_verts, 3), # Random normals - inuse=np.ones(n_verts, dtype=int), # All vertices in use - nuse=n_verts, - type=type_src, - id=i, - np=n_verts - ) - src.append(src_dict) - - return mne.SourceSpaces(src) - - -def create_dummy_forward(): - # Define the basic parameters for the forward solution - n_sources = 10 # Number of ipoles - n_channels = 5 # Number of MEG/EEG channels - - # Create a dummy info structure - info = mne.create_info(ch_names=['MEG1', 'MEG2', 'MEG3', 'EEG1', 'EEG2'], - sfreq=1000., ch_types=['mag', 'mag', 'mag', 'eeg', 'eeg']) - - # Generate random source space data (e.g., forward operator) - fwd_data = np.random.randn(n_channels, n_sources) - - # Create a dummy source space - lh_vertno = np.arange(n_sources // 2) - rh_vertno = np.arange(n_sources // 2) - - src = create_dummy_sourcespace([lh_vertno, rh_vertno]) - - # Generate random source positions - source_rr = np.random.rand(n_sources, 3) - - # Generate random source orientations - source_nn = np.random.randn(n_sources, 3) - source_nn /= np.linalg.norm(source_nn, axis=1, keepdims=True) - - # Create a forward solution - forward = { - 'sol': {'data': fwd_data}, - '_orig_sol': fwd_data, - 'sol_grad': None, - 'info': info, - 'source_ori': 1, - 'surf_ori': True, - 'nsource': n_sources, - 'nchan': n_channels, - 'coord_frame': 1, - 'src': src, - 'source_rr': source_rr, - 'source_nn': source_nn, - '_orig_source_ori': 1 - } - - # Convert the dictionary to an mne.Forward object - fwd = mne.Forward(**forward) - - return fwd +from utils.mocks import MockPointSource +from utils.prepare import prepare_source_space, prepare_forward def prepare_stc(vertices, num_samples=500): @@ -86,19 +19,26 @@ def prepare_stc(vertices, num_samples=500): return mne.SourceEstimate(data, vertices, tmin=0, tstep=0.01) -def test_get_sensor_space_variance_no_filter_all_vert(): - fwd = create_dummy_forward() - vertices = [list(np.arange(5)), list(np.arange(5))] +def test_get_sensor_space_variance_no_filter(): + fwd = prepare_forward(2, 4) + fwd['sol']['data'] = np.array([ + [0, 1, 0, -1], + [0, -1, 0, 1] + ]) + vertices = [[0, 1], [0, 1]] stc = prepare_stc(vertices) - leadfield = fwd['sol']['data'] - expected_variance = np.mean(stc.data ** 2) * np.mean(leadfield ** 2) + + # Vertices with vertno=1 in both hemispheres have constant activity (1) + # Since the leadfield values are opposite for these vertices, the + # activity should cancel out in sensor space + expected_variance = 0. variance = get_sensor_space_variance(stc, fwd, filter=False) assert np.isclose(variance, expected_variance), \ f"Expected variance {expected_variance}, but got {variance}" def test_get_sensor_space_variance_no_filter_sel_vert(): - fwd = create_dummy_forward() + fwd = prepare_forward(5, 10) vertices = [[0], [0]] stc = prepare_stc(vertices) @@ -109,13 +49,13 @@ def test_get_sensor_space_variance_no_filter_sel_vert(): f"Expected variance {expected_variance}, but got {variance}" -@patch('meegsim.snr.filtfilt', return_value=np.ones((1, 100))) +@patch('meegsim.snr.filtfilt', return_value=np.ones((4, 500))) @patch('meegsim.snr.butter', return_value=(0, 0)) def test_get_sensor_space_variance_with_filter(butter_mock, filtfilt_mock): - fwd = create_dummy_forward() + fwd = prepare_forward(5, 10) vertices = [[0, 1], [0, 1]] stc = prepare_stc(vertices) - variance = get_sensor_space_variance(stc, fwd, filter=True) + variance = get_sensor_space_variance(stc, fwd, fmin=8, fmax=12, filter=True) # Check that butter and filtfilt were called butter_mock.assert_called() @@ -136,13 +76,13 @@ def test_get_sensor_space_variance_with_filter(butter_mock, filtfilt_mock): assert variance >= 0, "Variance should be non-negative" -@patch('meegsim.snr.filtfilt', return_value=np.ones((1, 100))) +@patch('meegsim.snr.filtfilt', return_value=np.ones((4, 500))) @patch('meegsim.snr.butter', return_value=(0, 0)) def test_get_sensor_space_variance_with_filter_fmin_fmax(butter_mock, filtfilt_mock): - fwd = create_dummy_forward() + fwd = prepare_forward(5, 10) vertices = [[0, 1], [0, 1]] stc = prepare_stc(vertices) - variance = get_sensor_space_variance(stc, fwd, filter=True, fmin=20., fmax=30.) + get_sensor_space_variance(stc, fwd, filter=True, fmin=20., fmax=30.) # Check that butter and filtfilt were called butter_mock.assert_called() @@ -161,44 +101,83 @@ def test_get_sensor_space_variance_with_filter_fmin_fmax(butter_mock, filtfilt_m f"Expected fmax to be {expected_wmax}, got {actual_wmax}" -@pytest.mark.parametrize("target_snr", np.logspace(-6, 6, 10)) -def test_adjust_snr(target_snr): - signal_var = 10.0 - noise_var = 5.0 +def test_get_sensor_space_variance_no_fmin_fmax(): + fwd = prepare_forward(5, 10) + vertices = [[0, 1], [0, 1]] + stc = prepare_stc(vertices) - snr_current = np.divide(signal_var, noise_var) - expected_result = np.sqrt(target_snr / snr_current) + # No filtering required - should pass + get_sensor_space_variance(stc, fwd, filter=False) - result = adjust_snr(signal_var, noise_var, target_snr=target_snr) - assert np.isclose(result, expected_result), \ - f"Expected {expected_result}, but got {result}" + # No fmin + with pytest.raises(ValueError, match="Frequency band limits are required"): + get_sensor_space_variance(stc, fwd, fmax=12, filter=True) + + # No fmax + with pytest.raises(ValueError, match="Frequency band limits are required"): + get_sensor_space_variance(stc, fwd, fmin=8, filter=True) -def test_adjust_snr_default_target(): +@pytest.mark.parametrize("target_snr", np.logspace(-6, 6, 10)) +def test_amplitude_adjustment_factor(target_snr): signal_var = 10.0 noise_var = 5.0 - default_snr = 1.0 snr_current = np.divide(signal_var, noise_var) - expected_result = np.sqrt(default_snr / snr_current) + expected_result = np.sqrt(target_snr / snr_current) - result = adjust_snr(signal_var, noise_var) + result = amplitude_adjustment_factor(signal_var, noise_var, target_snr=target_snr) assert np.isclose(result, expected_result), \ f"Expected {expected_result}, but got {result}" -def test_adjust_snr_zero_signal_var(): +def test_amplitude_adjustment_zero_signal_var(): signal_var = 0.0 noise_var = 5.0 with pytest.raises(ValueError, match="initial SNR appear to be zero"): - adjust_snr(signal_var, noise_var) + amplitude_adjustment_factor(signal_var, noise_var, target_snr=1) -def test_adjust_snr_zero_noise_var(): +def test_amplitude_adjustment_zero_noise_var(): signal_var = 10.0 noise_var = 0.0 with pytest.raises(ValueError, match="noise variance appears to be zero"): - adjust_snr(signal_var, noise_var) + amplitude_adjustment_factor(signal_var, noise_var, target_snr=1) + + +@patch('meegsim.snr.amplitude_adjustment_factor', return_value=2.) +def test_adjust_snr(adjust_snr_mock): + src = prepare_source_space( + types=['surf', 'surf'], + vertices=[[0, 1], [0, 1]] + ) + fwd = prepare_forward(5, 4) + + # Define source groups + source_groups = [ + PointSourceGroup( + n_sources=1, + location=[(0, 0)], + waveform=np.ones((1, 100)), + snr=np.array([5.]), + snr_params=dict(fmin=8, fmax=12), + names=['s1'] + ), + ] + sources = { + 's1': MockPointSource(name='s1') + } + noise_sources = { + 'n1': MockPointSource(name='n1') + } + + sources = _adjust_snr(src, fwd, sources, source_groups, noise_sources) + + # Check the SNR adjustment was performed + adjust_snr_mock.assert_called() + # Check that the amplitude of the source was adjusted + target = sources['s1'] + assert np.all(target.waveform == 2) diff --git a/tests/utils/mocks.py b/tests/utils/mocks.py new file mode 100644 index 0000000..31ad02b --- /dev/null +++ b/tests/utils/mocks.py @@ -0,0 +1,14 @@ +import numpy as np +import mne + + +class MockPointSource: + """ + Mock PointSource class for testing purposes. + """ + def __init__(self, name, shape=(1, 100)): + self.name = name + self.waveform = np.ones(shape) + + def to_stc(self, *args, **kwargs): + return mne.SourceEstimate(self.waveform, [[0], []], 0, 0.01) diff --git a/tests/utils/prepare.py b/tests/utils/prepare.py index 4753a4c..4282ca8 100644 --- a/tests/utils/prepare.py +++ b/tests/utils/prepare.py @@ -38,3 +38,54 @@ def prepare_source_space(types, vertices): src.append(src_dict) return mne.SourceSpaces(src) + + +def prepare_forward(n_channels, n_sources, + ch_names=None, ch_types=None, sfreq=250): + + assert n_sources % 2 == 0, "Only an even number of sources is supported" + + # Create a dummy info structure + if ch_names is None: + ch_names = [f'EEG{i+1}' for i in range(n_channels)] + if ch_types is None: + ch_types = ['eeg'] * n_channels + info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types) + + # Generate random source space data (e.g., forward operator) + fwd_data = np.random.randn(n_channels, n_sources) + + # Create a dummy source space + lh_vertno = np.arange(n_sources // 2) + rh_vertno = np.arange(n_sources // 2) + + src = prepare_source_space(['surf', 'surf'], [lh_vertno, rh_vertno]) + + # Generate random source positions + source_rr = np.random.rand(n_sources, 3) + + # Generate random source orientations + source_nn = np.random.randn(n_sources, 3) + source_nn /= np.linalg.norm(source_nn, axis=1, keepdims=True) + + # Create a forward solution + forward = { + 'sol': {'data': fwd_data}, + '_orig_sol': fwd_data, + 'sol_grad': None, + 'info': info, + 'source_ori': 1, + 'surf_ori': True, + 'nsource': n_sources, + 'nchan': n_channels, + 'coord_frame': 1, + 'src': src, + 'source_rr': source_rr, + 'source_nn': source_nn, + '_orig_source_ori': 1 + } + + # Convert the dictionary to an mne.Forward object + fwd = mne.Forward(**forward) + + return fwd