Skip to content

Commit

Permalink
Cleanup after merge, fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ctrltz committed Oct 2, 2024
1 parent fb1c9fd commit 5b7dbf4
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 108 deletions.
24 changes: 12 additions & 12 deletions src/meegsim/source_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,18 +205,18 @@ def simulate(self, src, times, random_state=None):

@classmethod
def create(
cls,
src,
location,
waveform,
snr,
location_params,
waveform_params,
snr_params,
extents,
names,
group,
existing
cls,
src,
location,
waveform,
snr,
location_params,
waveform_params,
snr_params,
extents,
names,
group,
existing
):
"""
Check the provided input for all fields and create a source group that
Expand Down
56 changes: 31 additions & 25 deletions src/meegsim/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,9 @@ def __init__(self, name, src_idx, vertno, waveform, sfreq, hemi=None):
def __repr__(self):
# Use human readable names of hemispheres if possible
src_desc = self.hemi if self.hemi else f'src[{self.src_idx}]'
return f'<PatchSource | {self.name} | {src_desc} | {self.vertno} >'
n_vertno = len(self.vertno)
vertno_desc = f'{n_vertno} vertex' if n_vertno == 1 else f'{n_vertno} vertices'
return f'<PatchSource | {self.name} | {src_desc} | {vertno_desc} >'

def to_stc(self, src, subject=None):
"""
Expand Down Expand Up @@ -231,11 +233,13 @@ def to_stc(self, src, subject=None):
f"which is not present in the provided src object."
)

if len(set(self.vertno) - set(src[self.src_idx]['vertno'])) > 0:
missing_vertno = set(self.vertno) - set(src[self.src_idx]['vertno'])
if missing_vertno:
report_missing = ', '.join([str(v) for v in missing_vertno])
raise ValueError(
f"The patch source cannot be added to the provided src. "
f"The source space with index {self.src_idx} does not "
f"contain the vertex {self.vertno}"
f"contain the following vertices: {report_missing}"
)

# Resolve the subject name as done in MNE
Expand All @@ -248,7 +252,6 @@ def to_stc(self, src, subject=None):
vertices[self.src_idx].extend(self.vertno)
data = np.tile(self.waveform[np.newaxis, :], (len(self.vertno), 1))


