diff --git a/examples/dummy.py b/examples/dummy.py index 752e1e2..8a489a9 100644 --- a/examples/dummy.py +++ b/examples/dummy.py @@ -19,6 +19,13 @@ def to_json(sources): return json.dumps({k: str(s) for k, s in sources.items()}, indent=4) +def data2stc(data, src): + vertno = [s["vertno"] for s in src] + return mne.SourceEstimate( + data=data, vertices=vertno, tmin=0, tstep=0.01, subject="fsaverage" + ) + + # Load the head model fs_dir = Path("~/mne_data/MNE-fsaverage-data/fsaverage/") fwd_path = fs_dir / "bem_copy" / "fsaverage-oct6-fwd.fif" @@ -44,6 +51,17 @@ def to_json(sources): fwd = mne.convert_forward_solution(fwd, force_fixed=True) fwd = mne.pick_channels_forward(fwd, info.ch_names, ordered=True) +# Create a dummy stc for std based on the y-position of the sources +ypos = np.hstack([1 - 8 * np.abs(s["rr"][s["inuse"] > 0, 1]) for s in src]) +std_stc = data2stc(ypos, src) +std_stc.plot( + subject="fsaverage", + hemi="split", + views=["lat", "med"], + clim=dict(kind="value", lims=[0, 1, 2]), + transparent=False, +) + sim = SourceSimulator(src) sim.add_noise_sources(location=select_random, location_params=dict(n=10)) @@ -54,8 +72,9 @@ def to_json(sources): waveform=narrowband_oscillation, location_params=dict(n=3), waveform_params=dict(fmin=8, fmax=12), - std=[1, 1, 10], + std=std_stc, names=["s1", "s2", "s3"], + extents=10, ) # Set coupling diff --git a/src/meegsim/_check.py b/src/meegsim/_check.py index cc21556..3b75886 100644 --- a/src/meegsim/_check.py +++ b/src/meegsim/_check.py @@ -593,3 +593,16 @@ def check_extents(extents, n_sources): ) return extents + + +def check_stc_as_param(stc, src): + for src_idx, s in enumerate(src): + common = np.intersect1d(stc.vertices[src_idx], s["vertno"], assume_unique=True) + + missing_vertno = set(s["vertno"]) - set(common) + if missing_vertno: + raise ValueError( + f"The provided stc does not contain all vertices of the " + f"source space that used for simulations. Missing vertices: " + f"{", ".join(list(missing_vertno))}" + ) diff --git a/src/meegsim/snr.py b/src/meegsim/snr.py index cddba8e..4dd8207 100644 --- a/src/meegsim/snr.py +++ b/src/meegsim/snr.py @@ -131,7 +131,7 @@ def _adjust_snr_local(src, fwd, tstep, sources, source_groups, noise_sources): ) # Adjust the amplitude of each source in the group to match the target SNR - for name, target_snr in zip(sg.names, sg.snr): + for name, target_snr in zip(sg.names, sg.snr, strict=True): s = sources[name] # NOTE: taking a safer approach for now and filtering diff --git a/src/meegsim/source_groups.py b/src/meegsim/source_groups.py index d79e4ee..9b57742 100644 --- a/src/meegsim/source_groups.py +++ b/src/meegsim/source_groups.py @@ -3,15 +3,18 @@ defined by the user until we actually start simulating the data. """ -from ._check import ( +import mne + +from meegsim._check import ( check_location, check_waveform, check_numeric_array, check_snr_params, + check_stc_as_param, check_names, check_extents, ) -from .sources import PointSource, PatchSource +from meegsim.sources import PointSource, PatchSource def generate_names(group, n_sources): @@ -141,7 +144,10 @@ def create( "SNR", snr, n_sources, bounds=(0, None), allow_none=True ) snr_params = check_snr_params(snr_params, snr) - std = check_numeric_array("std", std, n_sources, bounds=(0, None)) + if isinstance(std, mne.SourceEstimate): + check_stc_as_param(std, src) + else: + std = check_numeric_array("std", std, n_sources, bounds=(0, None)) # Auto-generate or check the provided source names if not names: @@ -258,8 +264,11 @@ def create( "SNR", snr, n_sources, bounds=(0, None), allow_none=True ) snr_params = check_snr_params(snr_params, snr) - std = check_numeric_array("std", std, n_sources, bounds=(0, None)) extents = check_extents(extents, n_sources) + if isinstance(std, mne.SourceEstimate): + check_stc_as_param(std, src) + else: + std = check_numeric_array("std", std, n_sources, bounds=(0, None)) # Auto-generate or check the provided source names if not names: diff --git a/src/meegsim/sources.py b/src/meegsim/sources.py index 36a45dc..55b436d 100644 --- a/src/meegsim/sources.py +++ b/src/meegsim/sources.py @@ -8,7 +8,7 @@ import numpy as np import mne -from .utils import vertices_to_mne, _extract_hemi +from meegsim.utils import vertices_to_mne, _extract_hemi, get_param_from_stc class _BaseSource: @@ -175,9 +175,15 @@ def create( if data.shape[1] != len(times): raise ValueError("The number of samples in waveform does not match") + # Get the std values if an stc was provided + if isinstance(stds, mne.SourceEstimate): + stds = get_param_from_stc(stds, vertices) + # Create point sources and save them as a group sources = [] - for (src_idx, vertno), waveform, std, name in zip(vertices, data, stds, names): + for (src_idx, vertno), waveform, std, name in zip( + vertices, data, stds, names, strict=True + ): hemi = _extract_hemi(src[src_idx]) sources.append( cls( @@ -229,7 +235,10 @@ def __repr__(self): @property def data(self): - return np.tile(self.waveform, (len(self.vertno), 1)) + # NOTE: the scaling factor is introduced to make the total variance of + # patch activity invariant to the number of vertices in the patch + scaling_factor = 1 / np.sqrt(len(self.vertno)) + return np.tile(self.waveform, (len(self.vertno), 1)) * scaling_factor @property def vertices(self): @@ -273,6 +282,7 @@ def create( # find patch vertices subject = src[0].get("subject_his_id", None) patch_vertices = [] + patch_stds = [] if isinstance(stds, mne.SourceEstimate) else stds for isource, extent in enumerate(extents): src_idx, vertno = vertices[isource] @@ -288,6 +298,11 @@ def create( subject, vertno, extent, src_idx, subjects_dir=None )[0] + # Get the std values if an stc was provided + if isinstance(stds, mne.SourceEstimate): + std = get_param_from_stc(stds, [(src_idx, vertno)]) + patch_stds.append(std) + # Prune vertices patch_vertno = [ vert for vert in patch.vertices if vert in src[src_idx]["vertno"] @@ -297,7 +312,7 @@ def create( # Create patch sources and save them as a group sources = [] for (src_idx, _), patch_vertno, waveform, std, name in zip( - vertices, patch_vertices, data, stds, names + vertices, patch_vertices, data, patch_stds, names, strict=True ): hemi = _extract_hemi(src[src_idx]) sources.append( diff --git a/src/meegsim/utils.py b/src/meegsim/utils.py index 50e0a70..7f76247 100644 --- a/src/meegsim/utils.py +++ b/src/meegsim/utils.py @@ -40,7 +40,7 @@ def combine_stcs(stc1, stc2): stc = stc1.copy() new_data = stc2.data.copy() - for vi, (v_old, v_new) in enumerate(zip(stc.vertices, stc2.vertices)): + for vi, (v_old, v_new) in enumerate(zip(stc.vertices, stc2.vertices, strict=True)): v_common, ind1, ind2 = np.intersect1d(v_old, v_new, return_indices=True) if v_common.size > 0: # Sum up signals for vertices common to stc1 and stc2 @@ -60,7 +60,7 @@ def combine_stcs(stc1, stc2): offsets_old += [len(v_old)] offsets_new += [len(v_new)] - inds = [ii + offset for ii, offset in zip(inserters, offsets_old[:-1])] + inds = [ii + offset for ii, offset in zip(inserters, offsets_old[:-1], strict=True)] inds = np.concatenate(inds) stc.data = np.insert(stc.data, inds, new_data, axis=0) @@ -211,3 +211,13 @@ def vertices_to_mne(vertices, src): packed_vertices[src_idx] = src_vertno return packed_vertices + + +def get_param_from_stc(stc, vertices): + values = np.zeros((len(vertices),)) + offsets = [0, len(stc.vertices[0])] + for i, (src_idx, vertno) in enumerate(vertices): + idx = offsets[src_idx] + np.searchsorted(stc.vertices[src_idx], vertno) + values[i] = stc.data[idx] + + return values diff --git a/tests/test_utils.py b/tests/test_utils.py index 8e95933..55a286c 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -127,7 +127,7 @@ def test_extract_hemi(): ) expected_hemis = ["lh", "rh", None, None] - for s, hemi in zip(src, expected_hemis): + for s, hemi in zip(src, expected_hemis, strict=True): assert _extract_hemi(s) == hemi, f"Failed for {s['type']}" diff --git a/tests/utils/prepare.py b/tests/utils/prepare.py index a36cc62..6502e66 100644 --- a/tests/utils/prepare.py +++ b/tests/utils/prepare.py @@ -13,7 +13,7 @@ def prepare_source_space(types, vertices): # Create a simple dummy data structure for testing purposes src = [] - for i, (src_type, src_vertno) in enumerate(zip(types, vertices)): + for i, (src_type, src_vertno) in enumerate(zip(types, vertices, strict=True)): n_verts = len(src_vertno) # Generate random positions and random normals