Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: allow setting the std parameter using SourceEstimate #67

Open
wants to merge 26 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
1726ba1
Divide by std instead of norm when normalizing
ctrltz Jan 24, 2025
e7b7fde
Replace scaling factor with SourceSimulator.base_amplitude
ctrltz Jan 24, 2025
e6f0013
TEST: fix tests after introducing base amplitude
ctrltz Jan 24, 2025
8e87b37
Introduce standard deviation of the waveform
ctrltz Jan 24, 2025
3acb544
Add std to the public API (tests are broken)
ctrltz Jan 24, 2025
ee3e08b
TEST: tests for check_numeric_array
ctrltz Jan 27, 2025
f7feaa3
TEST: fix all other tests
ctrltz Jan 27, 2025
1a0d822
Docstring, base_amplitude -> base_std
ctrltz Jan 27, 2025
01d87ee
TEST: test the std-based scaling
ctrltz Jan 27, 2025
cf79cdc
DOC: docstrings
ctrltz Jan 28, 2025
fd053a2
MAINT: changelog entry
ctrltz Jan 28, 2025
919d5b4
DOC: docstring
ctrltz Jan 28, 2025
447e41c
ENH: allow setting std using a SourceEstimate
ctrltz Jan 28, 2025
0162bb0
Fix minor errors
ctrltz Jan 28, 2025
0436289
Merge branch 'master' into ctrltz/std-as-stc
ctrltz Jan 28, 2025
e7bcdeb
Merge branch 'master' into ctrltz/std-as-stc
ctrltz Feb 5, 2025
53aad08
Fix std check for patch sources
ctrltz Feb 5, 2025
7bc2d43
TEST: tests for check_stc_as_param
ctrltz Feb 5, 2025
88cb9b5
TEST: tests for _get_param_from_stc
ctrltz Feb 5, 2025
7da0407
Slowly getting rid of relative imports
ctrltz Feb 5, 2025
79dec97
TEST: fix all tests
ctrltz Feb 5, 2025
44eb37c
Remove strict=True from zips - not supported in py3.9
ctrltz Feb 5, 2025
b4891b7
DOC: describe the new option in the docstring
ctrltz Feb 5, 2025
fa6de8f
MAINT: changelog entry
ctrltz Feb 5, 2025
7c449de
Merge branch 'master' into ctrltz/std-as-stc
ctrltz Feb 6, 2025
7c8e084
Fix the case of manually provided patch vertices
ctrltz Feb 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ noise ([#58](https://github.com/ctrltz/meegsim/pull/58))
- A possibility to plot the source configuration ([#59](https://github.com/ctrltz/meegsim/pull/59))
- Adjustment of global (all signal vs. all noise sources) SNR ([#64](https://github.com/ctrltz/meegsim/pull/64))
- Adjustment of the standard deviation of source activity ([#66](https://github.com/ctrltz/meegsim/pull/66))
- Allow specifying standard deviation via a SourceEstimate object ([#67](https://github.com/ctrltz/meegsim/pull/67))

### Changed

Expand Down
68 changes: 39 additions & 29 deletions examples/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from pathlib import Path

from meegsim.coupling import ppc_von_mises
from meegsim.location import select_random
from meegsim.simulate import SourceSimulator
from meegsim.waveform import narrowband_oscillation
Expand All @@ -19,6 +18,17 @@ def to_json(sources):
return json.dumps({k: str(s) for k, s in sources.items()}, indent=4)


def data2stc(data, src):
vertno = [s["vertno"] for s in src]
return mne.SourceEstimate(
data=data, vertices=vertno, tmin=0, tstep=0.01, subject="fsaverage"
)


def extents_from_areas_cm2(areas_cm2):
return list(np.sqrt(np.array(areas_cm2) * 100 / np.pi))


# Load the head model
fs_dir = Path("~/mne_data/MNE-fsaverage-data/fsaverage/")
fwd_path = fs_dir / "bem_copy" / "fsaverage-oct6-fwd.fif"
Expand All @@ -44,45 +54,32 @@ def to_json(sources):
fwd = mne.convert_forward_solution(fwd, force_fixed=True)
fwd = mne.pick_channels_forward(fwd, info.ch_names, ordered=True)

# Create a dummy stc for std based on the y-position of the sources
ypos = np.hstack([1 - 8 * np.abs(s["rr"][s["inuse"] > 0, 1]) for s in src])
std_stc = data2stc(ypos, src)
std_stc.plot(
subject="fsaverage",
hemi="split",
views=["lat", "med"],
clim=dict(kind="value", lims=[0, 1, 2]),
transparent=False,
background="white",
)

sim = SourceSimulator(src)

sim.add_noise_sources(location=select_random, location_params=dict(n=10))

# Select some vertices randomly
sim.add_point_sources(
location=[(0, 0), (0, 87780), (0, 106307)],
waveform=narrowband_oscillation,
location_params=dict(n=3),
waveform_params=dict(fmin=8, fmax=12),
std=[1, 1, 10],
names=["s1", "s2", "s3"],
)

sim.add_patch_sources(
location=select_random,
waveform=narrowband_oscillation,
snr=1,
location_params=dict(n=3),
waveform_params=dict(fmin=8, fmax=12),
snr_params=dict(fmin=8, fmax=12),
extents=[10, 20, 50],
names=["s4", "s5", "s6"],
std=std_stc,
extents=extents_from_areas_cm2([2, 4, 8]),
)

# Set coupling
sim.set_coupling(
coupling={
("s1", "s2"): dict(kappa=1, phase_lag=np.pi / 3),
("s2", "s3"): dict(kappa=10, phase_lag=-np.pi / 2),
},
method=ppc_von_mises,
fmin=8,
fmax=12,
)

print(sim._coupling_graph)
print(sim._coupling_graph.edges(data=True))

sc = sim.simulate(
sfreq,
duration,
Expand All @@ -91,9 +88,22 @@ def to_json(sources):
snr_params=dict(fmin=8, fmax=12),
random_state=seed,
)
stc = sc.to_stc()
raw = sc.to_raw(fwd, info, sensor_noise_level=0.05)

print([np.var(s.waveform) for s in sc._sources.values()])
source_std = np.std(stc.data, axis=1)
lim = np.max(source_std)
std_stc_est = mne.SourceEstimate(source_std, stc.vertices, tmin=0, tstep=0.01)
std_stc_est.plot(
subject="fsaverage",
hemi="split",
views=["lat", "med"],
clim=dict(kind="value", lims=[0, lim / 2, lim]),
colormap="Reds",
time_viewer=False,
transparent=False,
background="white",
)

sc.plot(subject="fsaverage", hemi="split", views=["lat", "med"])

Expand Down
15 changes: 15 additions & 0 deletions src/meegsim/_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,21 @@ def check_extents(extents, n_sources):
return extents


def check_stc_as_param(stc, src):
for src_idx, s in enumerate(src):
common = np.intersect1d(stc.vertices[src_idx], s["vertno"], assume_unique=True)

# XXX: the code below overlaps with sources._BaseSource.check_compatibility
missing_vertno = set(s["vertno"]) - set(common)
if missing_vertno:
report_missing = ", ".join([str(v) for v in missing_vertno])
raise ValueError(
f"The provided stc does not contain all vertices of the "
f"source space that is used for simulations. The following vertices "
f"from src[{src_idx}] are missing: {report_missing}"
)


def check_colors(colors):
"""
Check the dictionary with colors provided for the visualization of the
Expand Down
50 changes: 36 additions & 14 deletions src/meegsim/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,20 @@ def add_point_sources(
``'local'``. Can be None (no adjustment of SNR), a single value
that is used for all sources or an array with one SNR
value per source.
std : float or array, optional
Desired standard deviation of the source activity, provided either as a
single value that applies to all sources or as an array with one value per
source. This parameter can be used in combination with the global SNR
std : float, array, or SourceEstimate, optional
Desired standard deviation of the source activity, provided via one of
the following options:

- a single value that applies to all sources
- an array with one value per source
- a :class:`~mne.SourceEstimate` object that contains values of all
vertices of the source space. In this case, the value will be adjusted
for each source automatically based on its location.

This parameter can be used in combination with the global SNR
mode to set an arbitrary spatial distribution of source activity.
By default, 1 is used so the variance of all sources is the same.
If the value of local SNR is specified, this parameter will effectively
If the value of ``snr`` is specified, this parameter will effectively
be ignored.
location_params : dict, optional
Keyword arguments that will be passed to ``location``
Expand Down Expand Up @@ -187,13 +194,21 @@ def add_patch_sources(
``'local'``. Can be None (no adjustment of SNR, default),
a single value that is used for all sources or an array
with one SNR value per source.
std : float or array, optional
Desired standard deviation of the source activity, provided either as a
single value that applies to all sources or as an array with one value per
source. This parameter can be used in combination with the global SNR
std : float, array, or SourceEstimate, optional
Desired standard deviation of the **total** source activity of the
patch (invariant to the number of vertices in the patch), provided via
one of the following options:

- a single value that applies to all sources
- an array with one value per source
- a :class:`~mne.SourceEstimate` object that contains values of all
vertices of the source space. In this case, the value will be adjusted
for each source automatically based on its location.

This parameter can be used in combination with the global SNR
mode to set an arbitrary spatial distribution of source activity.
By default, 1 is used so the variance of all sources is the same.
If the value of local SNR is specified, this parameter will effectively
If the value of ``snr`` is specified, this parameter will effectively
be ignored.
location_params : dict, optional
Keyword arguments that will be passed to ``location`` if a
Expand Down Expand Up @@ -272,10 +287,17 @@ def add_noise_sources(
waveform : array or callable
Waveform provided either directly as an array or as a function.
By default, 1/f noise with the slope of 1 is used for all noise sources.
std : float or array, optional
Desired standard deviation of the source activity, provided either as a
single value that applies to all sources or as an array with one value per
source. By default, 1 is used so the variance of all noise sources is
std : float, array, or SourceEstimate, optional
Desired standard deviation of the source activity, provided via one of
the following options:

- a single value that applies to all sources
- an array with one value per source
- a :class:`~mne.SourceEstimate` object that contains values of all
vertices of the source space. In this case, the value will be adjusted
for each source automatically based on its location.

By default, 1 is used so the variance of all noise sources is
the same.
location_params : dict, optional
Keyword arguments that will be passed to ``location`` if a
Expand Down
2 changes: 1 addition & 1 deletion src/meegsim/snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from scipy.signal import butter, filtfilt

from .sources import _combine_sources_into_stc
from meegsim.sources import _combine_sources_into_stc


def get_sensor_space_variance(stc, fwd, fmin=None, fmax=None, filter=False):
Expand Down
21 changes: 15 additions & 6 deletions src/meegsim/source_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,18 @@
defined by the user until we actually start simulating the data.
"""

from ._check import (
import mne

from meegsim._check import (
check_location,
check_waveform,
check_numeric_array,
check_snr_params,
check_stc_as_param,
check_names,
check_extents,
)
from .sources import PointSource, PatchSource
from meegsim.sources import PointSource, PatchSource


def generate_names(group, n_sources):
Expand Down Expand Up @@ -114,7 +117,7 @@ def create(
The waveform provided by the user.
snr: None, float, or array
The SNR values provided by the user.
std: float or array
std: float, array, or mne.SourceEstimate
The values of standard deviation provided by the user.
location_params: dict, optional
Additional keyword arguments for the location function.
Expand Down Expand Up @@ -143,7 +146,10 @@ def create(
"SNR", snr, n_sources, bounds=(0, None), allow_none=True
)
snr_params = check_snr_params(snr_params, snr)
std = check_numeric_array("std", std, n_sources, bounds=(0, None))
if isinstance(std, mne.SourceEstimate):
check_stc_as_param(std, src)
else:
std = check_numeric_array("std", std, n_sources, bounds=(0, None))

# Auto-generate or check the provided source names
if not names:
Expand Down Expand Up @@ -231,7 +237,7 @@ def create(
The waveform provided by the user.
snr:
The SNR values provided by the user.
std: float or array
std: float, array, or mne.SourceEstimate
The values of standard deviation provided by the user.
location_params: dict, optional
Additional keyword arguments for the location function.
Expand Down Expand Up @@ -262,7 +268,10 @@ def create(
"SNR", snr, n_sources, bounds=(0, None), allow_none=True
)
snr_params = check_snr_params(snr_params, snr)
std = check_numeric_array("std", std, n_sources, bounds=(0, None))
if isinstance(std, mne.SourceEstimate):
check_stc_as_param(std, src)
else:
std = check_numeric_array("std", std, n_sources, bounds=(0, None))
extents = check_extents(extents, n_sources)

# Auto-generate or check the provided source names
Expand Down
27 changes: 24 additions & 3 deletions src/meegsim/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@
import numpy as np
import mne

from .utils import vertices_to_mne, _extract_hemi, _hemi_to_index
from meegsim.utils import (
vertices_to_mne,
_extract_hemi,
_get_param_from_stc,
_hemi_to_index,
)


class _BaseSource:
Expand Down Expand Up @@ -175,6 +180,10 @@ def create(
if data.shape[1] != len(times):
raise ValueError("The number of samples in waveform does not match")

# Get the std values if an stc was provided
if isinstance(stds, mne.SourceEstimate):
stds = _get_param_from_stc(stds, vertices)

# Create point sources and save them as a group
sources = []
for (src_idx, vertno), waveform, std, name in zip(vertices, data, stds, names):
Expand Down Expand Up @@ -229,7 +238,10 @@ def __repr__(self):

@property
def data(self):
return np.tile(self.waveform, (len(self.vertno), 1))
# NOTE: the scaling factor is introduced to make the total variance of
# patch activity invariant to the number of vertices in the patch
scaling_factor = 1 / np.sqrt(len(self.vertno))
return np.tile(self.waveform, (len(self.vertno), 1)) * scaling_factor

@property
def vertices(self):
Expand Down Expand Up @@ -273,9 +285,18 @@ def create(
# find patch vertices
subject = src[0].get("subject_his_id", None)
patch_vertices = []
patch_stds = [] if isinstance(stds, mne.SourceEstimate) else stds
for isource, extent in enumerate(extents):
src_idx, vertno = vertices[isource]

# Get the std values if an stc was provided
# The resulting value either corresponds to the center of the
# patch (extent is not None) or to the average over all
# vertices of the patch
if isinstance(stds, mne.SourceEstimate):
std = _get_param_from_stc(stds, [(src_idx, v) for v in vertno])
patch_stds.append(std.mean())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a misunderstanding. I thought we are populating std from the center of the patch to all other points. Could you clarify this and also under which circumstances the std.mean() could be applied?


# Add vertices as they are if no extent provided
if extent is None:
# Wrap vertno in a list if it is a single number
Expand All @@ -297,7 +318,7 @@ def create(
# Create patch sources and save them as a group
sources = []
for (src_idx, _), patch_vertno, waveform, std, name in zip(
vertices, patch_vertices, data, stds, names
vertices, patch_vertices, data, patch_stds, names
):
hemi = _extract_hemi(src[src_idx])
sources.append(
Expand Down
27 changes: 27 additions & 0 deletions src/meegsim/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,30 @@ def _hemi_to_index(hemi):
Get the index of the hemisphere (0 for lh, 1 for rh).
"""
return ["lh", "rh"].index(hemi)


def _get_param_from_stc(stc, vertices):
"""
Extract parameter values for specified vertices from the provided stc.

Parameters
----------
stc : mne.SourceEstimate
The stc object that contains values for all vertices.
vertices: list
List of tuples (src_idx, vertno) corresponding to the vertices of interest.

Returns
-------
values : array
One value from stc for each vertex.
"""
values = np.zeros((len(vertices),))

# NOTE: we only support surface source estimates for now
offsets = [0, len(stc.vertices[0])]
for i, (src_idx, vertno) in enumerate(vertices):
idx = offsets[src_idx] + np.searchsorted(stc.vertices[src_idx], vertno)
values[i] = stc.data[idx]

return values
Loading