Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: lower-level functions for the adjustment of SNR #9

Merged
merged 7 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 83 additions & 4 deletions src/meegsim/snr.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,91 @@
import numpy as np
import warnings
import mne

from scipy.signal import butter, filtfilt


def get_sensor_space_variance(stc, fwd, *, fmin=None, fmax=None, filter=False):
"""
Estimate the sensor space variance of the provided stc
NB: we need to filter the signal in the frequency band of the oscillation

Parameters
----------
stc: mne.SourceEstimate
Source estimate containing signal or noise (vertices x times).

fwd: mne.Forward
Forward model.

fmin: float, optional
Lower cutoff frequency (in Hz). default = None.

fmax: float, optional
Upper cutoff frequency (in Hz). default = None.

filter: bool, optional
Indicate if filtering in the band of oscillations is required. default = False.

Returns
-------
stc_var: float
Variance with respect to leadfield.
"""
pass

if filter:
if fmin is None:
warnings.warn("fmin was None. Setting fmin to 8 Hz", UserWarning)
fmin = 8.
if fmax is None:
warnings.warn("fmax was None. Setting fmax to 12 Hz", UserWarning)
fmax = 12.

b, a = butter(2, np.array([fmin, fmax]) / stc.sfreq * 2, btype='bandpass')
stc_data = filtfilt(b, a, stc.data, axis=1)
else:
stc_data = stc.data

fwd_restrict = mne.forward.restrict_forward_to_stc(fwd, stc, on_missing='ignore')
leadfield_restict = fwd_restrict['sol']['data']

stc_var = np.mean(stc_data ** 2) * np.mean(leadfield_restict ** 2)
return stc_var

def adjust_snr():

def adjust_snr(signal_var, noise_var, *, target_snr=1):
"""
Derive the signal amplitude that allows obtaining target SNR
"""

Parameters
----------
signal_var: float
Variance of the simulated signal with respect to leadfield. Can be obtained with
a function snr.get_sensor_space_variance.

noise_var: float
Variance of the simulated noise with respect to leadfield. Can be obtained with
a function snr.get_sensor_space_variance.

target_snr: float, optional
Value of a desired SNR for the signal. default = 1.

