Skip to content

Commit

Permalink
ENH: adjust the SNR in SourceSimulator.simulate() (#31)
Browse files Browse the repository at this point in the history
Major:

* connected the low-level SNR functions to the SourceSimulator
* added checks for snr and snr_params
* sensor space variance is now calculated for the sum of all sources, not the mean
* sensor space variance is now calculated based on source space covariance (this will also be relevant for patches)

Minor:

* removed the default value for target SNR
* removed the default frequency band for the adjustment of SNR
* extracted dummy forward function to utils.prepare in tests
* renamed adjust_snr to amplitude_adjustment
  • Loading branch information
ctrltz authored Sep 19, 2024
1 parent dedcee7 commit 60b3300
Show file tree
Hide file tree
Showing 11 changed files with 542 additions and 197 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- SourceSimulator class that allows adding point sources with custom locations and waveforms to the simulation ([3bad4a8](https://github.com/ctrltz/meegsim/commit/3bad4a86a3712beb43fb404481c15e1a54250d87), [#24](https://github.com/ctrltz/meegsim/pull/24))
- Waveforms of white noise, narrowband oscillation (white noise filtered in a narrow frequency band) and 1/f noise with adjustable slope ([#8](https://github.com/ctrltz/meegsim/pull/8))
- Random vertices in the whole source space or in a subset of vertices as location for point sources ([#10](https://github.com/ctrltz/meegsim/pull/10))
- Adjustment of the SNR of the point sources based on sensor space power ([#9](https://github.com/ctrltz/meegsim/pull/9))
- Adjustment of the SNR of the point sources based on sensor space power ([#9](https://github.com/ctrltz/meegsim/pull/9), [#31](https://github.com/ctrltz/meegsim/pull/31))
- Phase-phase coupling with a constant phase lag or a probabilistic phase lag according to the von Mises distribution ([#11](https://github.com/ctrltz/meegsim/pull/11))
- Traversal of the coupling graph to ensure that the coupling is set up correctly when multiple connectivity edges are defined ([#12](https://github.com/ctrltz/meegsim/pull/12))
49 changes: 31 additions & 18 deletions examples/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
"""

import json
import numpy as np
import mne
import numpy as np

from pathlib import Path
from scipy.signal import butter, filtfilt

from meegsim.location import select_random
from meegsim.simulate import SourceSimulator
Expand All @@ -27,6 +28,10 @@ def to_json(sources):
# Simulation parameters
sfreq = 250
duration = 60
seed = 1234
target_snr = 20

b, a = butter(4, 2 * np.array([8, 12]) / sfreq, 'bandpass')

# Channel info
montage = mne.channels.make_standard_montage('standard_1020')
Expand All @@ -40,29 +45,37 @@ def to_json(sources):

sim = SourceSimulator(src)

# Select some vertices randomly (signal/noise does not matter for now)
sim.add_point_sources(
location=select_random,
waveform=narrowband_oscillation,
location_params=dict(n=10, vertices=[list(src[0]['vertno']), []]),
waveform_params=dict(fmin=8, fmax=12)
)
# Add noise sources
sim.add_noise_sources(
location=select_random,
location_params=dict(n=10, vertices=[[], list(src[1]['vertno'])])
location_params=dict(n=10)
)

# Print the source groups to check internal structure
print(f'Source groups: {sim._source_groups}')
print(f'Noise groups: {sim._noise_groups}')
sc_noise = sim.simulate(sfreq, duration, random_state=seed)
raw_noise = sc_noise.to_raw(fwd, info)

# Select some vertices randomly
sim.add_point_sources(
location=select_random,
waveform=narrowband_oscillation,
location_params=dict(n=1),
waveform_params=dict(fmin=8, fmax=12),
snr=target_snr,
snr_params=dict(fmin=8, fmax=12)
)

sc = sim.simulate(sfreq, duration, random_state=0)
sc_full = sim.simulate(sfreq, duration, fwd=fwd, random_state=seed)
raw_full = sc_full.to_raw(fwd, info)

# Print the sources to check internal structure
print(f'Simulated sources: {to_json(sc._sources)}')
print(f'Simulated noise sources: {to_json(sc._noise_sources)}')
n_samples = sc_full.times.size
noise_data = filtfilt(b, a, raw_noise.get_data())
cov_raw_noise = (noise_data @ noise_data.T) / n_samples
full_data = filtfilt(b, a, raw_full.get_data())
cov_raw_full = (full_data @ full_data.T) / n_samples
snr = np.mean(np.diag(cov_raw_full)) / np.mean(np.diag(cov_raw_noise)) - 1
print(f'Target SNR = {target_snr:.2f}')
print(f'Actual SNR = {snr:.2f}')

raw = sc.to_raw(fwd, info)
spec = raw.compute_psd(n_fft=sfreq, n_overlap=sfreq//2, n_per_seg=sfreq)
spec = raw_full.compute_psd(n_fft=sfreq, n_overlap=sfreq//2, n_per_seg=sfreq)
spec.plot(sphere='eeglab')
input('Press any key to continue')
76 changes: 66 additions & 10 deletions src/meegsim/_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,20 +257,76 @@ def check_names(names, n_sources, existing):
raise ValueError('All names should be unique')


def check_snr(snr, n_vertices):
if snr is not None:
raise NotImplementedError('Adjustment of SNR is not supported yet')
# TODO: check that the number of SNR values matches the number of vertices
# or it is a single SNR value that can be applied to all vertices
def check_snr(snr, n_sources):
"""
Check the user input for SNR: it can either be None (no adjustment of SNR),
a single float value that applies to all sources or an array of values
with one for each source.
Parameters
----------
snr: None, float, or array
The provided value(s) for SNR
n_sources: int
The number of sources.
Raises
------
ValueError
If the provided SNR value(s) do not follow the format described above.
"""

if snr is None:
return None

snr = np.ravel(np.array(snr))
if snr.size != 1 and snr.size != n_sources:
raise ValueError(
f'Expected either one SNR value that applies to all sources or '
f'one SNR value for each of the {n_sources} sources, got {snr.size}'
)

# Only positive values make sense, raise error if negative ones are provided
if np.any(snr < 0):
raise ValueError('Each SNR value should be positive')

# Broadcast to all sources if a single value was provided
if snr.size == 1:
snr = np.tile(snr, (n_sources,))

return snr


def check_snr_params(snr_params):
# TODO: we could try to extract fmin and fmax from waveform_params but
# not sure how confusing will it be, a dedicated waveform class could be
# easier to understand
pass
def check_snr_params(snr_params, snr):
"""
Check the user input for SNR parameters: if the SNR is adjusted (i.e., not None),
then fmin and fmax should be present in the dictionary to define a frequency band.
Parameters
----------
snr_params: dict
The provided dictionary with parameters of the SNR adjustment.
snr: None, float, or array
The provided value for SNR
Raises
------
ValueError
If the provided snr_params dictionary does not have the necessary parameters.
"""
if snr is None:
return snr_params

if 'fmin' not in snr_params or 'fmax' not in snr_params:
raise ValueError(
'Frequency band limits are required for the adjustment of SNR. '
'Please add fmin and fmax to the snr_params dictionary.'
)

if snr_params['fmin'] < 0 or snr_params['fmax'] < 0:
raise ValueError('Frequency limits should be positive')

return snr_params


def check_coupling():
Expand Down
93 changes: 68 additions & 25 deletions src/meegsim/simulate.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .configuration import SourceConfiguration
from .snr import _adjust_snr
from .source_groups import PointSourceGroup
from .waveform import one_over_f_noise

Expand Down Expand Up @@ -28,6 +29,10 @@ def __init__(self, src):
# Store all coupling edges
self._coupling = {}

# Keep track whether SNR of any source should be adjusted
# If yes, then a forward model is required for simulation
self.is_snr_adjusted = False

def add_point_sources(
self,
location,
Expand All @@ -53,16 +58,18 @@ def add_point_sources(
Waveforms of source activity provided either directly in an array (fixed
for every configuration) or as a function that generates the waveforms
(but differ between configurations if the generation is random).
snr: None (do not adjust SNR), float (same SNR for all sources), or array (one value per source)
TODO: fix when finalizing SNR
NB: only positive values make sense, raise error if negative ones are provided
snr: None, float, or array
SNR values for the defined sources. Can be None (no adjustment of SNR),
a single value that is used for all sources or an array with one SNR
value per source.
location_params: dict, optional
Keyword arguments that will be passed to the location function.
waveform_params: dict, optional
Keyword arguments that will be passed to the waveform function.
snr_params: dict, optional
TODO: fix when finalizing SNR
fmin and fmax for the frequency band that will be used to adjust SNR.
Additional parameters required for the adjustment of SNR.
Specify fmin and fmax here to define the frequency band which
should used for calculating the SNR.
names: list, optional
A list of names for each source. If not specified, the names will be
autogenerated using the format 'sgN-sM', where N is the index of the
Expand Down Expand Up @@ -92,6 +99,10 @@ def add_point_sources(
self._source_groups.append(point_sg)
self._sources.extend(point_sg.names)

# Check if SNR should be adjusted
if point_sg.snr is not None:
self.is_snr_adjusted = True

# Return the names of newly added sources
return point_sg.names

Expand Down Expand Up @@ -177,44 +188,82 @@ def simulate(
self,
sfreq,
duration,
fwd=None,
random_state=None
):
"""
Simulate a configuration of defined sources.
Parameters
----------
sfreq: float
The sampling frequency of the simulated data, in Hz.
duration: float
Duration of the simulated data, in seconds.
fwd: mne.Forward, optional
The forward model, only to be used for the adjustment of SNR.
If no adjustment is performed, the forward model is not required.
random_state: int or None, optional
The random state can be provided to obtain reproducible configurations.
If None (default), the simulated data will differ between function calls.
Returns
-------
sc: SourceConfiguration
The source configuration, which contains the defined sources and
their corresponding waveforms.
"""

if not (self._source_groups or self._noise_groups):
raise ValueError('No sources were added to the configuration.')

return _simulate(
if self.is_snr_adjusted and fwd is None:
raise ValueError('A forward model is required for the adjustment '
'of SNR.')

# Initialize the SourceConfiguration
sc = SourceConfiguration(self.src, sfreq, duration, random_state=random_state)

# Simulate signal and noise
sources, noise_sources = _simulate(
self._source_groups,
self._noise_groups,
self.is_snr_adjusted,
self.src,
sfreq,
duration,
sc.times,
fwd=fwd,
random_state=random_state
)

# Add the sources to the simulated configuration
sc._sources = sources
sc._noise_sources = noise_sources

return sc


def _simulate(
source_groups,
noise_groups,
is_snr_adjusted,
src,
sfreq,
duration,
times,
fwd,
random_state=None
):
"""
This function describes the simulation workflow.
"""
# Initialize the SourceConfiguration
sc = SourceConfiguration(src, sfreq, duration, random_state=random_state)


# Simulate all sources independently first (no coupling yet)
noise_sources = []
for ng in noise_groups:
noise_sources.extend(ng.simulate(src, sc.times, random_state=random_state))
noise_sources.extend(ng.simulate(src, times, random_state=random_state))
noise_sources = {s.name: s for s in noise_sources}

sources = []
for sg in source_groups:
sources.extend(sg.simulate(src, sc.times, random_state=random_state))
sources.extend(sg.simulate(src, times, random_state=random_state))
sources = {s.name: s for s in sources}

# Setup the desired coupling patterns
Expand All @@ -225,14 +274,8 @@ def _simulate(
# If there are no cycles, traverse the graph and set coupling according to the selected method
# Try calling the coupling with the provided parameters but be prepared for mismatches

# Adjust the SNR of sources in each source group
# 1. Estimate the noise variance in the specified band
# fwd_noise = mne.forward.restrict_forward_to_stc(fwd, stc_noise, on_missing='raise')
# noise_var = get_sensor_space_variance()
# 2. Adjust the amplitude of each signal source according to the desired SNR (if not None)

# Add the sources to the simulated configuration
sc._sources = sources
sc._noise_sources = noise_sources
# Adjust the SNR if needed
if is_snr_adjusted:
sources = _adjust_snr(src, fwd, sources, source_groups, noise_sources)

return sc
return sources, noise_sources
Loading

0 comments on commit 60b3300

Please sign in to comment.