Skip to content

Commit

Permalink
Cleanup, rename SNR functions
Browse files Browse the repository at this point in the history
  • Loading branch information
ctrltz committed Sep 18, 2024
1 parent e24d527 commit fc80f1d
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 55 deletions.
27 changes: 6 additions & 21 deletions examples/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def to_json(sources):
sfreq = 250
duration = 60
seed = 1234
target_snr = 20

b, a = butter(4, 2 * np.array([8, 12]) / sfreq, 'bandpass')

Expand All @@ -51,45 +52,29 @@ def to_json(sources):
)

sc_noise = sim.simulate(sfreq, duration, random_state=seed)
stc_noise = sc_noise.to_stc()
stc_data = filtfilt(b, a, stc_noise.data, axis=1)

n_sources, n_samples = stc_noise.data.shape
cov_noise = (stc_data @ stc_data.T) / n_samples

fwd_noise = mne.forward.restrict_forward_to_stc(fwd, stc_noise)
L_noise = fwd_noise['sol']['data']
n_sensors, _ = L_noise.shape

print(f'Source space (cov): {np.trace(cov_noise) / n_sources}')
print(f'Source space (mean): {np.mean(stc_data ** 2)}')
print(f'Sensor space (cov): {np.trace(L_noise @ cov_noise @ L_noise.T) / n_sensors}')
print(f'Sensor space (mean): {np.mean(stc_data ** 2) * np.mean(L_noise ** 2)}')

raw_noise = sc_noise.to_raw(fwd, info)
noise_data = filtfilt(b, a, raw_noise.get_data())
cov_raw = (noise_data @ noise_data.T) / n_samples
print(f'Sensor space (cov, raw): {np.mean(np.diag(cov_raw)) * 1e12}')
print(f'Sensor space (mean, raw): {np.mean(noise_data ** 2) * 1e12}')

# Select some vertices randomly
sim.add_point_sources(
location=select_random,
waveform=narrowband_oscillation,
location_params=dict(n=1),
waveform_params=dict(fmin=8, fmax=12),
snr=20,
snr=target_snr,
snr_params=dict(fmin=8, fmax=12)
)

sc_full = sim.simulate(sfreq, duration, fwd=fwd, random_state=seed)
raw_full = sc_full.to_raw(fwd, info)

n_samples = sc_full.times.size
noise_data = filtfilt(b, a, raw_noise.get_data())
cov_raw_noise = (noise_data @ noise_data.T) / n_samples
full_data = filtfilt(b, a, raw_full.get_data())
cov_raw_full = (full_data @ full_data.T) / n_samples
print(np.mean(np.diag(cov_raw_full)) / np.mean(np.diag(cov_raw_noise)) - 1)
snr = np.mean(np.diag(cov_raw_full)) / np.mean(np.diag(cov_raw_noise)) - 1
print(f'Target SNR = {target_snr:.2f}')
print(f'Actual SNR = {snr:.2f}')

