diff --git a/examples/snr_adjustment.py b/examples/snr_adjustment.py new file mode 100644 index 0000000..debc10b --- /dev/null +++ b/examples/snr_adjustment.py @@ -0,0 +1,73 @@ +""" +Testing the adjustment of SNR +""" + +import mne +import numpy as np +import matplotlib.pyplot as plt + +from pathlib import Path + +from meegsim.location import select_random +from meegsim.simulate import SourceSimulator +from meegsim.waveform import narrowband_oscillation + + +# Load the head model +fs_dir = Path('~/mne_data/MNE-fsaverage-data/fsaverage/') +fwd_path = fs_dir / 'bem_copy' / 'fsaverage-oct6-fwd.fif' +src_path = fs_dir / 'bem_copy' / 'fsaverage-oct6-src.fif' +src = mne.read_source_spaces(src_path) +fwd = mne.read_forward_solution(fwd_path) + +# Simulation parameters +sfreq = 250 +duration = 60 +seed = 123 + +# Channel info +montage = mne.channels.make_standard_montage('standard_1020') +ch_names = [ch for ch in montage.ch_names if ch not in ['O9', 'O10']] +info = mne.create_info(ch_names, sfreq, ch_types='eeg') +info.set_montage('standard_1020') + +# Adapt fwd to the info (could be done by our structure in principle) +fwd = mne.convert_forward_solution(fwd, force_fixed=True) +fwd = mne.pick_channels_forward(fwd, info.ch_names, ordered=True) + +fig, axes = plt.subplots(ncols=3, figsize=(8, 3)) +snr_values = [1, 5, 10] + +for i_snr, target_snr in enumerate(snr_values): + sim = SourceSimulator(src) + + # Select some vertices randomly + sim.add_point_sources( + location=select_random, + waveform=narrowband_oscillation, + location_params=dict(n=3), + waveform_params=dict(fmin=8, fmax=12), + snr=target_snr, + snr_params=dict(fmin=8, fmax=12), + names=['s1', 's2', 's3'] + ) + + sim.add_noise_sources( + location=select_random, + location_params=dict(n=10) + ) + + sc = sim.simulate(sfreq, duration, fwd=fwd, random_state=seed) + raw = sc.to_raw(fwd, info) + + spec = raw.compute_psd(fmax=40, n_fft=sfreq, + n_overlap=sfreq//2, n_per_seg=sfreq) + spec.plot(average=True, dB=False, axes=axes[i_snr], amplitude=False) + + axes[i_snr].set_title(f'SNR={target_snr}') + axes[i_snr].set_xlabel('Frequency (Hz)') + axes[i_snr].set_ylabel('PSD (uV^2/Hz)') + axes[i_snr].set_ylim([0, 1.25]) + +fig.tight_layout() +plt.show(block=True) diff --git a/src/meegsim/configuration.py b/src/meegsim/configuration.py index 7db1c35..9d41784 100644 --- a/src/meegsim/configuration.py +++ b/src/meegsim/configuration.py @@ -38,6 +38,7 @@ def __init__(self, src, sfreq, duration, random_state=None): self.duration = duration self.n_samples = self.sfreq * self.duration self.times = np.arange(self.n_samples) / self.sfreq + self.tstep = self.times[1] - self.times[0] # Random state (for reproducibility) self.random_state = random_state @@ -54,7 +55,7 @@ def to_stc(self): if not all_sources: raise ValueError('No sources were added to the configuration.') - return _combine_sources_into_stc(all_sources, self.src) + return _combine_sources_into_stc(all_sources, self.src, self.tstep) def to_raw(self, fwd, info, scaling_factor=1e-6): # Parameters: diff --git a/src/meegsim/simulate.py b/src/meegsim/simulate.py index 7afbb53..f52997f 100644 --- a/src/meegsim/simulate.py +++ b/src/meegsim/simulate.py @@ -403,6 +403,7 @@ def _simulate( # Adjust the SNR if needed if is_snr_adjusted: - sources = _adjust_snr(src, fwd, sources, source_groups, noise_sources) + tstep = times[1] - times[0] + sources = _adjust_snr(src, fwd, tstep, sources, source_groups, noise_sources) return sources, noise_sources diff --git a/src/meegsim/snr.py b/src/meegsim/snr.py index 4df7d16..c327cfc 100644 --- a/src/meegsim/snr.py +++ b/src/meegsim/snr.py @@ -101,9 +101,14 @@ def amplitude_adjustment_factor(signal_var, noise_var, target_snr): return factor -def _adjust_snr(src, fwd, sources, source_groups, noise_sources): +def _adjust_snr(src, fwd, tstep, sources, source_groups, noise_sources): # Get the stc and leadfield of all noise sources - stc_noise = _combine_sources_into_stc(noise_sources.values(), src) + if not noise_sources: + raise ValueError( + 'No noise sources were added to the simulation, so the SNR ' + 'cannot be adjusted.' + ) + stc_noise = _combine_sources_into_stc(noise_sources.values(), src, tstep) # Adjust the SNR of sources in each source group for sg in source_groups: @@ -121,7 +126,7 @@ def _adjust_snr(src, fwd, sources, source_groups, noise_sources): # NOTE: taking a safer approach for now and filtering # even if the signal is already a narrowband oscillation - signal_var = get_sensor_space_variance(s.to_stc(src), fwd, + signal_var = get_sensor_space_variance(s.to_stc(src, tstep), fwd, fmin=fmin, fmax=fmax, filter=True) # NOTE: patch sources might require more complex calculations diff --git a/src/meegsim/sources.py b/src/meegsim/sources.py index aaf1269..0814536 100644 --- a/src/meegsim/sources.py +++ b/src/meegsim/sources.py @@ -8,61 +8,66 @@ import numpy as np import mne -from .utils import combine_stcs, get_sfreq, _extract_hemi +from .utils import vertices_to_mne, get_sfreq, _extract_hemi class _BaseSource: """ An abstract class representing a source of activity. """ + kind = "base" - def __init__(self, waveform, sfreq): + def __init__(self, waveform): # Current constraint: one source corresponds to one waveform # Point source: the waveform is present in one vertex # Patch source: the waveform is mixed with noise in several vertices self.waveform = waveform - self.sfreq = sfreq - def to_stc(self): + @property + def data(self): raise NotImplementedError( - 'The to_stc() method should be implemented in the subclass.' + 'The .data property should be implemented in a subclass.' + ) + + @property + def vertices(self): + raise NotImplementedError( + 'The .vertices property should be implemented in a subclass.' ) + def _check_compatibility(self, src): + """ + Checks that the source is can be added to the provided src. + + Parameters + ---------- + src: mne.SourceSpaces + The source space where the source should be considered. -class PointSource(_BaseSource): - """ - Point source of activity that is located in one of the vertices in - the source space. - - Attributes - ---------- - src_idx: int - The index of source space that the point source belong to. - vertno: int - The vertex that the point source correspond to - waveform: np.array - The waveform of source activity. - sfreq: float - The sampling frequency of the activity time course. - hemi: str or None, optional - Human-readable name of the hemisphere (e.g, lh or rh). - """ - - def __init__(self, name, src_idx, vertno, waveform, sfreq, hemi=None): - super().__init__(waveform, sfreq) - - self.name = name - self.src_idx = src_idx - self.vertno = vertno - self.sfreq = sfreq - self.hemi = hemi + Raises + ------ + ValueError + If the source does not exist in the provided src. + """ + + if self.src_idx >= len(src): + raise ValueError( + f"The {self.kind} source cannot be added to the provided src. " + f"The {self.kind} source was assigned to source space {self.src_idx}, " + f"which is not present in the provided src object." + ) - def __repr__(self): - # Use human readable names of hemispheres if possible - src_desc = self.hemi if self.hemi else f'src[{self.src_idx}]' - return f'' + own_vertno = [self.vertno] if self.kind == "point" else self.vertno + missing_vertno = set(own_vertno) - set(src[self.src_idx]['vertno']) + if missing_vertno: + report_missing = ', '.join([str(v) for v in missing_vertno]) + raise ValueError( + f"The {self.kind} source cannot be added to the provided src. " + f"The source space with index {self.src_idx} does not " + f"contain the following vertices: {report_missing}" + ) - def to_stc(self, src, subject=None): + def to_stc(self, src, tstep, subject=None): """ Convert the point source into a SourceEstimate object in the context of the provided SourceSpaces. @@ -71,6 +76,8 @@ def to_stc(self, src, subject=None): ---------- src: mne.SourceSpaces The source space where the point source should be considered. + tstep: float + The sampling interval of the source time series (1 / sfreq). subject: str or None, optional Name of the subject that the stc corresponds to. If None, the subject name from the provided src is used if present. @@ -84,41 +91,65 @@ def to_stc(self, src, subject=None): Raises ------ ValueError - If the point source does not exist in the provided src. + If the source does not exist in the provided src. """ - - if self.src_idx >= len(src): - raise ValueError( - f"The point source cannot be added to the provided src. " - f"The point source was assigned to source space {self.src_idx}, " - f"which is not present in the provided src object." - ) - if self.vertno not in src[self.src_idx]['vertno']: - raise ValueError( - f"The point source cannot be added to the provided src. " - f"The source space with index {self.src_idx} does not " - f"contain the vertex {self.vertno}" - ) + self._check_compatibility(src) # Resolve the subject name as done in MNE if subject is None: subject = src[0].get("subject_his_id", None) - data = self.waveform[np.newaxis, :] - - # Create a list of vertices for each src - vertices = [[] for _ in src] - vertices[self.src_idx].append(self.vertno) - + # Convert the vertices to MNE format and construct the stc + vertices = vertices_to_mne(self.vertices, src) return mne.SourceEstimate( - data=data, + data=self.data, vertices=vertices, tmin=0, - tstep=1.0 / self.sfreq, + tstep=tstep, subject=subject ) + +class PointSource(_BaseSource): + """ + Point source of activity that is located in one of the vertices in + the source space. + + Attributes + ---------- + src_idx: int + The index of source space that the point source belong to. + vertno: int + The vertex that the point source correspond to + waveform: np.array + The waveform of source activity. + hemi: str or None, optional + Human-readable name of the hemisphere (e.g, lh or rh). + """ + kind = "point" + + def __init__(self, name, src_idx, vertno, waveform, hemi=None): + super().__init__(waveform) + + self.name = name + self.src_idx = src_idx + self.vertno = vertno + self.hemi = hemi + + def __repr__(self): + # Use human readable names of hemispheres if possible + src_desc = self.hemi if self.hemi else f'src[{self.src_idx}]' + return f'' + + @property + def data(self): + return np.atleast_2d(self.waveform) + + @property + def vertices(self): + return np.atleast_2d(np.array([self.src_idx, self.vertno])) + @classmethod def create( cls, @@ -134,9 +165,6 @@ def create( This function creates point sources according to the provided input. """ - # Get the sampling frequency - sfreq = get_sfreq(times) - # Get the list of vertices (directly from the provided input or through the function) vertices = location(src, random_state=random_state) if callable(location) else location if len(vertices) != n_sources: @@ -157,8 +185,7 @@ def create( name=name, src_idx=src_idx, vertno=vertno, - waveform=waveform, - sfreq=sfreq, + waveform=waveform, hemi=hemi )) @@ -178,19 +205,17 @@ class PatchSource(_BaseSource): The vertices that the patch sources correspond to including the central vertex. waveform: np.array The waveform of source activity. - sfreq: float - The sampling frequency of the activity time course. hemi: str or None, optional Human-readable name of the hemisphere (e.g, lh or rh). """ + kind = "patch" - def __init__(self, name, src_idx, vertno, waveform, sfreq, hemi=None): - super().__init__(waveform, sfreq) + def __init__(self, name, src_idx, vertno, waveform, hemi=None): + super().__init__(waveform) self.name = name self.src_idx = src_idx self.vertno = vertno - self.sfreq = sfreq self.hemi = hemi def __repr__(self): @@ -200,64 +225,13 @@ def __repr__(self): vertno_desc = f'{n_vertno} vertex' if n_vertno == 1 else f'{n_vertno} vertices' return f'' - def to_stc(self, src, subject=None): - """ - Convert the patch source into a SourceEstimate object in the context - of the provided SourceSpaces. - - Parameters - ---------- - src: mne.SourceSpaces - The source space where the patch source should be considered. - subject: str or None, optional - Name of the subject that the stc corresponds to. - If None, the subject name from the provided src is used if present. - - Returns - ------- - stc: mne.SourceEstimate - SourceEstimate that corresponds to the provided src and contains - one active vertex. - - Raises - ------ - ValueError - If the patch source does not exist in the provided src. - """ - - if self.src_idx >= len(src): - raise ValueError( - f"The patch source cannot be added to the provided src. " - f"The patch source was assigned to source space {self.src_idx}, " - f"which is not present in the provided src object." - ) - - missing_vertno = set(self.vertno) - set(src[self.src_idx]['vertno']) - if missing_vertno: - report_missing = ', '.join([str(v) for v in missing_vertno]) - raise ValueError( - f"The patch source cannot be added to the provided src. " - f"The source space with index {self.src_idx} does not " - f"contain the following vertices: {report_missing}" - ) - - # Resolve the subject name as done in MNE - if subject is None: - subject = src[0].get("subject_his_id", None) - - # Create a list of vertices for each src - vertices = [[] for _ in src] - - vertices[self.src_idx].extend(self.vertno) - data = np.tile(self.waveform[np.newaxis, :], (len(self.vertno), 1)) - - return mne.SourceEstimate( - data=data, - vertices=vertices, - tmin=0, - tstep=1.0 / self.sfreq, - subject=subject - ) + @property + def data(self): + return np.tile(self.waveform, (len(self.vertno), 1)) + + @property + def vertices(self): + return np.array([[self.src_idx, v] for v in self.vertno]) @classmethod def create( @@ -275,9 +249,6 @@ def create( This function creates patch sources according to the provided input. """ - # Get the sampling frequency - sfreq = get_sfreq(times) - # Get the list of vertices (directly from the provided input or through the function) vertices = location(src, random_state=random_state) if callable(location) else location if len(vertices) != n_sources: @@ -319,22 +290,60 @@ def create( src_idx=src_idx, vertno=patch_vertno, waveform=waveform, - sfreq=sfreq, hemi=hemi )) return sources -def _combine_sources_into_stc(sources, src): - stc_combined = None - - for s in sources: - stc_source = s.to_stc(src) - if stc_combined is None: - stc_combined = stc_source - continue +def _combine_sources_into_stc(sources, src, tstep): + """ + Create an stc object that contains the waveforms of all provided sources. - stc_combined = combine_stcs(stc_combined, stc_source) + Parameters + ---------- + sources: list + The list of point or patch sources. + src: mne.SourceSpaces + The source space with all candidate source locations. + tstep: float + The sampling interval of the source time series (1 / sfreq). + + Returns + ------- + stc: mne.SourceEstimate + The resulting stc object that contains all sources. + """ - return stc_combined \ No newline at end of file + # Return immediately if no sources were provided + if not sources: + return None + + # Collect the data and vertices from all sources first + data = [] + vertices = [] + for s in sources: + s._check_compatibility(src) + data.append(s.data) + vertices.append(s.vertices) + + # Stack the data and vertices of all sources + data_stacked = np.vstack(data) + vertices_stacked = np.vstack(vertices) + + # Resolve potential repetitions: if several signals apply to the same + # vertex, they should be summed + unique_vertices, indices = np.unique(vertices_stacked, axis=0, + return_inverse=True) + n_unique = unique_vertices.shape[0] + n_samples = data_stacked.shape[1] + + # Place the time courses correctly accounting for repetitions + data = np.zeros((n_unique, n_samples)) + for idx_orig, idx_new in enumerate(indices): + data[idx_new, :] += data_stacked[idx_orig, :] + + # Convert vertices to the MNE format + vertices = vertices_to_mne(unique_vertices, src) + + return mne.SourceEstimate(data, vertices, tmin=0, tstep=tstep) \ No newline at end of file diff --git a/src/meegsim/utils.py b/src/meegsim/utils.py index d5237dd..af1559c 100644 --- a/src/meegsim/utils.py +++ b/src/meegsim/utils.py @@ -184,3 +184,18 @@ def unpack_vertices(vertices_lists): def theoretical_plv(kappa): return i1(kappa) / i0(kappa) + + +def vertices_to_mne(vertices, src): + """ + Convert the vertices to the MNE format (list of lists). + """ + + vertices = np.array(vertices) + packed_vertices = [[] for _ in src] + for src_idx in np.unique(vertices[:, 0]): + src_vertices = vertices[vertices[:, 0] == src_idx, :] + src_vertno = list(np.sort(src_vertices[:, 1])) + packed_vertices[src_idx] = src_vertno + + return packed_vertices diff --git a/tests/test_configuration.py b/tests/test_configuration.py index fdef962..997300b 100644 --- a/tests/test_configuration.py +++ b/tests/test_configuration.py @@ -3,7 +3,7 @@ from mock import patch from meegsim.configuration import SourceConfiguration -from meegsim.sources import PointSource +from meegsim.sources import PointSource, PatchSource from utils.prepare import prepare_source_space @@ -28,8 +28,8 @@ def test_sourceconfiguration_to_stc_noise_only(): sc = SourceConfiguration(src, sfreq=250, duration=30) sc._noise_sources = { - 'n1': PointSource('n1', 0, 0, np.ones((250 * 30,)), sfreq=250), - 'n2': PointSource('n2', 0, 1, np.ones((250 * 30,)), sfreq=250), + 'n1': PointSource('n1', 0, 0, np.ones((250 * 30,))), + 'n2': PointSource('n2', 0, 1, np.ones((250 * 30,))), } stc = sc.to_stc() assert stc.data.shape[0] == 2, 'Expected two sources in stc' @@ -43,8 +43,8 @@ def test_sourceconfiguration_to_stc_signal_only(): sc = SourceConfiguration(src, sfreq=250, duration=30) sc._sources = { - 's1': PointSource('s1', 0, 0, np.ones((250 * 30,)), sfreq=250), - 's2': PointSource('s2', 0, 1, np.ones((250 * 30,)), sfreq=250), + 's1': PointSource('s1', 0, 0, np.ones((250 * 30,))), + 's2': PointSource('s2', 0, 1, np.ones((250 * 30,))), } stc = sc.to_stc() assert stc.data.shape[0] == 2, 'Expected two sources in stc' @@ -58,15 +58,32 @@ def test_sourceconfiguration_to_stc_signal_and_noise(): sc = SourceConfiguration(src, sfreq=250, duration=30) sc._sources = { - 's1': PointSource('s1', 0, 0, np.ones((250 * 30,)), sfreq=250), + 's1': PointSource('s1', 0, 0, np.ones((250 * 30,))), } sc._noise_sources = { - 'n1': PointSource('n1', 0, 1, np.ones((250 * 30,)), sfreq=250), + 'n1': PointSource('n1', 0, 1, np.ones((250 * 30,))), } stc = sc.to_stc() assert stc.data.shape[0] == 2, 'Expected two sources in stc' +def test_sourceconfiguration_to_stc_patch(): + src = prepare_source_space( + types=['surf', 'surf'], + vertices=[[0, 1, 2], [0, 1, 2]] + ) + + sc = SourceConfiguration(src, sfreq=250, duration=30) + n_samples = sc.sfreq * sc.duration + sources = [ + PatchSource('s1', 0, [0, 2], np.ones((n_samples,))), + PatchSource('s2', 1, [0, 1], np.ones((n_samples,))) + ] + sc._sources = {s.name: s for s in sources} + stc = sc.to_stc() + assert stc.data.shape[0] == 4, 'Expected four sources in stc' + + @patch('mne.apply_forward_raw', return_value=0) def test_sourceconfiguration_to_raw(apply_forward_mock): src = prepare_source_space( @@ -75,11 +92,13 @@ def test_sourceconfiguration_to_raw(apply_forward_mock): ) sc = SourceConfiguration(src, sfreq=250, duration=30) + n_samples = sc.sfreq * sc.duration sc._sources = { - 's1': PointSource('s1', 0, 0, np.ones((250 * 30,)), sfreq=250), + 's1': PointSource('s1', 0, 0, np.ones((n_samples,))), + 's2': PatchSource('s2', 1, [0, 1], np.ones((n_samples,))) } sc._noise_sources = { - 'n1': PointSource('n1', 0, 1, np.ones((250 * 30,)), sfreq=250), + 'n1': PointSource('n1', 0, 1, np.ones((n_samples,))), } raw = sc.to_raw([], []) diff --git a/tests/test_coupling_graph.py b/tests/test_coupling_graph.py index 73a2567..4fe7257 100644 --- a/tests/test_coupling_graph.py +++ b/tests/test_coupling_graph.py @@ -6,7 +6,7 @@ from meegsim.coupling_graph import generate_walkaround, traverse_tree, _set_coupling -from utils.mocks import MockPointSource +from utils.prepare import prepare_point_source def test_traverse_tree_with_start_node(): @@ -118,7 +118,7 @@ def coupling_fn(waveform, sfreq, kappa, random_state=0): return (kappa, side_effect[kappa]) sources = { - k: MockPointSource(name=k) for k in ['s1', 's2', 's3'] + k: prepare_point_source(name=k) for k in ['s1', 's2', 's3'] } coupling = [ ('s1', 's2', dict(method=coupling_fn, kappa=0)), @@ -145,7 +145,7 @@ def coupling_fn(waveform, sfreq, random_state=0): return random_state sources = { - k: MockPointSource(name=k) for k in ['s1', 's2', 's3'] + k: prepare_point_source(name=k) for k in ['s1', 's2', 's3'] } coupling = [ ('s1', 's2', dict(method=coupling_fn)) diff --git a/tests/test_simulate.py b/tests/test_simulate.py index 11fe6a9..f36d159 100644 --- a/tests/test_simulate.py +++ b/tests/test_simulate.py @@ -8,8 +8,7 @@ from meegsim.simulate import SourceSimulator, _simulate from meegsim.source_groups import PointSourceGroup -from utils.mocks import MockPointSource -from utils.prepare import prepare_source_space, prepare_forward +from utils.prepare import prepare_source_space, prepare_forward, prepare_point_source def test_sourcesimulator_add_point_sources(): @@ -231,16 +230,16 @@ def test_simulate(): # noise sources are created first (1 + 3), then actual sources (2) simulate_mock = Mock(side_effect=[ [ - MockPointSource(name='s1') + prepare_point_source(name='s1') ], [ - MockPointSource(name='s4'), - MockPointSource(name='s5'), - MockPointSource(name='s6') + prepare_point_source(name='s4'), + prepare_point_source(name='s5'), + prepare_point_source(name='s6') ], [ - MockPointSource(name='s2'), - MockPointSource(name='s3') + prepare_point_source(name='s2'), + prepare_point_source(name='s3') ], ]) @@ -284,8 +283,8 @@ def test_simulate(): def test_simulate_snr_adjustment(adjust_snr_mock): # return mock PointSource's - 1 noise source, 1 signal source simulate_mock = Mock(side_effect=[ - [MockPointSource(name='n1')], - [MockPointSource(name='s1')] + [prepare_point_source(name='n1')], + [prepare_point_source(name='s1')] ]) src = prepare_source_space( @@ -328,8 +327,8 @@ def test_simulate_snr_adjustment(adjust_snr_mock): def test_simulate_coupling_setup(set_coupling_mock): # return 2 mock PointSource's simulate_mock = Mock(side_effect=[ - [MockPointSource(name='s1')], - [MockPointSource(name='s2')] + [prepare_point_source(name='s1')], + [prepare_point_source(name='s2')] ]) src = prepare_source_space( diff --git a/tests/test_snr.py b/tests/test_snr.py index 2527de3..f5221fb 100644 --- a/tests/test_snr.py +++ b/tests/test_snr.py @@ -9,8 +9,7 @@ ) from meegsim.source_groups import PointSourceGroup -from utils.mocks import MockPointSource -from utils.prepare import prepare_source_space, prepare_forward +from utils.prepare import prepare_source_space, prepare_forward, prepare_point_source def prepare_stc(vertices, num_samples=500): @@ -167,13 +166,14 @@ def test_adjust_snr(adjust_snr_mock): ), ] sources = { - 's1': MockPointSource(name='s1') + 's1': prepare_point_source(name='s1') } noise_sources = { - 'n1': MockPointSource(name='n1') + 'n1': prepare_point_source(name='n1') } + tstep = 0.01 - sources = _adjust_snr(src, fwd, sources, source_groups, noise_sources) + sources = _adjust_snr(src, fwd, tstep, sources, source_groups, noise_sources) # Check the SNR adjustment was performed adjust_snr_mock.assert_called() @@ -181,3 +181,15 @@ def test_adjust_snr(adjust_snr_mock): # Check that the amplitude of the source was adjusted target = sources['s1'] assert np.all(target.waveform == 2) + + +def test_adjust_snr_no_noise_sources_raises(): + src = prepare_source_space( + types=['surf', 'surf'], + vertices=[[0, 1], [0, 1]] + ) + fwd = prepare_forward(5, 4) + + # it's only important that the noise sources list is empty + with pytest.raises(ValueError, match="No noise sources"): + _adjust_snr(src, fwd, 0.01, [], [], []) diff --git a/tests/test_sources.py b/tests/test_sources.py index b1ce886..20ddee6 100644 --- a/tests/test_sources.py +++ b/tests/test_sources.py @@ -11,9 +11,9 @@ def test_basesource_is_abstract(): waveform = np.ones((100,)) - s = _BaseSource(waveform, sfreq=250) - with pytest.raises(NotImplementedError, match="in the subclass"): - s.to_stc() + s = _BaseSource(waveform) + with pytest.raises(NotImplementedError, match="in a subclass"): + s.data @@ -30,7 +30,7 @@ def test_basesource_is_abstract(): ) def test_pointsource_repr(src_idx, vertno, hemi): # Waveform is not required for repr, leaving it empty - s = PointSource('mysource', src_idx, vertno, np.array([]), sfreq=250, hemi=hemi) + s = PointSource('mysource', src_idx, vertno, np.array([]), hemi=hemi) if hemi is None: assert f'src[{src_idx}]' in repr(s) @@ -53,8 +53,8 @@ def test_pointsource_to_stc(src_idx, vertno): types=['surf', 'surf'], vertices=[[0, 1], [0, 1]] ) - s = PointSource('mysource', src_idx, vertno, waveform, sfreq=100) - stc = s.to_stc(src) + s = PointSource('mysource', src_idx, vertno, waveform) + stc = s.to_stc(src, tstep=0.01) assert stc.data.shape[0] == 1, \ f"Expected one active vertex in stc, got {stc.data.shape[0]}" @@ -64,18 +64,19 @@ def test_pointsource_to_stc(src_idx, vertno): f"The source waveform should not change during conversion to stc" -@pytest.mark.parametrize("sfreq", [100, 250, 500]) -def test_pointsource_to_stc_sfreq(sfreq): +@pytest.mark.parametrize("tstep", [0.01, 0.025, 0.05]) +def test_pointsource_to_stc_tstep(tstep): waveform = np.ones((100,)) src = prepare_source_space( types=['surf', 'surf'], vertices=[[0, 1], [0, 1]] ) - s = PointSource('mysource', 0, 0, waveform, sfreq=sfreq) - stc = s.to_stc(src) + s = PointSource('mysource', 0, 0, waveform) + stc = s.to_stc(src, tstep=tstep) - assert stc.sfreq == sfreq, \ - f"Expected stc.sfreq to be {sfreq}, got {stc.sfreq}" + expected_sfreq = 1.0 / tstep + assert stc.sfreq == expected_sfreq, \ + f"Expected stc.sfreq to be {expected_sfreq}, got {stc.sfreq}" def test_pointsource_to_stc_subject(): @@ -84,13 +85,13 @@ def test_pointsource_to_stc_subject(): types=['surf', 'surf'], vertices=[[0, 1], [0, 1]] ) - s = PointSource('mysource', 0, 0, waveform, sfreq=250) - stc = s.to_stc(src) + s = PointSource('mysource', 0, 0, waveform) + stc = s.to_stc(src, tstep=0.01) assert stc.subject == 'meegsim', \ f"Expected stc.subject to be derived from src, got {stc.subject}" - stc = s.to_stc(src, subject='mysubject') + stc = s.to_stc(src, tstep=0.01, subject='mysubject') assert stc.subject == 'mysubject', \ f"Expected stc.subject to be mysubject, got {stc.subject}" @@ -104,9 +105,9 @@ def test_pointsource_to_stc_bad_src_raises(): ) # src[2] is out of range - s = PointSource('mysource', 2, 0, waveform, sfreq=250) - with pytest.raises(ValueError, match="not present in the provided src"): - s.to_stc(src, subject='mysubject') + s = PointSource('mysource', 2, 0, waveform) + with pytest.raises(ValueError, match="point source was assigned to source space 2"): + s.to_stc(src, tstep=0.01, subject='mysubject') def test_pointsource_to_stc_bad_vertno_raises(): @@ -117,9 +118,9 @@ def test_pointsource_to_stc_bad_vertno_raises(): ) # vertex 2 is not in src[0] - s = PointSource('mysource', 0, 2, waveform, sfreq=250) - with pytest.raises(ValueError, match="does not contain the vertex"): - s.to_stc(src, subject='mysubject') + s = PointSource('mysource', 0, 2, waveform) + with pytest.raises(ValueError, match="contain the following vertices: 2"): + s.to_stc(src, tstep=0.01, subject='mysubject') def test_pointsource_create_from_arrays(): @@ -188,7 +189,7 @@ def waveform_constant(n_sources, times, random_state=None): ) def test_patchsource_repr(src_idx, vertno, hemi): # Waveform is not required for repr, leaving it empty - s = PatchSource('mysource', src_idx, vertno, np.array([]), sfreq=250, hemi=hemi) + s = PatchSource('mysource', src_idx, vertno, np.array([]), hemi=hemi) if hemi is None: assert f'src[{src_idx}]' in repr(s) @@ -211,8 +212,8 @@ def test_patchsource_to_stc(src_idx, vertno): types=['surf', 'surf'], vertices=[[0, 1], [0, 1, 2]] ) - s = PatchSource('mysource', src_idx, vertno, waveform, sfreq=100) - stc = s.to_stc(src) + s = PatchSource('mysource', src_idx, vertno, waveform) + stc = s.to_stc(src, tstep=0.01) assert stc.data.shape[0] == 2, \ f"Expected two active vertices in stc, got {stc.data.shape[0]}" @@ -230,9 +231,9 @@ def test_patchsource_to_stc_bad_src_raises(): ) # src[2] is out of range - s = PatchSource('mysource', 2, [0, 1], waveform, sfreq=250) - with pytest.raises(ValueError, match="not present in the provided src"): - s.to_stc(src, subject='mysubject') + s = PatchSource('mysource', 2, [0, 1], waveform) + with pytest.raises(ValueError, match="patch source was assigned to source space 2"): + s.to_stc(src, tstep=0.01, subject='mysubject') def test_patchsource_to_stc_bad_vertno_raises(): @@ -243,9 +244,9 @@ def test_patchsource_to_stc_bad_vertno_raises(): ) # vertex 2 is not in src[0] - s = PatchSource('mysource', 0, [0, 2], waveform, sfreq=250) + s = PatchSource('mysource', 0, [0, 2], waveform) with pytest.raises(ValueError, match="does not contain the following vertices: 2"): - s.to_stc(src, subject='mysubject') + s.to_stc(src, tstep=0.01, subject='mysubject') def test_patch_source_with_extent(): @@ -322,17 +323,17 @@ def test_combine_sources_into_stc_point(): vertices=[[0, 1], [0, 1]] ) - s1 = PointSource('s1', 0, 0, np.ones((100,)), 250) - s2 = PointSource('s2', 0, 0, np.ones((100,)), 250) - s3 = PointSource('s3', 0, 1, np.ones((100,)), 250) + s1 = PointSource('s1', 0, 0, np.ones((100,))) + s2 = PointSource('s2', 0, 0, np.ones((100,))) + s3 = PointSource('s3', 0, 1, np.ones((100,))) # s1 and s2 are the same vertex, should be summed - stc1 = _combine_sources_into_stc([s1, s2], src) + stc1 = _combine_sources_into_stc([s1, s2], src, tstep=0.01) assert stc1.data.shape[0] == 1, 'Expected 1 active vertices in stc' assert np.all(stc1.data == 2), 'Expected source activity to be summed' # s1 and s3 are different vertices, should be concatenated - stc2 = _combine_sources_into_stc([s1, s3], src) + stc2 = _combine_sources_into_stc([s1, s3], src, tstep=0.01) assert stc2.data.shape[0] == 2, 'Expected 2 active vertices in stc' assert np.all(stc2.data == 1), 'Expected source activity not to be summed' @@ -343,16 +344,16 @@ def test_combine_sources_into_stc_patch(): vertices=[[0, 1], [0, 1]] ) - s1 = PatchSource('s1', 0, [0, 1], np.ones((100,)), 250) - s2 = PatchSource('s2', 1, [0, 1], np.ones((100,)), 250) - s3 = PatchSource('s3', 0, [0, 1], np.ones((100,)), 250) + s1 = PatchSource('s1', 0, [0, 1], np.ones((100,))) + s2 = PatchSource('s2', 1, [0, 1], np.ones((100,))) + s3 = PatchSource('s3', 0, [0, 1], np.ones((100,))) # s1 and s2 are the same vertex, should be summed - stc1 = _combine_sources_into_stc([s1, s2], src) + stc1 = _combine_sources_into_stc([s1, s2], src, tstep=0.01) assert stc1.data.shape[0] == 4, 'Expected 1 active vertices in stc' assert np.all(stc1.data == 1), 'Expected source activity not to be summed' # s1 and s3 are different vertices, should be concatenated - stc2 = _combine_sources_into_stc([s1, s3], src) + stc2 = _combine_sources_into_stc([s1, s3], src, tstep=0.01) assert stc2.data.shape[0] == 2, 'Expected 2 active vertices in stc' assert np.all(stc2.data == 2), 'Expected source activity to be summed' diff --git a/tests/test_utils.py b/tests/test_utils.py index 42c6fc0..f2e6666 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -4,9 +4,12 @@ from mne.io.constants import FIFF from meegsim.utils import ( - _extract_hemi, unpack_vertices, combine_stcs, normalize_power, get_sfreq + _extract_hemi, unpack_vertices, combine_stcs, normalize_power, + get_sfreq, vertices_to_mne ) +from utils.prepare import prepare_source_space + def test_unpack_single_list(): vertices_lists = [[1, 2, 3]] @@ -134,4 +137,24 @@ def test_extract_hemi_raises(): ] with pytest.raises(ValueError, match='Unexpected ID'): - _extract_hemi(src[0]) \ No newline at end of file + _extract_hemi(src[0]) + + +def test_vertices_to_mne(): + src = prepare_source_space( + ['surf', 'surf'], + [[0, 1, 2], [0, 1, 2]] + ) + + packed = vertices_to_mne([(0, 0)], src) + assert packed == [[0], []] + + # vertices should be sorted + packed = vertices_to_mne([(0, 2), (0, 0)], src) + assert packed == [[0, 2], []] + + packed = vertices_to_mne([(0, 0), (1, 2)], src) + assert packed == [[0], [2]] + + packed = vertices_to_mne([(1, 0), (1, 2)], src) + assert packed == [[], [0, 2]] \ No newline at end of file diff --git a/tests/utils/mocks.py b/tests/utils/mocks.py deleted file mode 100644 index 31ad02b..0000000 --- a/tests/utils/mocks.py +++ /dev/null @@ -1,14 +0,0 @@ -import numpy as np -import mne - - -class MockPointSource: - """ - Mock PointSource class for testing purposes. - """ - def __init__(self, name, shape=(1, 100)): - self.name = name - self.waveform = np.ones(shape) - - def to_stc(self, *args, **kwargs): - return mne.SourceEstimate(self.waveform, [[0], []], 0, 0.01) diff --git a/tests/utils/prepare.py b/tests/utils/prepare.py index 8f5ef5a..583e9ab 100644 --- a/tests/utils/prepare.py +++ b/tests/utils/prepare.py @@ -3,6 +3,8 @@ from mne.io.constants import FIFF +from meegsim.sources import PointSource, PatchSource + def prepare_source_space(types, vertices): assert len(types) == len(vertices), \ @@ -32,6 +34,7 @@ def prepare_source_space(types, vertices): nuse=int(n_verts), type=str(src_type), id=int(src_id), + coord_frame=FIFF.FIFFV_COORD_MRI, np=int(n_verts), subject_his_id='meegsim' ) @@ -92,3 +95,13 @@ def prepare_forward(n_channels, n_sources, fwd = mne.Forward(**forward) return fwd + + +def prepare_point_source(name, src_idx=0, vertno=0, n_samples=100): + waveform = np.ones((n_samples,)) + return PointSource(name, src_idx, vertno, waveform) + + +def prepare_patch_source(name, src_idx=0, vertno=[0, 1], n_samples=100): + waveform = np.ones((n_samples,)) + return PatchSource(name, src_idx, vertno, waveform)