-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] PCA flip for volumetric source estimates #13092
base: main
Are you sure you want to change the base?
Changes from 9 commits
bcc07f1
99b593a
00e0487
3b3b754
d88591f
2e56dcb
a677854
8a74ffe
dd15522
6839d73
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
Add PCA-flip to pool sources in source reconstruction in :func:`mne.extract_label_time_course`, by :newcontrib:`Fabrice Guibert`. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1460,22 +1460,47 @@ def label_sign_flip(label, src): | |
flip : array | ||
Sign flip vector (contains 1 or -1). | ||
""" | ||
if len(src) != 2: | ||
raise ValueError("Only source spaces with 2 hemisphers are accepted") | ||
if len(src) > 2 or len(src) == 0: | ||
raise ValueError( | ||
"Only source spaces with between one and two " | ||
+ "hemispheres are accepted, was {len(src)}" | ||
) | ||
|
||
lh_vertno = src[0]["vertno"] | ||
rh_vertno = src[1]["vertno"] | ||
if len(src) == 1 and label.hemi == "both": | ||
raise ValueError( | ||
'Cannot use hemisphere label "both" when source' | ||
+ "space contains a single hemisphere." | ||
) | ||
|
||
isbi_hemi = len(src) == 2 | ||
lh_vertno = None | ||
rh_vertno = None | ||
|
||
lh_id = -1 | ||
rh_id = -1 | ||
if isbi_hemi: | ||
lh_id = 0 | ||
rh_id = 1 | ||
lh_vertno = src[0]["vertno"] | ||
rh_vertno = src[1]["vertno"] | ||
elif label.hemi == "lh": | ||
lh_vertno = src[0]["vertno"] | ||
elif label.hemi == "rh": | ||
rh_id = 0 | ||
rh_vertno = src[0]["vertno"] | ||
else: | ||
raise Exception(f'Unknown hemisphere type "{label.hemi}"') | ||
Comment on lines
+1481
to
+1492
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't totally follow this logic. It's a surface source space if and only if it's bihemi, and it's volume source space if and only if it's a single-element list. I think it would be cleaner to triage based on source space type probably... |
||
|
||
# get source orientations | ||
ori = list() | ||
if label.hemi in ("lh", "both"): | ||
vertices = label.vertices if label.hemi == "lh" else label.lh.vertices | ||
vertno_sel = np.intersect1d(lh_vertno, vertices) | ||
ori.append(src[0]["nn"][vertno_sel]) | ||
ori.append(src[lh_id]["nn"][vertno_sel]) | ||
if label.hemi in ("rh", "both"): | ||
vertices = label.vertices if label.hemi == "rh" else label.rh.vertices | ||
vertno_sel = np.intersect1d(rh_vertno, vertices) | ||
ori.append(src[1]["nn"][vertno_sel]) | ||
ori.append(src[rh_id]["nn"][vertno_sel]) | ||
if len(ori) == 0: | ||
raise Exception(f'Unknown hemisphere type "{label.hemi}"') | ||
ori = np.concatenate(ori, axis=0) | ||
|
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -3376,12 +3376,19 @@ def _get_ico_tris(grade, verbose=None, return_surf=False): | |||||||||
|
||||||||||
|
||||||||||
def _pca_flip(flip, data): | ||||||||||
U, s, V = _safe_svd(data, full_matrices=False) | ||||||||||
# determine sign-flip | ||||||||||
sign = np.sign(np.dot(U[:, 0], flip)) | ||||||||||
# use average power in label for scaling | ||||||||||
scale = np.linalg.norm(s) / np.sqrt(len(data)) | ||||||||||
return sign * scale * V[0] | ||||||||||
result = None | ||||||||||
if flip is None: | ||||||||||
result = 0 | ||||||||||
elif data.shape[0] < 2: | ||||||||||
result = data.mean(axis=0) # Trivial accumulator | ||||||||||
else: | ||||||||||
U, s, V = _safe_svd(data, full_matrices=False) | ||||||||||
# determine sign-flip | ||||||||||
sign = np.sign(np.dot(U[:, 0], flip)) | ||||||||||
# use average power in label for scaling | ||||||||||
scale = np.linalg.norm(s) / np.sqrt(len(data)) | ||||||||||
result = sign * scale * V[0] | ||||||||||
return result | ||||||||||
|
||||||||||
|
||||||||||
_label_funcs = { | ||||||||||
|
@@ -3433,6 +3440,10 @@ def _prepare_label_extraction(stc, labels, src, mode, allow_empty, use_sparse): | |||||||||
# only computes vertex indices and label_flip will be list of None. | ||||||||||
from .label import BiHemiLabel, Label, label_sign_flip | ||||||||||
|
||||||||||
# logger.info("Selected mode: " + mode) | ||||||||||
# print("Entering _prepare_label_extraction") | ||||||||||
# print("Selected mode: " + mode) | ||||||||||
Comment on lines
+3443
to
+3445
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Cruft, should be removed. But really to help next time:
Suggested change
then if you run with |
||||||||||
|
||||||||||
# if source estimate provided in stc, get vertices from source space and | ||||||||||
# check that they are the same as in the stcs | ||||||||||
_check_stc_src(stc, src) | ||||||||||
|
@@ -3445,6 +3456,7 @@ def _prepare_label_extraction(stc, labels, src, mode, allow_empty, use_sparse): | |||||||||
|
||||||||||
bad_labels = list() | ||||||||||
for li, label in enumerate(labels): | ||||||||||
# print("Mode: " + mode + " li: " + str(li) + " label: " + str(label)) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
subject = label["subject"] if use_sparse else label.subject | ||||||||||
# stc and src can each be None | ||||||||||
_check_subject( | ||||||||||
|
@@ -3510,6 +3522,9 @@ def _prepare_label_extraction(stc, labels, src, mode, allow_empty, use_sparse): | |||||||||
# So if we override vertno with the stc vertices, it will pick | ||||||||||
# the correct normals. | ||||||||||
with _temporary_vertices(src, stc.vertices): | ||||||||||
# print(f"src: {src[:2]}") | ||||||||||
# print(f"len(src): {len(src[:2])}") | ||||||||||
|
||||||||||
Comment on lines
+3525
to
+3527
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
this_flip = label_sign_flip(label, src[:2])[:, None] | ||||||||||
|
||||||||||
label_vertidx.append(this_vertidx) | ||||||||||
|
@@ -3644,8 +3659,10 @@ def _get_default_label_modes(): | |||||||||
|
||||||||||
|
||||||||||
def _get_allowed_label_modes(stc): | ||||||||||
if isinstance(stc, _BaseVolSourceEstimate | _BaseVectorSourceEstimate): | ||||||||||
if isinstance(stc, _BaseVectorSourceEstimate): | ||||||||||
return ("mean", "max", "auto") | ||||||||||
elif isinstance(stc, _BaseVolSourceEstimate): | ||||||||||
return ("mean", "pca_flip", "max", "auto") | ||||||||||
else: | ||||||||||
return _get_default_label_modes() | ||||||||||
|
||||||||||
|
@@ -3732,6 +3749,10 @@ def _gen_extract_label_time_course( | |||||||||
this_data.shape = (this_data.shape[0],) + stc.data.shape[1:] | ||||||||||
else: | ||||||||||
this_data = stc.data[vertidx] | ||||||||||
# if flip is None: # Happens if fewer than 2 vertices in the label | ||||||||||
# if this_data.shape[] | ||||||||||
# label_tc[i] = 0 | ||||||||||
# else: | ||||||||||
Comment on lines
+3752
to
+3755
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
label_tc[i] = func(flip, this_data) | ||||||||||
|
||||||||||
if mode is not None: | ||||||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -2,7 +2,6 @@ | |||||
# Authors: The MNE-Python contributors. | ||||||
# License: BSD-3-Clause | ||||||
# Copyright the MNE-Python contributors. | ||||||
|
||||||
import os | ||||||
import re | ||||||
from contextlib import nullcontext | ||||||
|
@@ -679,6 +678,150 @@ def test_center_of_mass(): | |||||
assert_equal(np.round(t, 2), 0.12) | ||||||
|
||||||
|
||||||
@testing.requires_testing_data | ||||||
@pytest.mark.parametrize( | ||||||
"label_type, mri_res, test_label, cf, call", | ||||||
[ | ||||||
(str, False, False, "head", "meth"), # head frame | ||||||
(str, False, str, "mri", "func"), # fastest, default for testing | ||||||
(str, True, str, "mri", "func"), # fastest, default for testing | ||||||
(str, True, False, "mri", "func"), # mri_resolution | ||||||
(list, True, False, "mri", "func"), # volume label as list | ||||||
(dict, True, False, "mri", "func"), # volume label as dict | ||||||
], | ||||||
) | ||||||
def test_extract_label_time_course_volume_pca_flip( | ||||||
src_volume_labels, label_type, mri_res, test_label, cf, call | ||||||
): | ||||||
"""Test extraction of label timecourses on VolumetricSourceEstimate with PCA.""" | ||||||
# Setup of data | ||||||
src_labels, volume_labels, lut = src_volume_labels | ||||||
n_tot = 46 | ||||||
assert n_tot == len(src_labels) | ||||||
inv = read_inverse_operator(fname_inv_vol) | ||||||
if cf == "head": | ||||||
src = inv["src"] | ||||||
else: | ||||||
src = read_source_spaces(fname_src_vol) | ||||||
klass = VolVectorSourceEstimate._scalar_class | ||||||
vertices = [src[0]["vertno"]] | ||||||
n_verts = len(src[0]["vertno"]) | ||||||
n_times = 50 | ||||||
data = np.arange(1, n_verts + 1) | ||||||
end_shape = (n_times,) | ||||||
data = np.repeat(data[..., np.newaxis], n_times, -1) | ||||||
stcs = [klass(data.astype(float), vertices, 0, 1)] | ||||||
|
||||||
def eltc(*args, **kwargs): | ||||||
if call == "func": | ||||||
return extract_label_time_course(stcs, *args, **kwargs) | ||||||
else: | ||||||
return [stcs[0].extract_label_time_course(*args, **kwargs)] | ||||||
|
||||||
# triage "labels" argument | ||||||
if mri_res: | ||||||
# All should be there | ||||||
missing = [] | ||||||
else: | ||||||
# Nearest misses these | ||||||
missing = [ | ||||||
"Left-vessel", | ||||||
"Right-vessel", | ||||||
"5th-Ventricle", | ||||||
"non-WM-hypointensities", | ||||||
] | ||||||
n_want = len(src_labels) | ||||||
if label_type is str: | ||||||
labels = fname_aseg | ||||||
elif label_type is list: | ||||||
labels = (fname_aseg, volume_labels) | ||||||
else: | ||||||
assert label_type is dict | ||||||
labels = (fname_aseg, {k: lut[k] for k in volume_labels}) | ||||||
assert mri_res | ||||||
assert len(missing) == 0 | ||||||
# we're going to add one that won't exist | ||||||
missing = ["intentionally_bad"] | ||||||
labels[1][missing[0]] = 10000 | ||||||
n_want += 1 | ||||||
n_tot += 1 | ||||||
n_want -= len(missing) | ||||||
|
||||||
# _volume_labels(src, labels, mri_resolution) | ||||||
# actually do the testing | ||||||
from mne.source_estimate import _pca_flip, _prepare_label_extraction, _volume_labels | ||||||
|
||||||
labels_expanded = _volume_labels(src, labels, mri_res) | ||||||
_, src_flip = _prepare_label_extraction( | ||||||
stcs[0], labels_expanded, src, "pca_flip", "ignore", bool(mri_res) | ||||||
) | ||||||
|
||||||
mode = "pca_flip" | ||||||
with catch_logging() as log: | ||||||
label_tc = eltc( | ||||||
labels, | ||||||
src, | ||||||
mode=mode, | ||||||
allow_empty="ignore", | ||||||
mri_resolution=mri_res, | ||||||
verbose=True, | ||||||
) | ||||||
log = log.getvalue() | ||||||
assert re.search("^Reading atlas.*aseg\\.mgz\n", log) is not None | ||||||
if len(missing): | ||||||
# assert that the missing ones get logged | ||||||
assert "does not contain" in log | ||||||
assert repr(missing) in log | ||||||
else: | ||||||
assert "does not contain" not in log | ||||||
assert f"\n{n_want}/{n_tot} atlas regions had at least" in log | ||||||
assert len(label_tc) == 1 | ||||||
label_tc = label_tc[0] | ||||||
assert label_tc.shape == (n_tot,) + end_shape | ||||||
assert label_tc.shape == (n_tot, n_times) | ||||||
# let's test some actual values by trusting the masks provided by | ||||||
# setup_volume_source_space. mri_resolution=True does some | ||||||
# interpolation so we should not expect equivalence, False does | ||||||
# nearest so we should. | ||||||
if mri_res: | ||||||
rtol = 0.8 # max much more sensitive | ||||||
else: | ||||||
rtol = 0.0 | ||||||
for si, s in enumerate(src_labels): | ||||||
func = _pca_flip | ||||||
these = data[np.isin(src[0]["vertno"], s["vertno"])] | ||||||
print(these.shape) | ||||||
assert len(these) == s["nuse"] | ||||||
if si == 0 and s["seg_name"] == "Unknown": | ||||||
continue # unknown is crappy | ||||||
if s["nuse"] == 0: | ||||||
want = 0.0 | ||||||
if mri_res: | ||||||
# this one is totally due to interpolation, so no easy | ||||||
# test here | ||||||
continue | ||||||
else: | ||||||
if src_flip[si] is None: | ||||||
want = None | ||||||
else: | ||||||
want = func(src_flip[si], these) | ||||||
if want is not None: | ||||||
assert_allclose(label_tc[si], want, atol=1e-6, rtol=rtol) | ||||||
# compare with in_label, only on every fourth for speed | ||||||
if test_label is not False and si % 4 == 0: | ||||||
label = s["seg_name"] | ||||||
if test_label is int: | ||||||
label = lut[label] | ||||||
in_label = stcs[0].in_label(label, fname_aseg, src).data | ||||||
assert in_label.shape == (s["nuse"],) + end_shape | ||||||
if np.all(want == 0): | ||||||
assert in_label.shape[0] == 0 | ||||||
else: | ||||||
if src_flip[si] is not None: | ||||||
in_label = func(src_flip[si], in_label) | ||||||
assert_allclose(in_label, want, atol=1e-6, rtol=rtol) | ||||||
|
||||||
|
||||||
@testing.requires_testing_data | ||||||
@pytest.mark.parametrize("kind", ("surface", "mixed")) | ||||||
@pytest.mark.parametrize("vector", (False, True)) | ||||||
|
@@ -943,7 +1086,8 @@ def eltc(*args, **kwargs): | |||||
if cf == "head" and not mri_res: # some missing | ||||||
with pytest.warns(RuntimeWarning, match="any vertices"): | ||||||
eltc(labels, src, allow_empty=True, mri_resolution=mri_res) | ||||||
for mode in ("mean", "max"): | ||||||
modes = ("mean", "max") if vector else ("mean", "max") | ||||||
for mode in modes: | ||||||
Comment on lines
+1090
to
+1091
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Revert? |
||||||
with catch_logging() as log: | ||||||
label_tc = eltc( | ||||||
labels, | ||||||
|
@@ -1469,7 +1613,7 @@ def objective(x): | |||||
assert_allclose(directions, want_nn, atol=2e-6) | ||||||
|
||||||
|
||||||
@testing.requires_testing_data | ||||||
# @testing.requires_testing_data | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
def test_source_estime_project_label(): | ||||||
"""Test projecting a source estimate onto direction of max power.""" | ||||||
fwd = read_forward_solution(fname_fwd) | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A better / more modern check would be something like:
And incidentally I think eventually we could add support for segmented volume source spaces, as well as mixed source spaces (once surface + volume are fully supported, mixed isn't too bad after that). But probably not as part of this PR!