Skip to content

Commit

Permalink
bug fix mic
Browse files Browse the repository at this point in the history
  • Loading branch information
tsbinns committed Jun 26, 2023
1 parent 2778693 commit f2e4dad
Showing 1 changed file with 14 additions and 13 deletions.
27 changes: 14 additions & 13 deletions mne_connectivity/spectral/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def __init__(self, n_signals, n_cons, n_freqs, n_times, n_jobs=1):
if self.name == 'MIC':
self.patterns = np.empty((2, self.n_cons), dtype=object)

def compute_con(self, indices, ranks, n_epochs):
def compute_con(self, indices, ranks, n_epochs=1):
"""Compute multivariate imag. part of coherency between signals."""
assert self.name in ['MIC', 'MIM'], (
'the class name is not recognised, please contact the '
Expand Down Expand Up @@ -439,11 +439,7 @@ def compute_con(self, indices, ranks, n_epochs):
self._compute_mic(
E, C, n_seeds, n_times, U_bar_aa, U_bar_bb, con_i)
else:
self._compute_mim(E, con_i)

# Eq. 15 for MIM (same principle for MIC)
if all(np.unique(seed_idcs) == np.unique(target_idcs)):
self.con_scores[con_i] *= 0.5
self._compute_mim(E, seed_idcs, target_idcs, con_i)

con_i += 1

Expand Down Expand Up @@ -550,12 +546,16 @@ def _compute_mic(self, E, C, n_seeds, n_times, U_bar_aa, U_bar_bb, con_i):
beta, axis=3))[..., 0]
) / np.linalg.norm(alpha, axis=2) * np.linalg.norm(beta, axis=2)).T

def _compute_mim(self, E, con_i):
def _compute_mim(self, E, seed_idcs, target_idcs, con_i):
"""Compute MIM (a.k.a. GIM if seeds == targets)."""
# Eq. 14
self.con_scores[con_i] = np.matmul(
E, E.transpose(0, 1, 3, 2)).trace(axis1=2, axis2=3).T

# Eq. 15
if all(np.unique(seed_idcs) == np.unique(target_idcs)):
self.con_scores[con_i] *= 0.5

def reshape_results(self):
"""Remove time dimension from results, if necessary."""
if self.n_times == 0:
Expand Down Expand Up @@ -845,7 +845,7 @@ def __init__(self, n_signals, n_cons, n_freqs, n_times, n_lags, n_jobs=1):
'frequency resolution (%i)' % (n_lags, self.freq_res, ))
self.n_lags = n_lags

def compute_con(self, indices, ranks, n_epochs):
def compute_con(self, indices, ranks, n_epochs=1):
"""Compute multivariate state-space Granger causality."""
assert self.name in ['GC', 'GC time-reversed'], (
'the class name is not recognised, please contact the '
Expand Down Expand Up @@ -1181,7 +1181,7 @@ class _GCEst(_GCEstBase):
name = "GC"


class _TRGCEst(_GCEstBase):
class _GCTREst(_GCEstBase):
"""time-reversed[seeds -> targets] state-space GC estimator."""

name = "GC time-reversed"
Expand Down Expand Up @@ -1415,7 +1415,7 @@ def _get_and_verify_data_sizes(data, sfreq, n_signals=None, n_times=None,
'pli': _PLIEst, 'pli2_unbiased': _PLIUnbiasedEst,
'dpli': _DPLIEst, 'wpli': _WPLIEst,
'wpli2_debiased': _WPLIDebiasedEst, 'mic': _MICEst,
'mim': _MIMEst, 'gc': _GCEst, 'gc_tr': _TRGCEst}
'mim': _MIMEst, 'gc': _GCEst, 'gc_tr': _GCTREst}


def _check_estimators(method):
Expand Down Expand Up @@ -2063,9 +2063,10 @@ def _check_rank_input(rank, data, sfreq, indices):
con_data = EpochsArray(
data_arr[:, con_idcs], con_info, verbose=False)

rank[group_i][con_i] = sum(
compute_rank(con_data, tol=1e-10, tol_kind='relative',
verbose=False).values())
s = np.linalg.svd(con_data.get_data(), compute_uv=False)
rank[group_i][con_i] = np.min(
[np.count_nonzero(epoch >= epoch[0] * 1e-10)
for epoch in s])

logger.info('Estimated data ranks:')
con_i = 1
Expand Down

0 comments on commit f2e4dad

Please sign in to comment.