Skip to content

Commit

Permalink
get_sensor_space_variance with restricted leadfield, raise in adjust …
Browse files Browse the repository at this point in the history
…snr, tests
  • Loading branch information
astudenova committed Aug 10, 2024
1 parent 136d3af commit a73d909
Show file tree
Hide file tree
Showing 2 changed files with 202 additions and 7 deletions.
21 changes: 14 additions & 7 deletions src/meegsim/snr.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import numpy as np
import warnings
import mne
from scipy.signal import butter, filtfilt

def get_sensor_space_variance(stc, fwd, *, fmin=None, fmax=None, filter=False):
"""
Expand All @@ -8,7 +10,7 @@ def get_sensor_space_variance(stc, fwd, *, fmin=None, fmax=None, filter=False):
Parameters
----------
stc: mne.SourceEstimate
Source estimate containing with signal or noise (vertices x times).
Source estimate containing signal or noise (vertices x times).
fwd: mne.Forward
Forward model.
Expand All @@ -27,8 +29,7 @@ def get_sensor_space_variance(stc, fwd, *, fmin=None, fmax=None, filter=False):
stc_var: float
Variance with respect to leadfield.
"""
from scipy.signal import butter, filtfilt
leadfield = fwd['sol']['data']

if filter:
if fmin is None:
warnings.warn("fmin was None. Setting fmin to 8 Hz", UserWarning)
Expand All @@ -37,12 +38,14 @@ def get_sensor_space_variance(stc, fwd, *, fmin=None, fmax=None, filter=False):
warnings.warn("fmax was None. Setting fmax to 12 Hz", UserWarning)
fmax = 12.
b, a = butter(2, np.array([fmin, fmax]) / stc.sfreq * 2, btype='bandpass')
stc_data = filtfilt(b, a, stc._data, axis=1)
stc_data = filtfilt(b, a, stc.data, axis=1)
else:
stc_data = stc._data
stc_data = stc.data

fwd_restrict = mne.forward.restrict_forward_to_stc(fwd, stc, on_missing='ignore')
leadfield_restict = fwd_restrict['sol']['data']

nonzero_idx = np.mean(stc_data, axis=1) > 0
stc_var = np.mean(stc_data[nonzero_idx, :] ** 2) * np.mean(leadfield[:, nonzero_idx] ** 2)
stc_var = np.mean(stc_data ** 2) * np.mean(leadfield_restict ** 2)
return stc_var


