diff --git a/src/meegsim/source_groups.py b/src/meegsim/source_groups.py index 2ca0970..7dd5d6b 100644 --- a/src/meegsim/source_groups.py +++ b/src/meegsim/source_groups.py @@ -205,18 +205,18 @@ def simulate(self, src, times, random_state=None): @classmethod def create( - cls, - src, - location, - waveform, - snr, - location_params, - waveform_params, - snr_params, - extents, - names, - group, - existing + cls, + src, + location, + waveform, + snr, + location_params, + waveform_params, + snr_params, + extents, + names, + group, + existing ): """ Check the provided input for all fields and create a source group that diff --git a/src/meegsim/sources.py b/src/meegsim/sources.py index e004b78..8891bd5 100644 --- a/src/meegsim/sources.py +++ b/src/meegsim/sources.py @@ -197,7 +197,9 @@ def __init__(self, name, src_idx, vertno, waveform, sfreq, hemi=None): def __repr__(self): # Use human readable names of hemispheres if possible src_desc = self.hemi if self.hemi else f'src[{self.src_idx}]' - return f'' + n_vertno = len(self.vertno) + vertno_desc = f'{n_vertno} vertex' if n_vertno == 1 else f'{n_vertno} vertices' + return f'' def to_stc(self, src, subject=None): """ @@ -231,11 +233,13 @@ def to_stc(self, src, subject=None): f"which is not present in the provided src object." ) - if len(set(self.vertno) - set(src[self.src_idx]['vertno'])) > 0: + missing_vertno = set(self.vertno) - set(src[self.src_idx]['vertno']) + if missing_vertno: + report_missing = ', '.join([str(v) for v in missing_vertno]) raise ValueError( f"The patch source cannot be added to the provided src. " f"The source space with index {self.src_idx} does not " - f"contain the vertex {self.vertno}" + f"contain the following vertices: {report_missing}" ) # Resolve the subject name as done in MNE @@ -248,7 +252,6 @@ def to_stc(self, src, subject=None): vertices[self.src_idx].extend(self.vertno) data = np.tile(self.waveform[np.newaxis, :], (len(self.vertno), 1)) - return mne.SourceEstimate( data=data, vertices=vertices, @@ -259,22 +262,19 @@ def to_stc(self, src, subject=None): @classmethod def create( - cls, - src, - times, - n_sources, - location, - waveform, - names, - extents, - random_state=None + cls, + src, + times, + n_sources, + location, + waveform, + names, + extents, + random_state=None ): """ This function creates patch sources according to the provided input. """ - # Check extents - extents = check_extents(extents, n_sources) - # Get the sampling frequency sfreq = get_sfreq(times) @@ -295,15 +295,21 @@ def create( subject = src[0].get("subject_his_id", None) patch_vertices = [] for isource, extent in enumerate(extents): - if extent is not None: - src_idx = vertices[isource][0] - vertno = vertices[isource][1] - patch = mne.grow_labels(subject, vertno, extent, src_idx, subjects_dir=None)[0] - # prune vertices - patch_vertno = [vert for vert in patch.vertices if vert in src[src_idx]['vertno']] - patch_vertices.append(patch_vertno) - else: # if locations is a label - patch_vertices.append([vertices[isource][1]]) + src_idx, vertno = vertices[isource] + + # Add vertices as they are if no extent provided + if extent is None: + # Wrap vertno in a list if it is a single number + vertno = vertno if isinstance(vertno, list) else [vertno] + patch_vertices.append(vertno) + continue + + # Grow the patch from center otherwise + patch = mne.grow_labels(subject, vertno, extent, src_idx, subjects_dir=None)[0] + + # Prune vertices + patch_vertno = [vert for vert in patch.vertices if vert in src[src_idx]['vertno']] + patch_vertices.append(patch_vertno) # Create patch sources and save them as a group sources = [] diff --git a/tests/test_sources.py b/tests/test_sources.py index 54a7469..b1ce886 100644 --- a/tests/test_sources.py +++ b/tests/test_sources.py @@ -174,71 +174,50 @@ def waveform_constant(n_sources, times, random_state=None): assert [s.name for s in sources] == names -def test_combine_sources_into_stc(): - src = prepare_source_space( - types=['surf', 'surf'], - vertices=[[0, 1], [0, 1]] - ) - - s1 = PointSource('s1', 0, 0, np.ones((100,)), 250) - s2 = PointSource('s2', 0, 0, np.ones((100,)), 250) - s3 = PointSource('s3', 0, 1, np.ones((100,)), 250) - - # s1 and s2 are the same vertex, should be summed - stc1 = _combine_sources_into_stc([s1, s2], src) - assert stc1.data.shape[0] == 1, 'Expected 1 active vertices in stc' - assert np.all(stc1.data == 2), 'Expected source activity to be summed' - - # s1 and s3 are different vertices, should be concatenated - stc2 = _combine_sources_into_stc([s1, s3], src) - assert stc2.data.shape[0] == 2, 'Expected 2 active vertices in stc' - assert np.all(stc2.data == 1), 'Expected source activity not to be summed' - - # ================================= # Patch source # ================================= + @pytest.mark.parametrize( - "src_idx,vertno,patch_vertno,hemi", [ - (0, 123, [123], None), - (0, 234, [234], 'lh'), - (1, 345, [345], None), - (1, 456, [456], 'rh'), + "src_idx,vertno,hemi", [ + (0, [123], None), + (0, [123, 234], 'lh'), + (1, [345], None), + (1, [123, 234, 345, 456], 'rh'), ] ) -def test_patchsource_repr(src_idx, vertno, patch_vertno, hemi): +def test_patchsource_repr(src_idx, vertno, hemi): # Waveform is not required for repr, leaving it empty - s = PatchSource('mysource', src_idx, vertno, patch_vertno, np.array([]), sfreq=250, extent=1, hemi=hemi) + s = PatchSource('mysource', src_idx, vertno, np.array([]), sfreq=250, hemi=hemi) if hemi is None: assert f'src[{src_idx}]' in repr(s) else: assert hemi in repr(s) - assert str(vertno) in repr(s) - assert str(patch_vertno) in repr(s) + assert str(len(vertno)) in repr(s) assert 'mysource' in repr(s) @pytest.mark.parametrize( - "src_idx,vertno,patch_vertno", [ - (0, 0, [0, 1]), - (1, 1, [1, 2]) + "src_idx,vertno", [ + (0, [0, 1]), + (1, [1, 2]) ] ) -def test_patchsource_to_stc(src_idx, vertno, patch_vertno): +def test_patchsource_to_stc(src_idx, vertno): waveform = np.ones((100,)) src = prepare_source_space( types=['surf', 'surf'], - vertices=[[0, 1], [0, 1]] + vertices=[[0, 1], [0, 1, 2]] ) - s = PatchSource('mysource', src_idx, vertno, patch_vertno, waveform, sfreq=100, extent=1) + s = PatchSource('mysource', src_idx, vertno, waveform, sfreq=100) stc = s.to_stc(src) assert stc.data.shape[0] == 2, \ - f"Expected one active vertex in stc, got {stc.data.shape[0]}" - assert vertno in stc.vertices[src_idx], \ - f"Expected the vertex to be put in src {src_idx}, but it is not there" + f"Expected two active vertices in stc, got {stc.data.shape[0]}" + assert np.all(vertno in stc.vertices[src_idx]), \ + f"Expected all vertno to be put in src {src_idx}" assert np.allclose(stc.data, waveform), \ f"The source waveform should not change during conversion to stc" @@ -251,7 +230,7 @@ def test_patchsource_to_stc_bad_src_raises(): ) # src[2] is out of range - s = PatchSource('mysource', 2, 0, [0, 1], waveform, sfreq=250, extent=1) + s = PatchSource('mysource', 2, [0, 1], waveform, sfreq=250) with pytest.raises(ValueError, match="not present in the provided src"): s.to_stc(src, subject='mysubject') @@ -264,32 +243,11 @@ def test_patchsource_to_stc_bad_vertno_raises(): ) # vertex 2 is not in src[0] - s = PatchSource('mysource', 0, 2, [0, 2], waveform, sfreq=250, extent=1) - with pytest.raises(ValueError, match="does not contain the vertex"): + s = PatchSource('mysource', 0, [0, 2], waveform, sfreq=250) + with pytest.raises(ValueError, match="does not contain the following vertices: 2"): s.to_stc(src, subject='mysubject') -def test_combine_patchsources_into_stc(): - src = prepare_source_space( - types=['surf', 'surf'], - vertices=[[0, 1], [0, 1]] - ) - - s1 = PatchSource('s1', 0, 0, [0, 1], np.ones((100,)), 250, extent=1) - s2 = PatchSource('s2', 1, 0, [0, 1], np.ones((100,)), 250, extent=1) - s3 = PatchSource('s3', 0, 1, [0, 1], np.ones((100,)), 250, extent=1) - - # s1 and s2 are the same vertex, should be summed - stc1 = _combine_sources_into_stc([s1, s2], src) - assert stc1.data.shape[0] == 4, 'Expected 1 active vertices in stc' - assert np.all(stc1.data == 1), 'Expected source activity not to be summed' - - # s1 and s3 are different vertices, should be concatenated - stc2 = _combine_sources_into_stc([s1, s3], src) - assert stc2.data.shape[0] == 2, 'Expected 2 active vertices in stc' - assert np.all(stc2.data == 2), 'Expected source activity to be summed' - - def test_patch_source_with_extent(): """Test that PatchSource properly handles 'extent' parameter.""" @@ -315,12 +273,12 @@ def test_patch_source_with_extent(): extents = [3, None] # First source has extent, second has no extent # Mock grow_labels return values (for the first source with extent) - patch_vertno = [2, 3, 4] # Vertices grown for the first source + vertno = [2, 3, 4] # Vertices grown for the first source with patch('mne.grow_labels') as mock_grow_labels: # Mock grow_labels to return the desired vertices - mock_grow_labels.return_value = [MagicMock(vertices=patch_vertno)] + mock_grow_labels.return_value = [MagicMock(vertices=vertno)] # Call the create method sources = PatchSource.create( @@ -341,18 +299,60 @@ def test_patch_source_with_extent(): source_1 = sources[0] assert source_1.name == "source_1", "First source name mismatch" assert source_1.src_idx == 0, "First source src_idx mismatch" - assert source_1.vertno == 2, "First source vertno mismatch" - assert source_1.patch_vertno == patch_vertno, "First source patch_vertno mismatch" - assert source_1.extent == 3, "First source extent mismatch" + assert source_1.vertno == vertno, "First source vertno mismatch" # Check the second source (without extent) source_2 = sources[1] assert source_2.name == "source_2", "Second source name mismatch" assert source_2.src_idx == 1, "Second source src_idx mismatch" - assert source_2.vertno == 8, "Second source vertno mismatch" - assert source_2.patch_vertno == [], "Second source patch_vertno should be empty" - assert source_2.extent is None, "Second source extent mismatch" + assert source_2.vertno == [8], "Second source vertno mismatch" # Verify that grow_labels was called once for the source with extent mock_grow_labels.assert_called_once_with('meegsim', 2, 3, 0, subjects_dir=None) + +### +# _combine_sources_into_stc +### + + +def test_combine_sources_into_stc_point(): + src = prepare_source_space( + types=['surf', 'surf'], + vertices=[[0, 1], [0, 1]] + ) + + s1 = PointSource('s1', 0, 0, np.ones((100,)), 250) + s2 = PointSource('s2', 0, 0, np.ones((100,)), 250) + s3 = PointSource('s3', 0, 1, np.ones((100,)), 250) + + # s1 and s2 are the same vertex, should be summed + stc1 = _combine_sources_into_stc([s1, s2], src) + assert stc1.data.shape[0] == 1, 'Expected 1 active vertices in stc' + assert np.all(stc1.data == 2), 'Expected source activity to be summed' + + # s1 and s3 are different vertices, should be concatenated + stc2 = _combine_sources_into_stc([s1, s3], src) + assert stc2.data.shape[0] == 2, 'Expected 2 active vertices in stc' + assert np.all(stc2.data == 1), 'Expected source activity not to be summed' + + +def test_combine_sources_into_stc_patch(): + src = prepare_source_space( + types=['surf', 'surf'], + vertices=[[0, 1], [0, 1]] + ) + + s1 = PatchSource('s1', 0, [0, 1], np.ones((100,)), 250) + s2 = PatchSource('s2', 1, [0, 1], np.ones((100,)), 250) + s3 = PatchSource('s3', 0, [0, 1], np.ones((100,)), 250) + + # s1 and s2 are the same vertex, should be summed + stc1 = _combine_sources_into_stc([s1, s2], src) + assert stc1.data.shape[0] == 4, 'Expected 1 active vertices in stc' + assert np.all(stc1.data == 1), 'Expected source activity not to be summed' + + # s1 and s3 are different vertices, should be concatenated + stc2 = _combine_sources_into_stc([s1, s3], src) + assert stc2.data.shape[0] == 2, 'Expected 2 active vertices in stc' + assert np.all(stc2.data == 2), 'Expected source activity to be summed'