Skip to content

Commit

Permalink
ENH: allow setting std using a SourceEstimate
Browse files Browse the repository at this point in the history
  • Loading branch information
ctrltz committed Jan 28, 2025
1 parent 919d5b4 commit 447e41c
Show file tree
Hide file tree
Showing 8 changed files with 80 additions and 14 deletions.
21 changes: 20 additions & 1 deletion examples/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@ def to_json(sources):
return json.dumps({k: str(s) for k, s in sources.items()}, indent=4)


def data2stc(data, src):
vertno = [s["vertno"] for s in src]
return mne.SourceEstimate(
data=data, vertices=vertno, tmin=0, tstep=0.01, subject="fsaverage"
)


# Load the head model
fs_dir = Path("~/mne_data/MNE-fsaverage-data/fsaverage/")
fwd_path = fs_dir / "bem_copy" / "fsaverage-oct6-fwd.fif"
Expand All @@ -44,6 +51,17 @@ def to_json(sources):
fwd = mne.convert_forward_solution(fwd, force_fixed=True)
fwd = mne.pick_channels_forward(fwd, info.ch_names, ordered=True)

# Create a dummy stc for std based on the y-position of the sources
ypos = np.hstack([1 - 8 * np.abs(s["rr"][s["inuse"] > 0, 1]) for s in src])
std_stc = data2stc(ypos, src)
std_stc.plot(
subject="fsaverage",
hemi="split",
views=["lat", "med"],
clim=dict(kind="value", lims=[0, 1, 2]),
transparent=False,
)

sim = SourceSimulator(src)

sim.add_noise_sources(location=select_random, location_params=dict(n=10))
Expand All @@ -54,8 +72,9 @@ def to_json(sources):
waveform=narrowband_oscillation,
location_params=dict(n=3),
waveform_params=dict(fmin=8, fmax=12),
std=[1, 1, 10],
std=std_stc,
names=["s1", "s2", "s3"],
extents=10,
)

# Set coupling
Expand Down
13 changes: 13 additions & 0 deletions src/meegsim/_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,3 +593,16 @@ def check_extents(extents, n_sources):
)

return extents


def check_stc_as_param(stc, src):
for src_idx, s in enumerate(src):
common = np.intersect1d(stc.vertices[src_idx], s["vertno"], assume_unique=True)

missing_vertno = set(s["vertno"]) - set(common)
if missing_vertno:
raise ValueError(
f"The provided stc does not contain all vertices of the "
f"source space that used for simulations. Missing vertices: "
f"{", ".join(list(missing_vertno))}"
)
2 changes: 1 addition & 1 deletion src/meegsim/snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def _adjust_snr_local(src, fwd, tstep, sources, source_groups, noise_sources):
)

# Adjust the amplitude of each source in the group to match the target SNR
for name, target_snr in zip(sg.names, sg.snr):
for name, target_snr in zip(sg.names, sg.snr, strict=True):
s = sources[name]

