Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] PCA flip for volumetric source estimates #13092

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions doc/changes/devel/13092.newfeature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add PCA-flip to pool sources in source reconstruction in :func:`mne.extract_label_time_course`, by :newcontrib:`Fabrice Guibert`.
1 change: 1 addition & 0 deletions doc/changes/names.inc
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
.. _Evgeny Goldstein: https://github.com/evgenygoldstein
.. _Ezequiel Mikulan: https://github.com/ezemikulan
.. _Ezequiel Mikulan: https://github.com/ezemikulan
.. _Fabrice Guibert: https://github.com/Shrecki
.. _Fahimeh Mamashli: https://github.com/fmamashli
.. _Farzin Negahbani: https://github.com/Farzin-Negahbani
.. _Federico Raimondo: https://github.com/fraimondo
Expand Down
37 changes: 31 additions & 6 deletions mne/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -1460,22 +1460,47 @@ def label_sign_flip(label, src):
flip : array
Sign flip vector (contains 1 or -1).
"""
if len(src) != 2:
raise ValueError("Only source spaces with 2 hemisphers are accepted")
if len(src) > 2 or len(src) == 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A better / more modern check would be something like:

_validate_type(src, SourceSpaces, "src")
_check_option("source space kind", src.kind, ("volume", "surface"))
if src.kind == "volume" and len(src) != 1:
    raise ValueError("Only single-segment volumes, are supported, got labelized volume source space")

And incidentally I think eventually we could add support for segmented volume source spaces, as well as mixed source spaces (once surface + volume are fully supported, mixed isn't too bad after that). But probably not as part of this PR!

raise ValueError(
"Only source spaces with between one and two "
+ "hemispheres are accepted, was {len(src)}"
)

lh_vertno = src[0]["vertno"]
rh_vertno = src[1]["vertno"]
if len(src) == 1 and label.hemi == "both":
raise ValueError(
'Cannot use hemisphere label "both" when source'
+ "space contains a single hemisphere."
)

isbi_hemi = len(src) == 2
lh_vertno = None
rh_vertno = None

lh_id = -1
rh_id = -1
if isbi_hemi:
lh_id = 0
rh_id = 1
lh_vertno = src[0]["vertno"]
rh_vertno = src[1]["vertno"]
elif label.hemi == "lh":
lh_vertno = src[0]["vertno"]
elif label.hemi == "rh":
rh_id = 0
rh_vertno = src[0]["vertno"]
else:
raise Exception(f'Unknown hemisphere type "{label.hemi}"')
Comment on lines +1481 to +1492
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't totally follow this logic. It's a surface source space if and only if it's bihemi, and it's volume source space if and only if it's a single-element list. I think it would be cleaner to triage based on source space type probably...


# get source orientations
ori = list()
if label.hemi in ("lh", "both"):
vertices = label.vertices if label.hemi == "lh" else label.lh.vertices
vertno_sel = np.intersect1d(lh_vertno, vertices)
ori.append(src[0]["nn"][vertno_sel])
ori.append(src[lh_id]["nn"][vertno_sel])
if label.hemi in ("rh", "both"):
vertices = label.vertices if label.hemi == "rh" else label.rh.vertices
vertno_sel = np.intersect1d(rh_vertno, vertices)
ori.append(src[1]["nn"][vertno_sel])
ori.append(src[rh_id]["nn"][vertno_sel])
if len(ori) == 0:
raise Exception(f'Unknown hemisphere type "{label.hemi}"')
ori = np.concatenate(ori, axis=0)
Expand Down
35 changes: 28 additions & 7 deletions mne/source_estimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3376,12 +3376,19 @@ def _get_ico_tris(grade, verbose=None, return_surf=False):


def _pca_flip(flip, data):
U, s, V = _safe_svd(data, full_matrices=False)
# determine sign-flip
sign = np.sign(np.dot(U[:, 0], flip))
# use average power in label for scaling
scale = np.linalg.norm(s) / np.sqrt(len(data))
return sign * scale * V[0]
result = None
if flip is None:
result = 0
elif data.shape[0] < 2:
result = data.mean(axis=0) # Trivial accumulator
else:
U, s, V = _safe_svd(data, full_matrices=False)
# determine sign-flip
sign = np.sign(np.dot(U[:, 0], flip))
# use average power in label for scaling
scale = np.linalg.norm(s) / np.sqrt(len(data))
result = sign * scale * V[0]
return result


_label_funcs = {
Expand Down Expand Up @@ -3433,6 +3440,10 @@ def _prepare_label_extraction(stc, labels, src, mode, allow_empty, use_sparse):
# only computes vertex indices and label_flip will be list of None.
from .label import BiHemiLabel, Label, label_sign_flip

# logger.info("Selected mode: " + mode)
# print("Entering _prepare_label_extraction")
# print("Selected mode: " + mode)
Comment on lines +3443 to +3445
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cruft, should be removed. But really to help next time:

Suggested change
# logger.info("Selected mode: " + mode)
# print("Entering _prepare_label_extraction")
# print("Selected mode: " + mode)
logger.debug(f"Selected mode: {mode}")

then if you run with verbose='debug' you will see the statement


# if source estimate provided in stc, get vertices from source space and
# check that they are the same as in the stcs
_check_stc_src(stc, src)
Expand All @@ -3445,6 +3456,7 @@ def _prepare_label_extraction(stc, labels, src, mode, allow_empty, use_sparse):

bad_labels = list()
for li, label in enumerate(labels):
# print("Mode: " + mode + " li: " + str(li) + " label: " + str(label))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# print("Mode: " + mode + " li: " + str(li) + " label: " + str(label))

subject = label["subject"] if use_sparse else label.subject
# stc and src can each be None
_check_subject(
Expand Down Expand Up @@ -3510,6 +3522,9 @@ def _prepare_label_extraction(stc, labels, src, mode, allow_empty, use_sparse):
# So if we override vertno with the stc vertices, it will pick
# the correct normals.
with _temporary_vertices(src, stc.vertices):
# print(f"src: {src[:2]}")
# print(f"len(src): {len(src[:2])}")

Comment on lines +3525 to +3527
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# print(f"src: {src[:2]}")
# print(f"len(src): {len(src[:2])}")

this_flip = label_sign_flip(label, src[:2])[:, None]

label_vertidx.append(this_vertidx)
Expand Down Expand Up @@ -3644,8 +3659,10 @@ def _get_default_label_modes():


def _get_allowed_label_modes(stc):
if isinstance(stc, _BaseVolSourceEstimate | _BaseVectorSourceEstimate):
if isinstance(stc, _BaseVectorSourceEstimate):
return ("mean", "max", "auto")
elif isinstance(stc, _BaseVolSourceEstimate):
return ("mean", "pca_flip", "max", "auto")
else:
return _get_default_label_modes()

Expand Down Expand Up @@ -3732,6 +3749,10 @@ def _gen_extract_label_time_course(
this_data.shape = (this_data.shape[0],) + stc.data.shape[1:]
else:
this_data = stc.data[vertidx]
# if flip is None: # Happens if fewer than 2 vertices in the label
# if this_data.shape[]
# label_tc[i] = 0
# else:
Comment on lines +3752 to +3755
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# if flip is None: # Happens if fewer than 2 vertices in the label
# if this_data.shape[]
# label_tc[i] = 0
# else:

label_tc[i] = func(flip, this_data)

if mode is not None:
Expand Down
149 changes: 147 additions & 2 deletions mne/tests/test_source_estimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,150 @@ def test_center_of_mass():
assert_equal(np.round(t, 2), 0.12)


@testing.requires_testing_data
@pytest.mark.parametrize(
"label_type, mri_res, test_label, cf, call",
[
(str, False, False, "head", "meth"), # head frame
(str, False, str, "mri", "func"), # fastest, default for testing
(str, True, str, "mri", "func"), # fastest, default for testing
(str, True, False, "mri", "func"), # mri_resolution
(list, True, False, "mri", "func"), # volume label as list
(dict, True, False, "mri", "func"), # volume label as dict
],
)
def test_extract_label_time_course_volume_pca_flip(
src_volume_labels, label_type, mri_res, test_label, cf, call
):
"""Test extraction of label timecourses on VolumetricSourceEstimate with PCA."""
# Setup of data
src_labels, volume_labels, lut = src_volume_labels
n_tot = 46
assert n_tot == len(src_labels)
inv = read_inverse_operator(fname_inv_vol)
if cf == "head":
src = inv["src"]
else:
src = read_source_spaces(fname_src_vol)
klass = VolVectorSourceEstimate._scalar_class
vertices = [src[0]["vertno"]]
n_verts = len(src[0]["vertno"])
n_times = 50
data = np.arange(1, n_verts + 1)
end_shape = (n_times,)
data = np.repeat(data[..., np.newaxis], n_times, -1)
stcs = [klass(data.astype(float), vertices, 0, 1)]

def eltc(*args, **kwargs):
if call == "func":
return extract_label_time_course(stcs, *args, **kwargs)
else:
return [stcs[0].extract_label_time_course(*args, **kwargs)]

# triage "labels" argument
if mri_res:
# All should be there
missing = []
else:
# Nearest misses these
missing = [
"Left-vessel",
"Right-vessel",
"5th-Ventricle",
"non-WM-hypointensities",
]
n_want = len(src_labels)
if label_type is str:
labels = fname_aseg
elif label_type is list:
labels = (fname_aseg, volume_labels)
else:
assert label_type is dict
labels = (fname_aseg, {k: lut[k] for k in volume_labels})
assert mri_res
assert len(missing) == 0
# we're going to add one that won't exist
missing = ["intentionally_bad"]
labels[1][missing[0]] = 10000
n_want += 1
n_tot += 1
n_want -= len(missing)

# _volume_labels(src, labels, mri_resolution)
# actually do the testing
from mne.source_estimate import _pca_flip, _prepare_label_extraction, _volume_labels

labels_expanded = _volume_labels(src, labels, mri_res)
_, src_flip = _prepare_label_extraction(
stcs[0], labels_expanded, src, "pca_flip", "ignore", bool(mri_res)
)

mode = "pca_flip"
with catch_logging() as log:
label_tc = eltc(
labels,
src,
mode=mode,
allow_empty="ignore",
mri_resolution=mri_res,
verbose=True,
)
log = log.getvalue()
assert re.search("^Reading atlas.*aseg\\.mgz\n", log) is not None
if len(missing):
# assert that the missing ones get logged
assert "does not contain" in log
assert repr(missing) in log
else:
assert "does not contain" not in log
assert f"\n{n_want}/{n_tot} atlas regions had at least" in log
assert len(label_tc) == 1
label_tc = label_tc[0]
assert label_tc.shape == (n_tot,) + end_shape
assert label_tc.shape == (n_tot, n_times)
# let's test some actual values by trusting the masks provided by
# setup_volume_source_space. mri_resolution=True does some
# interpolation so we should not expect equivalence, False does
# nearest so we should.
if mri_res:
rtol = 0.8 # max much more sensitive
else:
rtol = 0.0
for si, s in enumerate(src_labels):
func = _pca_flip
these = data[np.isin(src[0]["vertno"], s["vertno"])]
print(these.shape)
assert len(these) == s["nuse"]
if si == 0 and s["seg_name"] == "Unknown":
continue # unknown is crappy
if s["nuse"] == 0:
want = 0.0
if mri_res:
# this one is totally due to interpolation, so no easy
# test here
continue
else:
if src_flip[si] is None:
want = None
else:
want = func(src_flip[si], these)
if want is not None:
assert_allclose(label_tc[si], want, atol=1e-6, rtol=rtol)
# compare with in_label, only on every fourth for speed
if test_label is not False and si % 4 == 0:
label = s["seg_name"]
if test_label is int:
label = lut[label]
in_label = stcs[0].in_label(label, fname_aseg, src).data
assert in_label.shape == (s["nuse"],) + end_shape
if np.all(want == 0):
assert in_label.shape[0] == 0
else:
if src_flip[si] is not None:
in_label = func(src_flip[si], in_label)
assert_allclose(in_label, want, atol=1e-6, rtol=rtol)


@testing.requires_testing_data
@pytest.mark.parametrize("kind", ("surface", "mixed"))
@pytest.mark.parametrize("vector", (False, True))
Expand Down Expand Up @@ -943,7 +1087,8 @@ def eltc(*args, **kwargs):
if cf == "head" and not mri_res: # some missing
with pytest.warns(RuntimeWarning, match="any vertices"):
eltc(labels, src, allow_empty=True, mri_resolution=mri_res)
for mode in ("mean", "max"):
modes = ("mean", "max") if vector else ("mean", "max")
for mode in modes:
Comment on lines +1090 to +1091
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revert?

with catch_logging() as log:
label_tc = eltc(
labels,
Expand Down Expand Up @@ -1469,7 +1614,7 @@ def objective(x):
assert_allclose(directions, want_nn, atol=2e-6)


@testing.requires_testing_data
# @testing.requires_testing_data
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# @testing.requires_testing_data
@testing.requires_testing_data

def test_source_estime_project_label():
"""Test projecting a source estimate onto direction of max power."""
fwd = read_forward_solution(fname_fwd)
Expand Down
Loading