Expand All @@ -68,5 +71,9 @@ def adjust_snr(signal_var, noise_var, *, target_snr=1):
out: float
The value that original signal should be scaled (divided) to in order to obtain desired SNR.
"""

if noise_var == 0:
raise ValueError("Noise variance is zero; SNR cannot be calculated.")

snr_current = signal_var / noise_var
return np.sqrt(snr_current / target_snr)
188 changes: 188 additions & 0 deletions tests/test_snr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
import numpy as np
import mne
from unittest.mock import patch

import pytest

from meegsim.snr import get_sensor_space_variance, adjust_snr


def create_dummy_sourcespace(vertices):
# Fill in dummy data as a constant time series equal to the vertex number
n_src_spaces = len(vertices)
type_src = 'surf' if n_src_spaces == 2 else 'vol'
src = []
for i in range(n_src_spaces):
# Create a simple dummy data structure
n_verts = len(vertices[i])
vertno = vertices[i] # Vertices for this hemisphere
xyz = np.random.rand(n_verts, 3) * 100 # Random positions
src_dict = dict(
vertno=vertno,
rr=xyz,
nn=np.random.rand(n_verts, 3), # Random normals
inuse=np.ones(n_verts, dtype=int), # All vertices in use
nuse=n_verts,
type=type_src,
id=i,
np=n_verts
)
src.append(src_dict)

return mne.SourceSpaces(src)


def create_dummy_forward():
# Define the basic parameters for the forward solution
n_sources = 10 # Number of ipoles
n_channels = 5 # Number of MEG/EEG channels

# Create a dummy info structure
info = mne.create_info(ch_names=['MEG1', 'MEG2', 'MEG3', 'EEG1', 'EEG2'],
sfreq=1000., ch_types=['mag', 'mag', 'mag', 'eeg', 'eeg'])

# Generate random source space data (e.g., forward operator)
fwd_data = np.random.randn(n_channels, n_sources)

# Create a dummy source space
lh_vertno = np.arange(n_sources // 2)
rh_vertno = np.arange(n_sources // 2)

src = create_dummy_sourcespace([lh_vertno, rh_vertno])

# Generate random source positions
source_rr = np.random.rand(n_sources, 3)

# Generate random source orientations
source_nn = np.random.randn(n_sources, 3)
source_nn /= np.linalg.norm(source_nn, axis=1, keepdims=True)

# Create a forward solution
forward = {
'sol': {'data': fwd_data},
'_orig_sol': fwd_data,
'sol_grad': None,
'info': info,
'source_ori': 1,
'surf_ori': True,
'nsource': n_sources,
'nchan': n_channels,
'coord_frame': 1,
'src': src,
'source_rr': source_rr,
'source_nn': source_nn,
'_orig_source_ori': 1
}

# Convert the dictionary to an mne.Forward object
fwd = mne.Forward(**forward)

return fwd


def prepare_stc(vertices, num_samples=500):
# Fill in dummy data as a constant time series equal to the vertex number
data = np.tile(vertices[0] + vertices[1], reps=(num_samples, 1)).T
return mne.SourceEstimate(data, vertices, tmin=0, tstep=0.01)


def test_get_sensor_space_variance_no_filter_all_vert():
fwd = create_dummy_forward()
vertices = [list(np.arange(5)), list(np.arange(5))]
stc = prepare_stc(vertices)
variance = get_sensor_space_variance(stc, fwd, filter=False)
assert variance >= 0, "Variance should be non-negative"


def test_get_sensor_space_variance_no_filter_sel_vert():
fwd = create_dummy_forward()
vertices = [[0, 1], [0, 1]]
stc = prepare_stc(vertices)
variance = get_sensor_space_variance(stc, fwd, filter=False)
assert variance >= 0, "Variance should be non-negative"


@patch('meegsim.snr.filtfilt', return_value=np.ones((1, 100)))
@patch('meegsim.snr.butter', return_value=(0, 0))
def test_get_sensor_space_variance_with_filter(butter_mock, filtfilt_mock):
fwd = create_dummy_forward()
vertices = [[0, 1], [0, 1]]
stc = prepare_stc(vertices)
variance = get_sensor_space_variance(stc, fwd, filter=True)

# Check that butter and filtfilt were called
butter_mock.assert_called()
filtfilt_mock.assert_called()

# Check that fmin and fmax are set to default values
butter_args = butter_mock.call_args
sfreq = stc.sfreq
normalized_fmin = 8.0 / (0.5 * sfreq)
normalized_fmax = 12.0 / (0.5 * sfreq)
assert np.isclose(butter_args[0][1][0], normalized_fmin), f"Expected fmin to be {normalized_fmin}"
assert np.isclose(butter_args[0][1][1], normalized_fmax), f"Expected fmax to be {normalized_fmax}"

assert variance >= 0, "Variance should be non-negative"


@patch('meegsim.snr.filtfilt', return_value=np.ones((1, 100)))
@patch('meegsim.snr.butter', return_value=(0, 0))
def test_get_sensor_space_variance_with_filter_fmin_fmax(butter_mock, filtfilt_mock):
fwd = create_dummy_forward()
vertices = [[0, 1], [0, 1]]
stc = prepare_stc(vertices)
variance = get_sensor_space_variance(stc, fwd, filter=True, fmin=20., fmax=30.)

# Check that butter and filtfilt were called
butter_mock.assert_called()
filtfilt_mock.assert_called()

# Check that fmin and fmax
butter_args = butter_mock.call_args
sfreq = stc.sfreq
normalized_fmin = 20.0 / (0.5 * sfreq)
normalized_fmax = 30.0 / (0.5 * sfreq)
assert np.isclose(butter_args[0][1][0], normalized_fmin), f"Expected fmin to be {normalized_fmin}"
assert np.isclose(butter_args[0][1][1], normalized_fmax), f"Expected fmax to be {normalized_fmax}"

assert variance >= 0, "Variance should be non-negative"


@pytest.mark.parametrize("target_snr", np.logspace(-6, 6, 10))
def test_adjust_snr(target_snr):
signal_var = 10.0
noise_var = 5.0

snr_current = signal_var / noise_var
expected_result = np.sqrt(snr_current / target_snr)

result = adjust_snr(signal_var, noise_var, target_snr=target_snr)
assert np.isclose(result, expected_result), f"Expected {expected_result}, but got {result}"


def test_adjust_snr_default_target():
signal_var = 10.0
noise_var = 5.0

default_snr = 1.0
snr_current = signal_var / noise_var
expected_result = np.sqrt(snr_current / default_snr)

result = adjust_snr(signal_var, noise_var)
assert np.isclose(result, expected_result), f"Expected {expected_result}, but got {result}"


def test_adjust_snr_zero_signal_var():
signal_var = 0.0
noise_var = 5.0

result = adjust_snr(signal_var, noise_var)
assert np.isclose(result, 0.0), f"Expected 0.0, but got {result}"


def test_adjust_snr_zero_noise_var():
signal_var = 10.0
noise_var = 0.0

with pytest.raises(ValueError, match="Noise variance is zero; SNR cannot be calculated."):
adjust_snr(signal_var, noise_var)

0 comments on commit a73d909

Please sign in to comment.