spec = raw_full.compute_psd(n_fft=sfreq, n_overlap=sfreq//2, n_per_seg=sfreq)
spec.plot(sphere='eeglab')
Expand Down
4 changes: 2 additions & 2 deletions src/meegsim/_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,14 +266,14 @@ def check_snr(snr, n_sources):
Parameters
----------
snr: None, float, or array
The provided value for SNR
The provided value(s) for SNR
n_sources: int
The number of sources.
Raises
------
ValueError
If the provided SNR value does not follow the requirements described above.
If the provided SNR value(s) do not follow the format described above.
"""

if snr is None:
Expand Down
4 changes: 2 additions & 2 deletions src/meegsim/simulate.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .configuration import SourceConfiguration
from .snr import _setup_snr
from .snr import _adjust_snr
from .source_groups import PointSourceGroup
from .waveform import one_over_f_noise

Expand Down Expand Up @@ -276,6 +276,6 @@ def _simulate(

# Adjust the SNR if needed
if is_snr_adjusted:
sources = _setup_snr(src, fwd, sources, source_groups, noise_sources)
sources = _adjust_snr(src, fwd, sources, source_groups, noise_sources)

return sources, noise_sources
30 changes: 19 additions & 11 deletions src/meegsim/snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .sources import _combine_sources_into_stc


def _get_sensor_space_variance(stc, fwd, *, fmin=None, fmax=None, filter=False):
def get_sensor_space_variance(stc, fwd, *, fmin=None, fmax=None, filter=False):
"""
Estimate the sensor space variance of the provided stc
Expand Down Expand Up @@ -43,9 +43,15 @@ def _get_sensor_space_variance(stc, fwd, *, fmin=None, fmax=None, filter=False):
b, a = butter(2, np.array([fmin, fmax]) / stc.sfreq * 2, btype='bandpass')
stc_data = filtfilt(b, a, stc_data, axis=1)

fwd_restrict = mne.forward.restrict_forward_to_stc(fwd, stc,
on_missing='raise')
leadfield_restict = fwd_restrict['sol']['data']
try:
fwd_restrict = mne.forward.restrict_forward_to_stc(fwd, stc,
on_missing='raise')
leadfield_restict = fwd_restrict['sol']['data']
except ValueError:
raise ValueError(
'The provided forward model does not contain some of the '
'simulated sources, so the SNR cannot be adjusted.'
)

n_samples = stc_data.shape[1]
n_sensors = leadfield_restict.shape[0]
Expand All @@ -56,7 +62,7 @@ def _get_sensor_space_variance(stc, fwd, *, fmin=None, fmax=None, filter=False):
return sensor_var


def _adjust_snr(signal_var, noise_var, target_snr):
def amplitude_adjustment(signal_var, noise_var, target_snr):
"""
Derive the signal amplitude that allows obtaining target SNR
Expand All @@ -75,8 +81,9 @@ def _adjust_snr(signal_var, noise_var, target_snr):
Returns
-------
out: float
The value that original signal should be scaled (multiplied) to in order to obtain desired SNR.
amp: float
The amplitude of the signal that allows obtaining the desired SNR.
The original signal should be multiplied by this value.
"""

snr_current = np.divide(signal_var, noise_var)
Expand All @@ -95,7 +102,7 @@ def _adjust_snr(signal_var, noise_var, target_snr):
return factor


def _setup_snr(src, fwd, sources, source_groups, noise_sources):
def _adjust_snr(src, fwd, sources, source_groups, noise_sources):
# Get the stc and leadfield of all noise sources
stc_noise = _combine_sources_into_stc(noise_sources.values(), src)

Expand All @@ -106,20 +113,21 @@ def _setup_snr(src, fwd, sources, source_groups, noise_sources):

# Estimate the noise variance in the specified frequency band
fmin, fmax = sg.snr_params['fmin'], sg.snr_params['fmax']
noise_var = _get_sensor_space_variance(stc_noise, fwd,
noise_var = get_sensor_space_variance(stc_noise, fwd,
fmin=fmin, fmax=fmax, filter=True)

# Adjust the amplitude of each source in the group to match the target SNR
for name, target_snr in zip(sg.names, sg.snr):
s = sources[name]

# NOTE: taking a safer approach for now and filtering
# even if the signal is already a narrowband oscillation
signal_var = _get_sensor_space_variance(s.to_stc(src), fwd,
signal_var = get_sensor_space_variance(s.to_stc(src), fwd,
fmin=fmin, fmax=fmax, filter=True)

# NOTE: patch sources might require more complex calculations
# if the within-patch correlation is not equal to 1
amp = _adjust_snr(signal_var, noise_var, target_snr)
amp = amplitude_adjustment(signal_var, noise_var, target_snr)
s.waveform *= amp

return sources
2 changes: 1 addition & 1 deletion tests/test_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def test_simulate():
f"Expected 4 sources, got {len(noise_sources)}"


@patch('meegsim.simulate._setup_snr', return_value = [])
@patch('meegsim.simulate._adjust_snr', return_value = [])
def test_simulate_snr_adjustment(setup_snr_mock):
# return mock PointSource's - 1 noise source, 1 signal source
simulate_mock = Mock(side_effect=[
Expand Down
34 changes: 17 additions & 17 deletions tests/test_snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pytest

from meegsim.snr import _get_sensor_space_variance, _adjust_snr, _setup_snr
from meegsim.snr import get_sensor_space_variance, amplitude_adjustment, _adjust_snr
from meegsim.source_groups import PointSourceGroup

from utils.mocks import MockPointSource
Expand All @@ -30,7 +30,7 @@ def test_get_sensor_space_variance_no_filter():
# Since the leadfield values are opposite for these vertices, the
# activity should cancel out in sensor space
expected_variance = 0.
variance = _get_sensor_space_variance(stc, fwd, filter=False)
variance = get_sensor_space_variance(stc, fwd, filter=False)
assert np.isclose(variance, expected_variance), \
f"Expected variance {expected_variance}, but got {variance}"

Expand All @@ -42,7 +42,7 @@ def test_get_sensor_space_variance_no_filter_sel_vert():

# Both vertices in the stc have corresponding zero time series
expected_variance = 0
variance = _get_sensor_space_variance(stc, fwd, filter=False)
variance = get_sensor_space_variance(stc, fwd, filter=False)
assert np.isclose(variance, expected_variance), \
f"Expected variance {expected_variance}, but got {variance}"

Expand All @@ -53,7 +53,7 @@ def test_get_sensor_space_variance_with_filter(butter_mock, filtfilt_mock):
fwd = prepare_forward(5, 10)
vertices = [[0, 1], [0, 1]]
stc = prepare_stc(vertices)
variance = _get_sensor_space_variance(stc, fwd, fmin=8, fmax=12, filter=True)
variance = get_sensor_space_variance(stc, fwd, fmin=8, fmax=12, filter=True)

# Check that butter and filtfilt were called
butter_mock.assert_called()
Expand All @@ -80,7 +80,7 @@ def test_get_sensor_space_variance_with_filter_fmin_fmax(butter_mock, filtfilt_m
fwd = prepare_forward(5, 10)
vertices = [[0, 1], [0, 1]]
stc = prepare_stc(vertices)
_get_sensor_space_variance(stc, fwd, filter=True, fmin=20., fmax=30.)
get_sensor_space_variance(stc, fwd, filter=True, fmin=20., fmax=30.)

# Check that butter and filtfilt were called
butter_mock.assert_called()
Expand All @@ -105,48 +105,48 @@ def test_get_sensor_space_variance_no_fmin_fmax():
stc = prepare_stc(vertices)

# No filtering required - should pass
_get_sensor_space_variance(stc, fwd, filter=False)
get_sensor_space_variance(stc, fwd, filter=False)

# No fmin
with pytest.raises(ValueError, match="Frequency band limits are required"):
_get_sensor_space_variance(stc, fwd, fmax=12, filter=True)
get_sensor_space_variance(stc, fwd, fmax=12, filter=True)

# No fmax
with pytest.raises(ValueError, match="Frequency band limits are required"):
_get_sensor_space_variance(stc, fwd, fmin=8, filter=True)
get_sensor_space_variance(stc, fwd, fmin=8, filter=True)


@pytest.mark.parametrize("target_snr", np.logspace(-6, 6, 10))
def test_adjust_snr(target_snr):
def test_amplitude_adjustment(target_snr):
signal_var = 10.0
noise_var = 5.0

snr_current = np.divide(signal_var, noise_var)
expected_result = np.sqrt(target_snr / snr_current)

result = _adjust_snr(signal_var, noise_var, target_snr=target_snr)
result = amplitude_adjustment(signal_var, noise_var, target_snr=target_snr)
assert np.isclose(result, expected_result), \
f"Expected {expected_result}, but got {result}"


def test_adjust_snr_zero_signal_var():
def test_amplitude_adjustment_zero_signal_var():
signal_var = 0.0
noise_var = 5.0

with pytest.raises(ValueError, match="initial SNR appear to be zero"):
_adjust_snr(signal_var, noise_var, target_snr=1)
amplitude_adjustment(signal_var, noise_var, target_snr=1)


def test_adjust_snr_zero_noise_var():
def test_amplitude_adjustment_zero_noise_var():
signal_var = 10.0
noise_var = 0.0

with pytest.raises(ValueError, match="noise variance appears to be zero"):
_adjust_snr(signal_var, noise_var, target_snr=1)
amplitude_adjustment(signal_var, noise_var, target_snr=1)


@patch('meegsim.snr._adjust_snr', return_value=2.)
def test_setup_snr(adjust_snr_mock):
@patch('meegsim.snr.amplitude_adjustment', return_value=2.)
def test_adjust_snr(adjust_snr_mock):
src = prepare_source_space(
types=['surf', 'surf'],
vertices=[[0, 1], [0, 1]]
Expand All @@ -171,7 +171,7 @@ def test_setup_snr(adjust_snr_mock):
'n1': MockPointSource(name='n1')
}

sources = _setup_snr(src, fwd, sources, source_groups, noise_sources)
sources = _adjust_snr(src, fwd, sources, source_groups, noise_sources)

# Check the SNR adjustment was performed
adjust_snr_mock.assert_called()
Expand Down
2 changes: 1 addition & 1 deletion tests/utils/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def prepare_source_space(types, vertices):
def prepare_forward(n_channels, n_sources,
ch_names=None, ch_types=None, sfreq=250):

assert n_sources % 2 == 0, "Only even number of sources is supported"
assert n_sources % 2 == 0, "Only an even number of sources is supported"

# Create a dummy info structure
if ch_names is None:
Expand Down

0 comments on commit fc80f1d

Please sign in to comment.