Skip to content

Commit

Permalink
BUG: Fix bug with fwd restriction (#11694)
Browse files Browse the repository at this point in the history
  • Loading branch information
larsoner authored May 13, 2023
1 parent 68e4f89 commit 1e6aa7d
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 85 deletions.
2 changes: 1 addition & 1 deletion doc/changes/latest.inc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Enhancements

Bugs
~~~~
- None yet
- Fix bug with :func:`mne.forward.restrict_forward_to_label` where cortical patch information was not adjusted (:gh:`11694` by `Eric Larson`_)

API changes
~~~~~~~~~~~
Expand Down
83 changes: 7 additions & 76 deletions mne/forward/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1760,7 +1760,7 @@ def restrict_forward_to_stc(fwd, stc, on_missing="ignore"):
"""
_validate_type(on_missing, str, "on_missing")
_check_option("on_missing", on_missing, ("ignore", "warn", "raise"))
src_sel, _, vertices = _stc_src_sel(fwd["src"], stc, on_missing=on_missing)
src_sel, _, _ = _stc_src_sel(fwd["src"], stc, on_missing=on_missing)
del stc
return _restrict_forward_to_src_sel(fwd, src_sel)

Expand Down Expand Up @@ -1844,81 +1844,12 @@ def restrict_forward_to_label(fwd, labels):
vertices[i] = np.append(vertices[i], label.vertices)
# Remove duplicates and sort
vertices = [np.unique(vert_hemi) for vert_hemi in vertices]

fwd_out = deepcopy(fwd)
fwd_out["source_rr"] = np.zeros((0, 3))
fwd_out["nsource"] = 0
fwd_out["source_nn"] = np.zeros((0, 3))
fwd_out["sol"]["data"] = np.zeros((fwd["sol"]["data"].shape[0], 0))
fwd_out["_orig_sol"] = np.zeros((fwd["_orig_sol"].shape[0], 0))
if fwd["sol_grad"] is not None:
fwd_out["sol_grad"]["data"] = np.zeros((fwd["sol_grad"]["data"].shape[0], 0))
fwd_out["_orig_sol_grad"] = np.zeros((fwd["_orig_sol_grad"].shape[0], 0))
fwd_out["sol"]["ncol"] = 0
nuse_lh = fwd["src"][0]["nuse"]

for i in range(2):
fwd_out["src"][i]["vertno"] = np.array([], int)
fwd_out["src"][i]["nuse"] = 0
fwd_out["src"][i]["inuse"] = fwd["src"][i]["inuse"].copy()
fwd_out["src"][i]["inuse"].fill(0)
fwd_out["src"][i]["use_tris"] = np.array([[]], int)
fwd_out["src"][i]["nuse_tri"] = np.array([0])

# src_sel is idx to cols in fwd that are in any label per hemi
src_sel = np.intersect1d(fwd["src"][i]["vertno"], vertices[i])
src_sel = np.searchsorted(fwd["src"][i]["vertno"], src_sel)

# Reconstruct each src
vertno = fwd["src"][i]["vertno"][src_sel]
fwd_out["src"][i]["inuse"][vertno] = 1
fwd_out["src"][i]["nuse"] += len(vertno)
fwd_out["src"][i]["vertno"] = np.where(fwd_out["src"][i]["inuse"])[0]

# Reconstruct part of fwd that is not sol data
src_sel += i * nuse_lh # Add column shift to right hemi
fwd_out["source_rr"] = np.vstack(
[fwd_out["source_rr"], fwd["source_rr"][src_sel]]
)
fwd_out["nsource"] += len(src_sel)

if is_fixed_orient(fwd):
idx = src_sel
if fwd["sol_grad"] is not None:
idx_grad = (3 * src_sel[:, None] + np.arange(3)).ravel()
else:
idx = (3 * src_sel[:, None] + np.arange(3)).ravel()
if fwd["sol_grad"] is not None:
idx_grad = (9 * src_sel[:, None] + np.arange(9)).ravel()

fwd_out["source_nn"] = np.vstack([fwd_out["source_nn"], fwd["source_nn"][idx]])
fwd_out["sol"]["data"] = np.hstack(
[fwd_out["sol"]["data"], fwd["sol"]["data"][:, idx]]
)
if fwd["sol_grad"] is not None:
fwd_out["sol_grad"]["data"] = np.hstack(
[fwd_out["sol_grad"]["data"], fwd["sol_rad"]["data"][:, idx_grad]]
)
fwd_out["sol"]["ncol"] += len(idx)

if is_fixed_orient(fwd, orig=True):
idx = src_sel
if fwd["sol_grad"] is not None:
idx_grad = (3 * src_sel[:, None] + np.arange(3)).ravel()
else:
idx = (3 * src_sel[:, None] + np.arange(3)).ravel()
if fwd["sol_grad"] is not None:
idx_grad = (9 * src_sel[:, None] + np.arange(9)).ravel()

fwd_out["_orig_sol"] = np.hstack(
[fwd_out["_orig_sol"], fwd["_orig_sol"][:, idx]]
)
if fwd["sol_grad"] is not None:
fwd_out["_orig_sol_grad"] = np.hstack(
[fwd_out["_orig_sol_grad"], fwd["_orig_sol_grad"][:, idx_grad]]
)

return fwd_out
vertices = [
vert_hemi[np.in1d(vert_hemi, s["vertno"])]
for vert_hemi, s in zip(vertices, fwd["src"])
]
src_sel, _, _ = _stc_src_sel(fwd["src"], vertices, on_missing="raise")
return _restrict_forward_to_src_sel(fwd, src_sel)


def _do_forward_solution(
Expand Down
45 changes: 43 additions & 2 deletions mne/forward/tests/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
fname_evoked = (
Path(__file__).parent.parent.parent / "io" / "tests" / "data" / "test-ave.fif"
)
label_path = data_path / "MEG" / "sample" / "labels"


def assert_forward_allclose(f1, f2, rtol=1e-7):
Expand Down Expand Up @@ -318,7 +319,6 @@ def test_restrict_forward_to_label(tmp_path):
fwd = convert_forward_solution(fwd, surf_ori=True, force_fixed=True, use_cps=True)
fwd = pick_types_forward(fwd, meg=True)

label_path = data_path / "MEG" / "sample" / "labels"
labels = ["Aud-lh", "Vis-rh"]
label_lh = read_label(label_path / (labels[0] + ".label"))
label_rh = read_label(label_path / (labels[1] + ".label"))
Expand All @@ -344,7 +344,6 @@ def test_restrict_forward_to_label(tmp_path):
fwd = read_forward_solution(fname_meeg)
fwd = pick_types_forward(fwd, meg=True)

label_path = data_path / "MEG" / "sample" / "labels"
labels = ["Aud-lh", "Vis-rh"]
label_lh = read_label(label_path / (labels[0] + ".label"))
label_rh = read_label(label_path / (labels[1] + ".label"))
Expand Down Expand Up @@ -375,6 +374,48 @@ def test_restrict_forward_to_label(tmp_path):
assert_forward_allclose(fwd_out, fwd_out_read)


@pytest.mark.parametrize("use_cps", [True, False])
@testing.requires_testing_data
def test_restrict_forward_to_label_cps(tmp_path, use_cps):
"""Test for gh-11689."""
label_lh = read_label(label_path / "Aud-lh.label")
fwd = read_forward_solution(fname_meeg)
convert_forward_solution(
fwd, surf_ori=True, force_fixed=False, copy=False, use_cps=use_cps
)
fwd = pick_types_forward(fwd, meg="mag")
fwd_out = restrict_forward_to_label(fwd, label_lh)
vert = fwd_out["src"][0]["vertno"][0]

assert fwd["surf_ori"]
assert not is_fixed_orient(fwd)
idx = list(fwd["src"][0]["vertno"]).index(vert)
assert idx == 126
go1 = fwd["_orig_sol"][:, idx * 3 : idx * 3 + 3].copy()
gs1 = fwd["sol"]["data"][:, idx * 3 : idx * 3 + 3].copy()

assert fwd_out["surf_ori"]
assert not is_fixed_orient(fwd_out)
idx = list(fwd_out["src"][0]["vertno"]).index(vert)
assert idx == 0
go2 = fwd_out["_orig_sol"][:, idx * 3 : idx * 3 + 3].copy()
gs2 = fwd_out["sol"]["data"][:, idx * 3 : idx * 3 + 3].copy()
assert_allclose(go2, go1)
assert_allclose(gs2, gs1)

# should be a no-op
convert_forward_solution(
fwd_out, surf_ori=True, force_fixed=False, copy=False, use_cps=use_cps
)
assert fwd_out["surf_ori"]
assert not is_fixed_orient(fwd_out)
assert list(fwd_out["src"][0]["vertno"]).index(vert) == 0
go3 = fwd_out["_orig_sol"][:, idx * 3 : idx * 3 + 3].copy()
gs3 = fwd_out["sol"]["data"][:, idx * 3 : idx * 3 + 3].copy()
assert_allclose(go3, go1)
assert_allclose(gs3, gs1)


@testing.requires_testing_data
@requires_mne
def test_average_forward_solution(tmp_path):
Expand Down
15 changes: 9 additions & 6 deletions mne/proj.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def compute_proj_raw(

@verbose
def sensitivity_map(
fwd, projs=None, ch_type="grad", mode="fixed", exclude=[], verbose=None
fwd, projs=None, ch_type="grad", mode="fixed", exclude=(), *, verbose=None
):
"""Compute sensitivity map.
Expand Down Expand Up @@ -420,6 +420,11 @@ def sensitivity_map(
stc : SourceEstimate | VolSourceEstimate
The sensitivity map as a SourceEstimate or VolSourceEstimate instance
for visualization.
Notes
-----
When mode is ``'fixed'`` or ``'free'``, the sensitivity map is normalized
by its maximum value.
"""
from scipy import linalg

Expand All @@ -444,10 +449,7 @@ def sensitivity_map(
convert_forward_solution(
fwd, surf_ori=True, force_fixed=False, copy=False, verbose=False
)
if not fwd["surf_ori"] or is_fixed_orient(fwd):
raise RuntimeError(
"Error converting solution, please notify " "mne-python developers"
)
assert fwd["surf_ori"] and not is_fixed_orient(fwd)

gain = fwd["sol"]["data"]

Expand Down Expand Up @@ -477,8 +479,9 @@ def sensitivity_map(
elif mode in residual_types:
raise ValueError("No projectors used, cannot compute %s" % mode)

n_sensors, n_dipoles = gain.shape
_, n_dipoles = gain.shape
n_locations = n_dipoles // 3
del n_dipoles
sensitivity_map = np.empty(n_locations)

for k in range(n_locations):
Expand Down

0 comments on commit 1e6aa7d

Please sign in to comment.