Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: Fix bug with fwd restriction #11694

Merged
merged 2 commits into from
May 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/changes/latest.inc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Enhancements

Bugs
~~~~
- None yet
- Fix bug with :func:`mne.forward.restrict_forward_to_label` where cortical patch information was not adjusted (:gh:`11694` by `Eric Larson`_)

API changes
~~~~~~~~~~~
Expand Down
83 changes: 7 additions & 76 deletions mne/forward/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1760,7 +1760,7 @@ def restrict_forward_to_stc(fwd, stc, on_missing="ignore"):
"""
_validate_type(on_missing, str, "on_missing")
_check_option("on_missing", on_missing, ("ignore", "warn", "raise"))
src_sel, _, vertices = _stc_src_sel(fwd["src"], stc, on_missing=on_missing)
src_sel, _, _ = _stc_src_sel(fwd["src"], stc, on_missing=on_missing)
del stc
return _restrict_forward_to_src_sel(fwd, src_sel)

Expand Down Expand Up @@ -1844,81 +1844,12 @@ def restrict_forward_to_label(fwd, labels):
vertices[i] = np.append(vertices[i], label.vertices)
# Remove duplicates and sort
vertices = [np.unique(vert_hemi) for vert_hemi in vertices]

fwd_out = deepcopy(fwd)
fwd_out["source_rr"] = np.zeros((0, 3))
fwd_out["nsource"] = 0
fwd_out["source_nn"] = np.zeros((0, 3))
fwd_out["sol"]["data"] = np.zeros((fwd["sol"]["data"].shape[0], 0))
fwd_out["_orig_sol"] = np.zeros((fwd["_orig_sol"].shape[0], 0))
if fwd["sol_grad"] is not None:
fwd_out["sol_grad"]["data"] = np.zeros((fwd["sol_grad"]["data"].shape[0], 0))
fwd_out["_orig_sol_grad"] = np.zeros((fwd["_orig_sol_grad"].shape[0], 0))
fwd_out["sol"]["ncol"] = 0
nuse_lh = fwd["src"][0]["nuse"]

for i in range(2):
fwd_out["src"][i]["vertno"] = np.array([], int)
fwd_out["src"][i]["nuse"] = 0
fwd_out["src"][i]["inuse"] = fwd["src"][i]["inuse"].copy()
fwd_out["src"][i]["inuse"].fill(0)
fwd_out["src"][i]["use_tris"] = np.array([[]], int)
fwd_out["src"][i]["nuse_tri"] = np.array([0])

# src_sel is idx to cols in fwd that are in any label per hemi
src_sel = np.intersect1d(fwd["src"][i]["vertno"], vertices[i])
src_sel = np.searchsorted(fwd["src"][i]["vertno"], src_sel)

# Reconstruct each src
vertno = fwd["src"][i]["vertno"][src_sel]
fwd_out["src"][i]["inuse"][vertno] = 1
fwd_out["src"][i]["nuse"] += len(vertno)
fwd_out["src"][i]["vertno"] = np.where(fwd_out["src"][i]["inuse"])[0]

# Reconstruct part of fwd that is not sol data
src_sel += i * nuse_lh # Add column shift to right hemi
fwd_out["source_rr"] = np.vstack(
[fwd_out["source_rr"], fwd["source_rr"][src_sel]]
)
fwd_out["nsource"] += len(src_sel)

if is_fixed_orient(fwd):
idx = src_sel
if fwd["sol_grad"] is not None:
idx_grad = (3 * src_sel[:, None] + np.arange(3)).ravel()
else:
idx = (3 * src_sel[:, None] + np.arange(3)).ravel()
if fwd["sol_grad"] is not None:
idx_grad = (9 * src_sel[:, None] + np.arange(9)).ravel()

fwd_out["source_nn"] = np.vstack([fwd_out["source_nn"], fwd["source_nn"][idx]])
fwd_out["sol"]["data"] = np.hstack(
[fwd_out["sol"]["data"], fwd["sol"]["data"][:, idx]]
)
if fwd["sol_grad"] is not None:
fwd_out["sol_grad"]["data"] = np.hstack(
[fwd_out["sol_grad"]["data"], fwd["sol_rad"]["data"][:, idx_grad]]
)
fwd_out["sol"]["ncol"] += len(idx)

if is_fixed_orient(fwd, orig=True):
idx = src_sel
if fwd["sol_grad"] is not None:
idx_grad = (3 * src_sel[:, None] + np.arange(3)).ravel()
else:
idx = (3 * src_sel[:, None] + np.arange(3)).ravel()
if fwd["sol_grad"] is not None:
idx_grad = (9 * src_sel[:, None] + np.arange(9)).ravel()

fwd_out["_orig_sol"] = np.hstack(
[fwd_out["_orig_sol"], fwd["_orig_sol"][:, idx]]
)
if fwd["sol_grad"] is not None:
fwd_out["_orig_sol_grad"] = np.hstack(
[fwd_out["_orig_sol_grad"], fwd["_orig_sol_grad"][:, idx_grad]]
)

return fwd_out
vertices = [
vert_hemi[np.in1d(vert_hemi, s["vertno"])]
for vert_hemi, s in zip(vertices, fwd["src"])
]
src_sel, _, _ = _stc_src_sel(fwd["src"], vertices, on_missing="raise")
return _restrict_forward_to_src_sel(fwd, src_sel)


def _do_forward_solution(
Expand Down
45 changes: 43 additions & 2 deletions mne/forward/tests/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
fname_evoked = (
Path(__file__).parent.parent.parent / "io" / "tests" / "data" / "test-ave.fif"
)
label_path = data_path / "MEG" / "sample" / "labels"


def assert_forward_allclose(f1, f2, rtol=1e-7):
Expand Down Expand Up @@ -318,7 +319,6 @@ def test_restrict_forward_to_label(tmp_path):
fwd = convert_forward_solution(fwd, surf_ori=True, force_fixed=True, use_cps=True)
fwd = pick_types_forward(fwd, meg=True)

label_path = data_path / "MEG" / "sample" / "labels"
labels = ["Aud-lh", "Vis-rh"]
label_lh = read_label(label_path / (labels[0] + ".label"))
label_rh = read_label(label_path / (labels[1] + ".label"))
Expand All @@ -344,7 +344,6 @@ def test_restrict_forward_to_label(tmp_path):
fwd = read_forward_solution(fname_meeg)
fwd = pick_types_forward(fwd, meg=True)

label_path = data_path / "MEG" / "sample" / "labels"
labels = ["Aud-lh", "Vis-rh"]
label_lh = read_label(label_path / (labels[0] + ".label"))
label_rh = read_label(label_path / (labels[1] + ".label"))
Expand Down Expand Up @@ -375,6 +374,48 @@ def test_restrict_forward_to_label(tmp_path):
assert_forward_allclose(fwd_out, fwd_out_read)


@pytest.mark.parametrize("use_cps", [True, False])
@testing.requires_testing_data
def test_restrict_forward_to_label_cps(tmp_path, use_cps):
"""Test for gh-11689."""
label_lh = read_label(label_path / "Aud-lh.label")
fwd = read_forward_solution(fname_meeg)
convert_forward_solution(
fwd, surf_ori=True, force_fixed=False, copy=False, use_cps=use_cps
)
fwd = pick_types_forward(fwd, meg="mag")
fwd_out = restrict_forward_to_label(fwd, label_lh)
vert = fwd_out["src"][0]["vertno"][0]

assert fwd["surf_ori"]
assert not is_fixed_orient(fwd)
idx = list(fwd["src"][0]["vertno"]).index(vert)
assert idx == 126
go1 = fwd["_orig_sol"][:, idx * 3 : idx * 3 + 3].copy()
gs1 = fwd["sol"]["data"][:, idx * 3 : idx * 3 + 3].copy()

assert fwd_out["surf_ori"]
assert not is_fixed_orient(fwd_out)
idx = list(fwd_out["src"][0]["vertno"]).index(vert)
assert idx == 0
go2 = fwd_out["_orig_sol"][:, idx * 3 : idx * 3 + 3].copy()
gs2 = fwd_out["sol"]["data"][:, idx * 3 : idx * 3 + 3].copy()
assert_allclose(go2, go1)
assert_allclose(gs2, gs1)

# should be a no-op
convert_forward_solution(
fwd_out, surf_ori=True, force_fixed=False, copy=False, use_cps=use_cps
)
assert fwd_out["surf_ori"]
assert not is_fixed_orient(fwd_out)
assert list(fwd_out["src"][0]["vertno"]).index(vert) == 0
go3 = fwd_out["_orig_sol"][:, idx * 3 : idx * 3 + 3].copy()
gs3 = fwd_out["sol"]["data"][:, idx * 3 : idx * 3 + 3].copy()
assert_allclose(go3, go1)
assert_allclose(gs3, gs1)


@testing.requires_testing_data
@requires_mne
def test_average_forward_solution(tmp_path):
Expand Down
15 changes: 9 additions & 6 deletions mne/proj.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def compute_proj_raw(

@verbose
def sensitivity_map(
fwd, projs=None, ch_type="grad", mode="fixed", exclude=[], verbose=None
fwd, projs=None, ch_type="grad", mode="fixed", exclude=(), *, verbose=None
):
"""Compute sensitivity map.

Expand Down Expand Up @@ -420,6 +420,11 @@ def sensitivity_map(
stc : SourceEstimate | VolSourceEstimate
The sensitivity map as a SourceEstimate or VolSourceEstimate instance
for visualization.

Notes
-----
When mode is ``'fixed'`` or ``'free'``, the sensitivity map is normalized
by its maximum value.
"""
from scipy import linalg

Expand All @@ -444,10 +449,7 @@ def sensitivity_map(
convert_forward_solution(
fwd, surf_ori=True, force_fixed=False, copy=False, verbose=False
)
if not fwd["surf_ori"] or is_fixed_orient(fwd):
raise RuntimeError(
"Error converting solution, please notify " "mne-python developers"
)
assert fwd["surf_ori"] and not is_fixed_orient(fwd)

gain = fwd["sol"]["data"]

Expand Down Expand Up @@ -477,8 +479,9 @@ def sensitivity_map(
elif mode in residual_types:
raise ValueError("No projectors used, cannot compute %s" % mode)

n_sensors, n_dipoles = gain.shape
_, n_dipoles = gain.shape
n_locations = n_dipoles // 3
del n_dipoles
sensitivity_map = np.empty(n_locations)

for k in range(n_locations):
Expand Down