# NOTE: taking a safer approach for now and filtering
Expand Down
17 changes: 13 additions & 4 deletions src/meegsim/source_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,18 @@
defined by the user until we actually start simulating the data.
"""

from ._check import (
import mne

from meegsim._check import (
check_location,
check_waveform,
check_numeric_array,
check_snr_params,
check_stc_as_param,
check_names,
check_extents,
)
from .sources import PointSource, PatchSource
from meegsim.sources import PointSource, PatchSource


def generate_names(group, n_sources):
Expand Down Expand Up @@ -141,7 +144,10 @@ def create(
"SNR", snr, n_sources, bounds=(0, None), allow_none=True
)
snr_params = check_snr_params(snr_params, snr)
std = check_numeric_array("std", std, n_sources, bounds=(0, None))
if isinstance(std, mne.SourceEstimate):
check_stc_as_param(std, src)
else:
std = check_numeric_array("std", std, n_sources, bounds=(0, None))

# Auto-generate or check the provided source names
if not names:
Expand Down Expand Up @@ -258,8 +264,11 @@ def create(
"SNR", snr, n_sources, bounds=(0, None), allow_none=True
)
snr_params = check_snr_params(snr_params, snr)
std = check_numeric_array("std", std, n_sources, bounds=(0, None))
extents = check_extents(extents, n_sources)
if isinstance(std, mne.SourceEstimate):
check_stc_as_param(std, src)
else:
std = check_numeric_array("std", std, n_sources, bounds=(0, None))

# Auto-generate or check the provided source names
if not names:
Expand Down
23 changes: 19 additions & 4 deletions src/meegsim/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np
import mne

from .utils import vertices_to_mne, _extract_hemi
from meegsim.utils import vertices_to_mne, _extract_hemi, get_param_from_stc


class _BaseSource:
Expand Down Expand Up @@ -175,9 +175,15 @@ def create(
if data.shape[1] != len(times):
raise ValueError("The number of samples in waveform does not match")

# Get the std values if an stc was provided
if isinstance(stds, mne.SourceEstimate):
stds = get_param_from_stc(stds, vertices)

# Create point sources and save them as a group
sources = []
for (src_idx, vertno), waveform, std, name in zip(vertices, data, stds, names):
for (src_idx, vertno), waveform, std, name in zip(
vertices, data, stds, names, strict=True
):
hemi = _extract_hemi(src[src_idx])
sources.append(
cls(
Expand Down Expand Up @@ -229,7 +235,10 @@ def __repr__(self):

@property
def data(self):
return np.tile(self.waveform, (len(self.vertno), 1))
# NOTE: the scaling factor is introduced to make the total variance of
# patch activity invariant to the number of vertices in the patch
scaling_factor = 1 / np.sqrt(len(self.vertno))
return np.tile(self.waveform, (len(self.vertno), 1)) * scaling_factor

@property
def vertices(self):
Expand Down Expand Up @@ -273,6 +282,7 @@ def create(
# find patch vertices
subject = src[0].get("subject_his_id", None)
patch_vertices = []
patch_stds = [] if isinstance(stds, mne.SourceEstimate) else stds
for isource, extent in enumerate(extents):
src_idx, vertno = vertices[isource]

Expand All @@ -288,6 +298,11 @@ def create(
subject, vertno, extent, src_idx, subjects_dir=None
)[0]

# Get the std values if an stc was provided
if isinstance(stds, mne.SourceEstimate):
std = get_param_from_stc(stds, [(src_idx, vertno)])
patch_stds.append(std)

# Prune vertices
patch_vertno = [
vert for vert in patch.vertices if vert in src[src_idx]["vertno"]
Expand All @@ -297,7 +312,7 @@ def create(
# Create patch sources and save them as a group
sources = []
for (src_idx, _), patch_vertno, waveform, std, name in zip(
vertices, patch_vertices, data, stds, names
vertices, patch_vertices, data, patch_stds, names, strict=True
):
hemi = _extract_hemi(src[src_idx])
sources.append(
Expand Down
14 changes: 12 additions & 2 deletions src/meegsim/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def combine_stcs(stc1, stc2):

stc = stc1.copy()
new_data = stc2.data.copy()
for vi, (v_old, v_new) in enumerate(zip(stc.vertices, stc2.vertices)):
for vi, (v_old, v_new) in enumerate(zip(stc.vertices, stc2.vertices, strict=True)):
v_common, ind1, ind2 = np.intersect1d(v_old, v_new, return_indices=True)
if v_common.size > 0:
# Sum up signals for vertices common to stc1 and stc2
Expand All @@ -60,7 +60,7 @@ def combine_stcs(stc1, stc2):
offsets_old += [len(v_old)]
offsets_new += [len(v_new)]

inds = [ii + offset for ii, offset in zip(inserters, offsets_old[:-1])]
inds = [ii + offset for ii, offset in zip(inserters, offsets_old[:-1], strict=True)]
inds = np.concatenate(inds)
stc.data = np.insert(stc.data, inds, new_data, axis=0)

Expand Down Expand Up @@ -211,3 +211,13 @@ def vertices_to_mne(vertices, src):
packed_vertices[src_idx] = src_vertno

return packed_vertices


def get_param_from_stc(stc, vertices):
values = np.zeros((len(vertices),))
offsets = [0, len(stc.vertices[0])]
for i, (src_idx, vertno) in enumerate(vertices):
idx = offsets[src_idx] + np.searchsorted(stc.vertices[src_idx], vertno)
values[i] = stc.data[idx]

return values
2 changes: 1 addition & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def test_extract_hemi():
)
expected_hemis = ["lh", "rh", None, None]

for s, hemi in zip(src, expected_hemis):
for s, hemi in zip(src, expected_hemis, strict=True):
assert _extract_hemi(s) == hemi, f"Failed for {s['type']}"


Expand Down
2 changes: 1 addition & 1 deletion tests/utils/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def prepare_source_space(types, vertices):

# Create a simple dummy data structure for testing purposes
src = []
for i, (src_type, src_vertno) in enumerate(zip(types, vertices)):
for i, (src_type, src_vertno) in enumerate(zip(types, vertices, strict=True)):
n_verts = len(src_vertno)

# Generate random positions and random normals
Expand Down

0 comments on commit 447e41c

Please sign in to comment.