Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: lower-level functions for the adjustment of SNR #9

Merged
merged 7 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 71 additions & 4 deletions src/meegsim/snr.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,79 @@
import numpy as np
import warnings
import mne
from scipy.signal import butter, filtfilt

def get_sensor_space_variance(stc, fwd, *, fmin=None, fmax=None, filter=False):
"""
Estimate the sensor space variance of the provided stc
NB: we need to filter the signal in the frequency band of the oscillation

Parameters
----------
stc: mne.SourceEstimate
Source estimate containing signal or noise (vertices x times).

fwd: mne.Forward
Forward model.

fmin: float, optional
Lower cutoff frequency (in Hz). default = None.

fmax: float, optional
Upper cutoff frequency (in Hz). default = None.

filter: bool, optional
Indicate if filtering in the band of oscillations is required. default = False.

Returns
-------
stc_var: float
Variance with respect to leadfield.
"""
pass

if filter:
if fmin is None:
warnings.warn("fmin was None. Setting fmin to 8 Hz", UserWarning)
fmin = 8.
if fmax is None:
warnings.warn("fmax was None. Setting fmax to 12 Hz", UserWarning)
fmax = 12.
b, a = butter(2, np.array([fmin, fmax]) / stc.sfreq * 2, btype='bandpass')
stc_data = filtfilt(b, a, stc.data, axis=1)
else:
stc_data = stc.data

fwd_restrict = mne.forward.restrict_forward_to_stc(fwd, stc, on_missing='ignore')
leadfield_restict = fwd_restrict['sol']['data']

stc_var = np.mean(stc_data ** 2) * np.mean(leadfield_restict ** 2)
return stc_var

def adjust_snr():

def adjust_snr(signal_var, noise_var, *, target_snr=1):
"""
Derive the signal amplitude that allows obtaining target SNR
"""

Parameters
----------
signal_var: float
Variance of the simulated signal with respect to leadfield. Can be obtained with
a function snr.get_sensor_space_variance.

noise_var: float
Variance of the simulated noise with respect to leadfield. Can be obtained with
a function snr.get_sensor_space_variance.

target_snr: float, optional
Value of a desired SNR for the signal. default = 1.

