Skip to content

Commit

Permalink
PERF: rework source classes, speed up _combine_sources_into_stc (#40)
Browse files Browse the repository at this point in the history
  • Loading branch information
ctrltz authored Oct 17, 2024
1 parent 406027b commit 562e524
Show file tree
Hide file tree
Showing 14 changed files with 387 additions and 230 deletions.
73 changes: 73 additions & 0 deletions examples/snr_adjustment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""
Testing the adjustment of SNR
"""

import mne
import numpy as np
import matplotlib.pyplot as plt

from pathlib import Path

from meegsim.location import select_random
from meegsim.simulate import SourceSimulator
from meegsim.waveform import narrowband_oscillation


# Load the head model
fs_dir = Path('~/mne_data/MNE-fsaverage-data/fsaverage/')
fwd_path = fs_dir / 'bem_copy' / 'fsaverage-oct6-fwd.fif'
src_path = fs_dir / 'bem_copy' / 'fsaverage-oct6-src.fif'
src = mne.read_source_spaces(src_path)
fwd = mne.read_forward_solution(fwd_path)

# Simulation parameters
sfreq = 250
duration = 60
seed = 123

# Channel info
montage = mne.channels.make_standard_montage('standard_1020')
ch_names = [ch for ch in montage.ch_names if ch not in ['O9', 'O10']]
info = mne.create_info(ch_names, sfreq, ch_types='eeg')
info.set_montage('standard_1020')

# Adapt fwd to the info (could be done by our structure in principle)
fwd = mne.convert_forward_solution(fwd, force_fixed=True)
fwd = mne.pick_channels_forward(fwd, info.ch_names, ordered=True)

fig, axes = plt.subplots(ncols=3, figsize=(8, 3))
snr_values = [1, 5, 10]

for i_snr, target_snr in enumerate(snr_values):
sim = SourceSimulator(src)

# Select some vertices randomly
sim.add_point_sources(
location=select_random,
waveform=narrowband_oscillation,
location_params=dict(n=3),
waveform_params=dict(fmin=8, fmax=12),
snr=target_snr,
snr_params=dict(fmin=8, fmax=12),
names=['s1', 's2', 's3']
)

sim.add_noise_sources(
location=select_random,
location_params=dict(n=10)
)

sc = sim.simulate(sfreq, duration, fwd=fwd, random_state=seed)
raw = sc.to_raw(fwd, info)

spec = raw.compute_psd(fmax=40, n_fft=sfreq,
n_overlap=sfreq//2, n_per_seg=sfreq)
spec.plot(average=True, dB=False, axes=axes[i_snr], amplitude=False)

axes[i_snr].set_title(f'SNR={target_snr}')
axes[i_snr].set_xlabel('Frequency (Hz)')
axes[i_snr].set_ylabel('PSD (uV^2/Hz)')
axes[i_snr].set_ylim([0, 1.25])

fig.tight_layout()
plt.show(block=True)
3 changes: 2 additions & 1 deletion src/meegsim/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(self, src, sfreq, duration, random_state=None):
self.duration = duration
self.n_samples = self.sfreq * self.duration
self.times = np.arange(self.n_samples) / self.sfreq
self.tstep = self.times[1] - self.times[0]

# Random state (for reproducibility)
self.random_state = random_state
Expand All @@ -54,7 +55,7 @@ def to_stc(self):
if not all_sources:
raise ValueError('No sources were added to the configuration.')

return _combine_sources_into_stc(all_sources, self.src)
return _combine_sources_into_stc(all_sources, self.src, self.tstep)

def to_raw(self, fwd, info, scaling_factor=1e-6):
# Parameters:
Expand Down
3 changes: 2 additions & 1 deletion src/meegsim/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,7 @@ def _simulate(

# Adjust the SNR if needed
if is_snr_adjusted:
sources = _adjust_snr(src, fwd, sources, source_groups, noise_sources)
tstep = times[1] - times[0]
sources = _adjust_snr(src, fwd, tstep, sources, source_groups, noise_sources)

return sources, noise_sources
11 changes: 8 additions & 3 deletions src/meegsim/snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,14 @@ def amplitude_adjustment_factor(signal_var, noise_var, target_snr):
return factor


def _adjust_snr(src, fwd, sources, source_groups, noise_sources):
def _adjust_snr(src, fwd, tstep, sources, source_groups, noise_sources):
# Get the stc and leadfield of all noise sources
stc_noise = _combine_sources_into_stc(noise_sources.values(), src)
if not noise_sources:
raise ValueError(
'No noise sources were added to the simulation, so the SNR '
'cannot be adjusted.'
)
stc_noise = _combine_sources_into_stc(noise_sources.values(), src, tstep)

# Adjust the SNR of sources in each source group
for sg in source_groups:
Expand All @@ -121,7 +126,7 @@ def _adjust_snr(src, fwd, sources, source_groups, noise_sources):

# NOTE: taking a safer approach for now and filtering
# even if the signal is already a narrowband oscillation
signal_var = get_sensor_space_variance(s.to_stc(src), fwd,
signal_var = get_sensor_space_variance(s.to_stc(src, tstep), fwd,
fmin=fmin, fmax=fmax, filter=True)

# NOTE: patch sources might require more complex calculations
Expand Down
Loading

0 comments on commit 562e524

Please sign in to comment.