Returns
-------
out: float
The value that original signal should be scaled (multiplied) to in order to obtain desired SNR.
"""

snr_current = np.divide(signal_var, noise_var)

if np.isinf(snr_current):
raise ValueError("The noise variance appears to be zero, so the initial SNR "
"cannot be calculated. Please check the created noise.")

factor = np.sqrt(target_snr / snr_current)

if np.isinf(factor):
raise ValueError("The signal variance and thus the initial SNR appear to be "
"zero, so SNR cannot be adjusted. Please check the created "
"signals.")

return factor
204 changes: 204 additions & 0 deletions tests/test_snr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
import numpy as np
import mne
from unittest.mock import patch

import pytest

from meegsim.snr import get_sensor_space_variance, adjust_snr


def create_dummy_sourcespace(vertices):
# Fill in dummy data as a constant time series equal to the vertex number
n_src_spaces = len(vertices)
type_src = 'surf' if n_src_spaces == 2 else 'vol'
src = []
for i in range(n_src_spaces):
# Create a simple dummy data structure
n_verts = len(vertices[i])
vertno = vertices[i] # Vertices for this hemisphere
xyz = np.random.rand(n_verts, 3) * 100 # Random positions
src_dict = dict(
vertno=vertno,
rr=xyz,
nn=np.random.rand(n_verts, 3), # Random normals
inuse=np.ones(n_verts, dtype=int), # All vertices in use
nuse=n_verts,
type=type_src,
id=i,
np=n_verts
)
src.append(src_dict)

return mne.SourceSpaces(src)


def create_dummy_forward():
# Define the basic parameters for the forward solution
n_sources = 10 # Number of ipoles
n_channels = 5 # Number of MEG/EEG channels

# Create a dummy info structure
info = mne.create_info(ch_names=['MEG1', 'MEG2', 'MEG3', 'EEG1', 'EEG2'],
sfreq=1000., ch_types=['mag', 'mag', 'mag', 'eeg', 'eeg'])

# Generate random source space data (e.g., forward operator)
fwd_data = np.random.randn(n_channels, n_sources)

# Create a dummy source space
lh_vertno = np.arange(n_sources // 2)
rh_vertno = np.arange(n_sources // 2)

src = create_dummy_sourcespace([lh_vertno, rh_vertno])

# Generate random source positions
source_rr = np.random.rand(n_sources, 3)

# Generate random source orientations
source_nn = np.random.randn(n_sources, 3)
source_nn /= np.linalg.norm(source_nn, axis=1, keepdims=True)

# Create a forward solution
forward = {
'sol': {'data': fwd_data},
'_orig_sol': fwd_data,
'sol_grad': None,
'info': info,
'source_ori': 1,
'surf_ori': True,
'nsource': n_sources,
'nchan': n_channels,
'coord_frame': 1,
'src': src,
'source_rr': source_rr,
'source_nn': source_nn,
'_orig_source_ori': 1
}

# Convert the dictionary to an mne.Forward object
fwd = mne.Forward(**forward)

return fwd


def prepare_stc(vertices, num_samples=500):
# Fill in dummy data as a constant time series equal to the vertex number
data = np.tile(vertices[0] + vertices[1], reps=(num_samples, 1)).T
return mne.SourceEstimate(data, vertices, tmin=0, tstep=0.01)


def test_get_sensor_space_variance_no_filter_all_vert():
fwd = create_dummy_forward()
vertices = [list(np.arange(5)), list(np.arange(5))]
stc = prepare_stc(vertices)
leadfield = fwd['sol']['data']
expected_variance = np.mean(stc.data ** 2) * np.mean(leadfield ** 2)
variance = get_sensor_space_variance(stc, fwd, filter=False)
assert np.isclose(variance, expected_variance), \
f"Expected variance {expected_variance}, but got {variance}"


def test_get_sensor_space_variance_no_filter_sel_vert():
fwd = create_dummy_forward()
vertices = [[0], [0]]
stc = prepare_stc(vertices)

# Both vertices in the stc have corresponding zero time series
expected_variance = 0
variance = get_sensor_space_variance(stc, fwd, filter=False)
assert np.isclose(variance, expected_variance), \
f"Expected variance {expected_variance}, but got {variance}"


@patch('meegsim.snr.filtfilt', return_value=np.ones((1, 100)))
@patch('meegsim.snr.butter', return_value=(0, 0))
def test_get_sensor_space_variance_with_filter(butter_mock, filtfilt_mock):
fwd = create_dummy_forward()
vertices = [[0, 1], [0, 1]]
stc = prepare_stc(vertices)
variance = get_sensor_space_variance(stc, fwd, filter=True)

# Check that butter and filtfilt were called
butter_mock.assert_called()
filtfilt_mock.assert_called()

# Check that fmin and fmax are set to default values by looking at
# the normalized frequencies (second argument of scipy.signal.butter)
butter_args = butter_mock.call_args
sfreq = stc.sfreq
expected_wmin = 8.0 / (0.5 * sfreq)
expected_wmax = 12.0 / (0.5 * sfreq)
actual_wmin, actual_wmax = butter_args.args[1]
assert np.isclose(actual_wmin, expected_wmin), \
f"Expected fmin to be {expected_wmin}, got {actual_wmin}"
assert np.isclose(actual_wmax, expected_wmax), \
f"Expected fmax to be {expected_wmax}, got {actual_wmax}"

assert variance >= 0, "Variance should be non-negative"


@patch('meegsim.snr.filtfilt', return_value=np.ones((1, 100)))
@patch('meegsim.snr.butter', return_value=(0, 0))
def test_get_sensor_space_variance_with_filter_fmin_fmax(butter_mock, filtfilt_mock):
fwd = create_dummy_forward()
vertices = [[0, 1], [0, 1]]
stc = prepare_stc(vertices)
variance = get_sensor_space_variance(stc, fwd, filter=True, fmin=20., fmax=30.)

# Check that butter and filtfilt were called
butter_mock.assert_called()
filtfilt_mock.assert_called()

# Check that fmin and fmax are set to custom values by looking at
# the normalized frequencies (second argument of scipy.signal.butter)
butter_args = butter_mock.call_args
sfreq = stc.sfreq
expected_wmin = 20.0 / (0.5 * sfreq)
expected_wmax = 30.0 / (0.5 * sfreq)
actual_wmin, actual_wmax = butter_args.args[1]
assert np.isclose(actual_wmin, expected_wmin), \
f"Expected fmin to be {expected_wmin}, got {actual_wmin}"
assert np.isclose(actual_wmax, expected_wmax), \
f"Expected fmax to be {expected_wmax}, got {actual_wmax}"


@pytest.mark.parametrize("target_snr", np.logspace(-6, 6, 10))
def test_adjust_snr(target_snr):
signal_var = 10.0
noise_var = 5.0

snr_current = np.divide(signal_var, noise_var)
expected_result = np.sqrt(target_snr / snr_current)

result = adjust_snr(signal_var, noise_var, target_snr=target_snr)
assert np.isclose(result, expected_result), \
f"Expected {expected_result}, but got {result}"


def test_adjust_snr_default_target():
signal_var = 10.0
noise_var = 5.0

default_snr = 1.0
snr_current = np.divide(signal_var, noise_var)
expected_result = np.sqrt(default_snr / snr_current)

result = adjust_snr(signal_var, noise_var)
assert np.isclose(result, expected_result), \
f"Expected {expected_result}, but got {result}"


def test_adjust_snr_zero_signal_var():
signal_var = 0.0
noise_var = 5.0

with pytest.raises(ValueError, match="initial SNR appear to be zero"):
adjust_snr(signal_var, noise_var)


def test_adjust_snr_zero_noise_var():
signal_var = 10.0
noise_var = 0.0

with pytest.raises(ValueError, match="noise variance appears to be zero"):
adjust_snr(signal_var, noise_var)