return mne.SourceEstimate(
data=data,
vertices=vertices,
Expand All @@ -259,22 +262,19 @@ def to_stc(self, src, subject=None):

@classmethod
def create(
cls,
src,
times,
n_sources,
location,
waveform,
names,
extents,
random_state=None
cls,
src,
times,
n_sources,
location,
waveform,
names,
extents,
random_state=None
):
"""
This function creates patch sources according to the provided input.
"""
# Check extents
extents = check_extents(extents, n_sources)


# Get the sampling frequency
sfreq = get_sfreq(times)
Expand All @@ -295,15 +295,21 @@ def create(
subject = src[0].get("subject_his_id", None)
patch_vertices = []
for isource, extent in enumerate(extents):
if extent is not None:
src_idx = vertices[isource][0]
vertno = vertices[isource][1]
patch = mne.grow_labels(subject, vertno, extent, src_idx, subjects_dir=None)[0]
# prune vertices
patch_vertno = [vert for vert in patch.vertices if vert in src[src_idx]['vertno']]
patch_vertices.append(patch_vertno)
else: # if locations is a label
patch_vertices.append([vertices[isource][1]])
src_idx, vertno = vertices[isource]

# Add vertices as they are if no extent provided
if extent is None:
# Wrap vertno in a list if it is a single number
vertno = vertno if isinstance(vertno, list) else [vertno]
patch_vertices.append(vertno)
continue

# Grow the patch from center otherwise
patch = mne.grow_labels(subject, vertno, extent, src_idx, subjects_dir=None)[0]

# Prune vertices
patch_vertno = [vert for vert in patch.vertices if vert in src[src_idx]['vertno']]
patch_vertices.append(patch_vertno)

# Create patch sources and save them as a group
sources = []
Expand Down
142 changes: 71 additions & 71 deletions tests/test_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,71 +174,50 @@ def waveform_constant(n_sources, times, random_state=None):
assert [s.name for s in sources] == names


def test_combine_sources_into_stc():
src = prepare_source_space(
types=['surf', 'surf'],
vertices=[[0, 1], [0, 1]]
)

s1 = PointSource('s1', 0, 0, np.ones((100,)), 250)
s2 = PointSource('s2', 0, 0, np.ones((100,)), 250)
s3 = PointSource('s3', 0, 1, np.ones((100,)), 250)

# s1 and s2 are the same vertex, should be summed
stc1 = _combine_sources_into_stc([s1, s2], src)
assert stc1.data.shape[0] == 1, 'Expected 1 active vertices in stc'
assert np.all(stc1.data == 2), 'Expected source activity to be summed'

# s1 and s3 are different vertices, should be concatenated
stc2 = _combine_sources_into_stc([s1, s3], src)
assert stc2.data.shape[0] == 2, 'Expected 2 active vertices in stc'
assert np.all(stc2.data == 1), 'Expected source activity not to be summed'


# =================================
# Patch source
# =================================

@pytest.mark.parametrize(
"src_idx,vertno,patch_vertno,hemi", [
(0, 123, [123], None),
(0, 234, [234], 'lh'),
(1, 345, [345], None),
(1, 456, [456], 'rh'),
"src_idx,vertno,hemi", [
(0, [123], None),
(0, [123, 234], 'lh'),
(1, [345], None),
(1, [123, 234, 345, 456], 'rh'),
]
)
def test_patchsource_repr(src_idx, vertno, patch_vertno, hemi):
def test_patchsource_repr(src_idx, vertno, hemi):
# Waveform is not required for repr, leaving it empty
s = PatchSource('mysource', src_idx, vertno, patch_vertno, np.array([]), sfreq=250, extent=1, hemi=hemi)
s = PatchSource('mysource', src_idx, vertno, np.array([]), sfreq=250, hemi=hemi)

if hemi is None:
assert f'src[{src_idx}]' in repr(s)
else:
assert hemi in repr(s)

assert str(vertno) in repr(s)
assert str(patch_vertno) in repr(s)
assert str(len(vertno)) in repr(s)
assert 'mysource' in repr(s)


@pytest.mark.parametrize(
"src_idx,vertno,patch_vertno", [
(0, 0, [0, 1]),
(1, 1, [1, 2])
"src_idx,vertno", [
(0, [0, 1]),
(1, [1, 2])
]
)
def test_patchsource_to_stc(src_idx, vertno, patch_vertno):
def test_patchsource_to_stc(src_idx, vertno):
waveform = np.ones((100,))
src = prepare_source_space(
types=['surf', 'surf'],
vertices=[[0, 1], [0, 1]]
vertices=[[0, 1], [0, 1, 2]]
)
s = PatchSource('mysource', src_idx, vertno, patch_vertno, waveform, sfreq=100, extent=1)
s = PatchSource('mysource', src_idx, vertno, waveform, sfreq=100)
stc = s.to_stc(src)

assert stc.data.shape[0] == 2, \
f"Expected one active vertex in stc, got {stc.data.shape[0]}"
assert vertno in stc.vertices[src_idx], \
f"Expected the vertex to be put in src {src_idx}, but it is not there"
f"Expected two active vertices in stc, got {stc.data.shape[0]}"
assert np.all(vertno in stc.vertices[src_idx]), \
f"Expected all vertno to be put in src {src_idx}"
assert np.allclose(stc.data, waveform), \
f"The source waveform should not change during conversion to stc"

Expand All @@ -251,7 +230,7 @@ def test_patchsource_to_stc_bad_src_raises():
)

# src[2] is out of range
s = PatchSource('mysource', 2, 0, [0, 1], waveform, sfreq=250, extent=1)
s = PatchSource('mysource', 2, [0, 1], waveform, sfreq=250)
with pytest.raises(ValueError, match="not present in the provided src"):
s.to_stc(src, subject='mysubject')

Expand All @@ -264,32 +243,11 @@ def test_patchsource_to_stc_bad_vertno_raises():
)

# vertex 2 is not in src[0]
s = PatchSource('mysource', 0, 2, [0, 2], waveform, sfreq=250, extent=1)
with pytest.raises(ValueError, match="does not contain the vertex"):
s = PatchSource('mysource', 0, [0, 2], waveform, sfreq=250)
with pytest.raises(ValueError, match="does not contain the following vertices: 2"):
s.to_stc(src, subject='mysubject')


def test_combine_patchsources_into_stc():
src = prepare_source_space(
types=['surf', 'surf'],
vertices=[[0, 1], [0, 1]]
)

s1 = PatchSource('s1', 0, 0, [0, 1], np.ones((100,)), 250, extent=1)
s2 = PatchSource('s2', 1, 0, [0, 1], np.ones((100,)), 250, extent=1)
s3 = PatchSource('s3', 0, 1, [0, 1], np.ones((100,)), 250, extent=1)

# s1 and s2 are the same vertex, should be summed
stc1 = _combine_sources_into_stc([s1, s2], src)
assert stc1.data.shape[0] == 4, 'Expected 1 active vertices in stc'
assert np.all(stc1.data == 1), 'Expected source activity not to be summed'

# s1 and s3 are different vertices, should be concatenated
stc2 = _combine_sources_into_stc([s1, s3], src)
assert stc2.data.shape[0] == 2, 'Expected 2 active vertices in stc'
assert np.all(stc2.data == 2), 'Expected source activity to be summed'


def test_patch_source_with_extent():
"""Test that PatchSource properly handles 'extent' parameter."""

Expand All @@ -315,12 +273,12 @@ def test_patch_source_with_extent():
extents = [3, None] # First source has extent, second has no extent

# Mock grow_labels return values (for the first source with extent)
patch_vertno = [2, 3, 4] # Vertices grown for the first source
vertno = [2, 3, 4] # Vertices grown for the first source

with patch('mne.grow_labels') as mock_grow_labels:

# Mock grow_labels to return the desired vertices
mock_grow_labels.return_value = [MagicMock(vertices=patch_vertno)]
mock_grow_labels.return_value = [MagicMock(vertices=vertno)]

# Call the create method
sources = PatchSource.create(
Expand All @@ -341,18 +299,60 @@ def test_patch_source_with_extent():
source_1 = sources[0]
assert source_1.name == "source_1", "First source name mismatch"
assert source_1.src_idx == 0, "First source src_idx mismatch"
assert source_1.vertno == 2, "First source vertno mismatch"
assert source_1.patch_vertno == patch_vertno, "First source patch_vertno mismatch"
assert source_1.extent == 3, "First source extent mismatch"
assert source_1.vertno == vertno, "First source vertno mismatch"

# Check the second source (without extent)
source_2 = sources[1]
assert source_2.name == "source_2", "Second source name mismatch"
assert source_2.src_idx == 1, "Second source src_idx mismatch"
assert source_2.vertno == 8, "Second source vertno mismatch"
assert source_2.patch_vertno == [], "Second source patch_vertno should be empty"
assert source_2.extent is None, "Second source extent mismatch"
assert source_2.vertno == [8], "Second source vertno mismatch"

# Verify that grow_labels was called once for the source with extent
mock_grow_labels.assert_called_once_with('meegsim', 2, 3, 0, subjects_dir=None)


###
# _combine_sources_into_stc
###


def test_combine_sources_into_stc_point():
src = prepare_source_space(
types=['surf', 'surf'],
vertices=[[0, 1], [0, 1]]
)

s1 = PointSource('s1', 0, 0, np.ones((100,)), 250)
s2 = PointSource('s2', 0, 0, np.ones((100,)), 250)
s3 = PointSource('s3', 0, 1, np.ones((100,)), 250)

# s1 and s2 are the same vertex, should be summed
stc1 = _combine_sources_into_stc([s1, s2], src)
assert stc1.data.shape[0] == 1, 'Expected 1 active vertices in stc'
assert np.all(stc1.data == 2), 'Expected source activity to be summed'

# s1 and s3 are different vertices, should be concatenated
stc2 = _combine_sources_into_stc([s1, s3], src)
assert stc2.data.shape[0] == 2, 'Expected 2 active vertices in stc'
assert np.all(stc2.data == 1), 'Expected source activity not to be summed'


def test_combine_sources_into_stc_patch():
src = prepare_source_space(
types=['surf', 'surf'],
vertices=[[0, 1], [0, 1]]
)

s1 = PatchSource('s1', 0, [0, 1], np.ones((100,)), 250)
s2 = PatchSource('s2', 1, [0, 1], np.ones((100,)), 250)
s3 = PatchSource('s3', 0, [0, 1], np.ones((100,)), 250)

# s1 and s2 are the same vertex, should be summed
stc1 = _combine_sources_into_stc([s1, s2], src)
assert stc1.data.shape[0] == 4, 'Expected 1 active vertices in stc'
assert np.all(stc1.data == 1), 'Expected source activity not to be summed'

# s1 and s3 are different vertices, should be concatenated
stc2 = _combine_sources_into_stc([s1, s3], src)
assert stc2.data.shape[0] == 2, 'Expected 2 active vertices in stc'
assert np.all(stc2.data == 2), 'Expected source activity to be summed'

0 comments on commit 5b7dbf4

Please sign in to comment.