Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: add coupling functionality to SourceSimulator #34

Merged
merged 23 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Waveforms of white noise, narrowband oscillation (white noise filtered in a narrow frequency band) and 1/f noise with adjustable slope ([#8](https://github.com/ctrltz/meegsim/pull/8))
- Random vertices in the whole source space or in a subset of vertices as location for point sources ([#10](https://github.com/ctrltz/meegsim/pull/10))
- Adjustment of the SNR of the point sources based on sensor space power ([#9](https://github.com/ctrltz/meegsim/pull/9), [#31](https://github.com/ctrltz/meegsim/pull/31))
- Phase-phase coupling with a constant phase lag or a probabilistic phase lag according to the von Mises distribution ([#11](https://github.com/ctrltz/meegsim/pull/11))
- Traversal of the coupling graph to ensure that the coupling is set up correctly when multiple connectivity edges are defined ([#12](https://github.com/ctrltz/meegsim/pull/12))
- Phase-phase coupling with a constant phase lag or a probabilistic phase lag according to the von Mises distribution ([#11](https://github.com/ctrltz/meegsim/pull/11), [#34](https://github.com/ctrltz/meegsim/pull/34))
- Traversal of the coupling graph to ensure that the coupling is set up correctly when multiple connectivity edges are defined ([#12](https://github.com/ctrltz/meegsim/pull/12), [#34](https://github.com/ctrltz/meegsim/pull/34))
60 changes: 25 additions & 35 deletions examples/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
import mne
import numpy as np

from harmoni.extratools import compute_plv
from pathlib import Path
from scipy.signal import butter, filtfilt

from meegsim.coupling import ppc_von_mises
from meegsim.location import select_random
from meegsim.simulate import SourceSimulator
from meegsim.waveform import narrowband_oscillation
Expand All @@ -19,15 +20,16 @@ def to_json(sources):


# Load the head model
fs_dir = Path('/data/hu_studenova/mne_data/MNE-fsaverage-data/fsaverage/')
fwd_path = fs_dir / 'bem' / 'fsaverage-oct6-fwd.fif'
src_path = fs_dir / 'bem' / 'fsaverage-oct6-src.fif'
fs_dir = Path('~/mne_data/MNE-fsaverage-data/fsaverage/')
fwd_path = fs_dir / 'bem_copy' / 'fsaverage-oct6-fwd.fif'
src_path = fs_dir / 'bem_copy' / 'fsaverage-oct6-src.fif'
src = mne.read_source_spaces(src_path)
fwd = mne.read_forward_solution(fwd_path)

# Simulation parameters
sfreq = 250
duration = 60
seed = 123
target_snr = 20

# Channel info
Expand All @@ -42,44 +44,32 @@ def to_json(sources):

sim = SourceSimulator(src)

# Select some vertices randomly (signal/noise does not matter for now)
# sim.add_point_sources(
# location=select_random,
# waveform=narrowband_oscillation,
# location_params=dict(n=10, vertices=[list(src[0]['vertno']), []]),
# waveform_params=dict(fmin=8, fmax=12)
# )
sim.add_patch_sources(
location=select_random,
waveform=narrowband_oscillation,
location_params=dict(n=2, vertices=[list(src[0]['vertno']), []]),
waveform_params=dict(fmin=8, fmax=12),
snr=target_snr,
snr_params=dict(fmin=8, fmax=12),
extents=10
)
sim.add_noise_sources(
location=select_random,
location_params=dict(n=10)
)

sc_noise = sim.simulate(sfreq, duration)
raw_noise = sc_noise.to_raw(fwd, info)

# Select some vertices randomly
sim.add_point_sources(
location=select_random,
waveform=narrowband_oscillation,
location_params=dict(n=1),
location_params=dict(n=3),
waveform_params=dict(fmin=8, fmax=12),
snr=target_snr,
snr_params=dict(fmin=8, fmax=12)
names=['s1', 's2', 's3']
)

sc_full = sim.simulate(sfreq, duration, fwd=fwd)
raw_full = sc_full.to_raw(fwd, info)
# Set coupling
sim.set_coupling(coupling={
('s1', 's2'): dict(kappa=1, phase_lag=np.pi/3),
('s2', 's3'): dict(kappa=10, phase_lag=-np.pi/2)
}, method=ppc_von_mises, fmin=8, fmax=12)

print(sim._coupling_graph)
print(sim._coupling_graph.edges(data=True))

sc = sim.simulate(sfreq, duration, fwd=fwd, random_state=seed)
raw = sc.to_raw(fwd, info)

source_data = np.vstack([s.waveform for s in sc._sources.values()])

print('PLV:', compute_plv(source_data, source_data, n=1, m=1))
print('iPLV:', compute_plv(source_data, source_data, n=1, m=1, plv_type='imag'))

raw = sc_full.to_raw(fwd, info)
spec = raw.compute_psd(n_fft=sfreq, n_overlap=sfreq // 2, n_per_seg=sfreq)
spec = raw.compute_psd(n_fft=sfreq, n_overlap=sfreq//2, n_per_seg=sfreq)
spec.plot(sphere='eeglab')
input('Press any key to continue')
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ classifiers = [
]
dependencies = [
"colorednoise",
"harmoni",
"mne",
"networkx"
]
Expand All @@ -30,6 +29,7 @@ Issues = "https://github.com/ctrltz/meegsim/issues"

[project.optional-dependencies]
dev = [
"harmoni",
"mock",
"pytest",
"pytest-cov"
Expand Down
160 changes: 150 additions & 10 deletions src/meegsim/_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .utils import logger


def check_callable(name, fun, *args, **kwargs):
def check_callable(context, fun, *args, **kwargs):
"""
Check whether the provided function can be run successfully.
The function is always run with random_state set to 0 for consistency.
Expand Down Expand Up @@ -48,7 +48,7 @@ def check_callable(name, fun, *args, **kwargs):
return fun(*args, **kwargs, random_state=0)
except:
logger.error(f'An error occurred when trying to call the '
f'provided {name} function')
f'provided function for: {context}')
raise


Expand Down Expand Up @@ -330,14 +330,154 @@ def check_snr_params(snr_params, snr):
return snr_params


def check_coupling():
# coupling_edge = list(coupling.keys())[0]
# coupling_params = list(coupling.values())[0]
# name1, name2 = coupling_edge[0]
# if missing:
# raise ValueError(f"The configuration contains no sources with the following names: {', '.join(missing)}")
# self.check_if_exist([name1, name2])
pass
def check_if_source_exists(name, existing):
"""
Check if a source exists when trying to set the coupling.

Parameters
----------
name: str
The name of the source to be checked.
existing: list of str
The name of all existing sources

Raises
------
ValueError
If the provided source name is not in the list of existing ones.
"""
if name not in existing:
raise ValueError(f'Source {name} was not defined yet')


def check_coupling_params(method, coupling_params, coupling_edge):
"""
Check whether all required coupling parameters were provided for the
selected method.

Parameters
----------
method: str
The name of the coupling method.
coupling_params: dict
The coupling parameters for the selected method.
coupling_edge: tuple
The coupling edge that the provided parameters apply to.
It is only used to be more specific in the error message.

Raises
------
ValueError
If the provided dictionary does not contain all required parameters.
"""

# Test on a 10 second segment of white noise
sfreq = 100
rng = np.random.default_rng(seed=0)
waveform = rng.random((10 * sfreq,))

# Temporarily remove 'method' from coupling_params
test_params = coupling_params.copy()
test_params.pop('method')

check_callable(f'coupling method, edge {coupling_edge}',
method, waveform, sfreq, **test_params)


def check_coupling(coupling_edge, coupling_params, common_params, names, current_graph):
"""
Check whether the provided coupling edge and parameters are valid.

Parameters
----------
coupling_edge: tuple
The coupling edge (source, target) that the provided parameters apply to.
coupling_params: dict
The coupling parameters that were defined for this edge specifically.
common_params: dict
The coupling parameters that apply to all edges.
names: list of str
The names of sources that exist in the simulation.
current_graph: nx.Graph
The coupling graph that was already defined in the simulation

Raises
------
ValueError
If source or target do not exist in the simulation.
If the coupling edge was defined previously.
If the coupling method or any of the required parameters for the method
are not provided.
"""

# Check that the coupling edge is defined as a tuple of two elements
if not isinstance(coupling_edge, tuple):
raise ValueError(
f'Coupling edges {coupling_edge} should be defined as a tuple'
)
if len(coupling_edge) != 2:
raise ValueError(
f'Coupling edges should contain two elements (names of '
f'the source and the target), got {coupling_edge}'
)

# Check that both source names already exist
source, target = coupling_edge
check_if_source_exists(source, names)
check_if_source_exists(target, names)

# Check that the edge is not a self-loop
if source == target:
raise ValueError(
f'The coupling edge {coupling_edge} is a self-loop, and '
f'only connections between distinct sources are allowed.'
)

# Check that this coupling edge has not been already added
if current_graph.has_edge(*coupling_edge):
raise ValueError(
f'The coupling edge {coupling_edge} already exists in the '
f'simulation, and multiple definitions are not allowed.'
)

# Coupling parameters should be provided in a dictionary
if not isinstance(coupling_params, dict):
actual_type = type(coupling_params).__name__
raise ValueError(
f'Coupling parameters should be provided as a dictionary, '
f'got {actual_type} for edge {coupling_edge}'
)

# Warn the user if some parameters were defined both for specific edges
# and as common ones
double_definition = set(common_params.keys()) & set(coupling_params.keys())
if double_definition:
double_defined = ', '.join(double_definition)
warnings.warn(
f'Parameters {double_defined} have double definition for edge '
f'{coupling_edge}. Edge-specific values have higher priority.'
)

# Overwrite the common coupling parameters with edge-specific ones
params = common_params.copy()
params.update(coupling_params)

# Check that the coupling method was defined
if 'method' not in params:
raise ValueError(f'Coupling method was not defined for the edge {coupling_edge}')
method = params['method']

# Check that the coupling method is a callable
if not callable(method):
raise ValueError(
f'Expected coupling method to be a callable, '
f'got {type(method).__name__} for edge {coupling_edge}'
)

# Check that all required coupling parameters were specified for the selected method
check_coupling_params(method, params, coupling_edge)

return params


def check_extents(extents, n_sources):
Expand Down
Loading