Returns
-------
out: float
The value that original signal should be scaled (divided) to in order to obtain desired SNR.
"""

if noise_var == 0:
raise ValueError("Noise variance is zero; SNR cannot be calculated.")

snr_current = signal_var / noise_var
return np.sqrt(snr_current / target_snr)
188 changes: 188 additions & 0 deletions tests/test_snr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
import numpy as np
import mne
from unittest.mock import patch

import pytest

from meegsim.snr import get_sensor_space_variance, adjust_snr


def create_dummy_sourcespace(vertices):
# Fill in dummy data as a constant time series equal to the vertex number
n_src_spaces = len(vertices)
type_src = 'surf' if n_src_spaces == 2 else 'vol'
src = []
for i in range(n_src_spaces):
# Create a simple dummy data structure
n_verts = len(vertices[i])
vertno = vertices[i] # Vertices for this hemisphere
xyz = np.random.rand(n_verts, 3) * 100 # Random positions
src_dict = dict(
vertno=vertno,
rr=xyz,
nn=np.random.rand(n_verts, 3), # Random normals
inuse=np.ones(n_verts, dtype=int), # All vertices in use
nuse=n_verts,
type=type_src,
id=i,
np=n_verts
)
src.append(src_dict)

return mne.SourceSpaces(src)


def create_dummy_forward():
# Define the basic parameters for the forward solution
n_sources = 10 # Number of ipoles
n_channels = 5 # Number of MEG/EEG channels

# Create a dummy info structure
info = mne.create_info(ch_names=['MEG1', 'MEG2', 'MEG3', 'EEG1', 'EEG2'],
sfreq=1000., ch_types=['mag', 'mag', 'mag', 'eeg', 'eeg'])

# Generate random source space data (e.g., forward operator)
fwd_data = np.random.randn(n_channels, n_sources)

# Create a dummy source space
lh_vertno = np.arange(n_sources // 2)
rh_vertno = np.arange(n_sources // 2)

src = create_dummy_sourcespace([lh_vertno, rh_vertno])

# Generate random source positions
source_rr = np.random.rand(n_sources, 3)

# Generate random source orientations
source_nn = np.random.randn(n_sources, 3)
source_nn /= np.linalg.norm(source_nn, axis=1, keepdims=True)

# Create a forward solution
forward = {
'sol': {'data': fwd_data},
'_orig_sol': fwd_data,
'sol_grad': None,
'info': info,
'source_ori': 1,
'surf_ori': True,
'nsource': n_sources,
'nchan': n_channels,
'coord_frame': 1,
'src': src,
'source_rr': source_rr,
'source_nn': source_nn,
'_orig_source_ori': 1
}

# Convert the dictionary to an mne.Forward object
fwd = mne.Forward(**forward)

return fwd


def prepare_stc(vertices, num_samples=500):
# Fill in dummy data as a constant time series equal to the vertex number
data = np.tile(vertices[0] + vertices[1], reps=(num_samples, 1)).T
return mne.SourceEstimate(data, vertices, tmin=0, tstep=0.01)


def test_get_sensor_space_variance_no_filter_all_vert():
fwd = create_dummy_forward()
vertices = [list(np.arange(5)), list(np.arange(5))]
stc = prepare_stc(vertices)
variance = get_sensor_space_variance(stc, fwd, filter=False)
assert variance >= 0, "Variance should be non-negative"


def test_get_sensor_space_variance_no_filter_sel_vert():
fwd = create_dummy_forward()
vertices = [[0, 1], [0, 1]]
stc = prepare_stc(vertices)
variance = get_sensor_space_variance(stc, fwd, filter=False)
assert variance >= 0, "Variance should be non-negative"
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that this check is not sufficient - all values (stc and leadfield) are squared so the result should be positive regardless of which vertices are used. Please define the leadfield values yourself and calculate the expected variance with all sources and a subset of sources. Feel free to reduce the dimensionality to make the calculations easier (2x2 leadfield and 1-source stc should already be sufficient to check that the second source does not play a role).

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MEG/EEG aspect is probably not so important here, both are numbers in the end



@patch('meegsim.snr.filtfilt', return_value=np.ones((1, 100)))
@patch('meegsim.snr.butter', return_value=(0, 0))
def test_get_sensor_space_variance_with_filter(butter_mock, filtfilt_mock):
fwd = create_dummy_forward()
vertices = [[0, 1], [0, 1]]
stc = prepare_stc(vertices)
variance = get_sensor_space_variance(stc, fwd, filter=True)

# Check that butter and filtfilt were called
butter_mock.assert_called()
filtfilt_mock.assert_called()

# Check that fmin and fmax are set to default values
butter_args = butter_mock.call_args
sfreq = stc.sfreq
normalized_fmin = 8.0 / (0.5 * sfreq)
normalized_fmax = 12.0 / (0.5 * sfreq)
assert np.isclose(butter_args[0][1][0], normalized_fmin), f"Expected fmin to be {normalized_fmin}"
assert np.isclose(butter_args[0][1][1], normalized_fmax), f"Expected fmax to be {normalized_fmax}"

assert variance >= 0, "Variance should be non-negative"


@patch('meegsim.snr.filtfilt', return_value=np.ones((1, 100)))
@patch('meegsim.snr.butter', return_value=(0, 0))
def test_get_sensor_space_variance_with_filter_fmin_fmax(butter_mock, filtfilt_mock):
fwd = create_dummy_forward()
vertices = [[0, 1], [0, 1]]
stc = prepare_stc(vertices)
variance = get_sensor_space_variance(stc, fwd, filter=True, fmin=20., fmax=30.)

# Check that butter and filtfilt were called
butter_mock.assert_called()
filtfilt_mock.assert_called()

# Check that fmin and fmax
butter_args = butter_mock.call_args
sfreq = stc.sfreq
normalized_fmin = 20.0 / (0.5 * sfreq)
normalized_fmax = 30.0 / (0.5 * sfreq)
assert np.isclose(butter_args[0][1][0], normalized_fmin), f"Expected fmin to be {normalized_fmin}"
assert np.isclose(butter_args[0][1][1], normalized_fmax), f"Expected fmax to be {normalized_fmax}"

assert variance >= 0, "Variance should be non-negative"


@pytest.mark.parametrize("target_snr", np.logspace(-6, 6, 10))
def test_adjust_snr(target_snr):
signal_var = 10.0
noise_var = 5.0

snr_current = signal_var / noise_var
expected_result = np.sqrt(snr_current / target_snr)

result = adjust_snr(signal_var, noise_var, target_snr=target_snr)
assert np.isclose(result, expected_result), f"Expected {expected_result}, but got {result}"


def test_adjust_snr_default_target():
signal_var = 10.0
noise_var = 5.0

default_snr = 1.0
snr_current = signal_var / noise_var
expected_result = np.sqrt(snr_current / default_snr)

result = adjust_snr(signal_var, noise_var)
assert np.isclose(result, expected_result), f"Expected {expected_result}, but got {result}"


def test_adjust_snr_zero_signal_var():
signal_var = 0.0
noise_var = 5.0

result = adjust_snr(signal_var, noise_var)
assert np.isclose(result, 0.0), f"Expected 0.0, but got {result}"


def test_adjust_snr_zero_noise_var():
signal_var = 10.0
noise_var = 0.0

with pytest.raises(ValueError, match="Noise variance is zero; SNR cannot be calculated."):
adjust_snr(signal_var, noise_var)