Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MRG: Add interpolation for NIRS signals #7428

Merged
merged 2 commits into from
Apr 6, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changes/latest.inc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Changelog

- Added temporal derivative distribution repair :func:`mne.preprocessing.nirs.temporal_derivative_distribution_repair` by `Robert Luke`_

- Add functionality to interpolate bad NIRS channels by `Robert Luke`_

Bug
~~~
Expand Down
9 changes: 7 additions & 2 deletions mne/channels/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,7 +962,7 @@ class InterpolationMixin(object):

@verbose
def interpolate_bads(self, reset_bads=True, mode='accurate',
origin='auto', verbose=None):
fnirs_method='nearest', origin='auto', verbose=None):
"""Interpolate bad MEG and EEG channels.

Operates in place.
Expand All @@ -975,6 +975,9 @@ def interpolate_bads(self, reset_bads=True, mode='accurate',
Either ``'accurate'`` or ``'fast'``, determines the quality of the
Legendre polynomial expansion used for interpolation of MEG
channels.
fnirs_method : str
Method to be used for fNIRS interpolation. Currently only 'nearest'
is supported.
origin : array-like, shape (3,) | str
Origin of the sphere in the head coordinate frame and in meters.
Can be ``'auto'`` (default), which means a head-digitization-based
Expand All @@ -993,7 +996,8 @@ def interpolate_bads(self, reset_bads=True, mode='accurate',
.. versionadded:: 0.9.0
"""
from ..bem import _check_origin
from .interpolation import _interpolate_bads_eeg, _interpolate_bads_meg
from .interpolation import _interpolate_bads_eeg,\
_interpolate_bads_meg, _interpolate_bads_nirs

_check_preload(self, "interpolation")

Expand All @@ -1003,6 +1007,7 @@ def interpolate_bads(self, reset_bads=True, mode='accurate',
origin = _check_origin(origin, self.info)
_interpolate_bads_eeg(self, origin=origin)
_interpolate_bads_meg(self, mode=mode, origin=origin)
_interpolate_bads_nirs(self, method=fnirs_method)

if reset_bads is True:
self.info['bads'] = []
Expand Down
55 changes: 55 additions & 0 deletions mne/channels/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ..io.pick import pick_types, pick_channels, pick_info
from ..surface import _normalize_vectors
from ..forward import _map_meg_channels
from ..utils import _check_option


def _calc_h(cosang, stiffness=4, n_legendre_terms=50):
Expand Down Expand Up @@ -213,3 +214,57 @@ def _interpolate_bads_meg(inst, mode='accurate', origin=(0., 0., 0.04),
info_to = pick_info(inst.info, picks_bad)
mapping = _map_meg_channels(info_from, info_to, mode=mode, origin=origin)
_do_interp_dots(inst, mapping, picks_good, picks_bad)


@verbose
def _interpolate_bads_nirs(inst, method='nearest', verbose=None):
"""Interpolate bad nirs channels. Simply replaces by closest non bad.

Parameters
----------
inst : mne.io.Raw, mne.Epochs or mne.Evoked
The data to interpolate. Must be preloaded.
method : str
Only the method 'nearest' is currently available. This method replaces
each bad channel with the nearest non bad channel.
%(verbose)s
"""
from scipy.spatial.distance import pdist, squareform
from mne.preprocessing.nirs import _channel_frequencies,\
_check_channels_ordered

# Returns pick of all nirs and ensures channels are correctly ordered
freqs = np.unique(_channel_frequencies(inst))
picks_nirs = _check_channels_ordered(inst, freqs)
if len(picks_nirs) == 0:
return

nirs_ch_names = [inst.info['ch_names'][p] for p in picks_nirs]
bads_nirs = [ch for ch in inst.info['bads'] if ch in nirs_ch_names]
if len(bads_nirs) == 0:
return
picks_bad = pick_channels(inst.info['ch_names'], bads_nirs, exclude=[])
bads_mask = [p in picks_bad for p in picks_nirs]

chs = [inst.info['chs'][i] for i in picks_nirs]
locs3d = np.array([ch['loc'][:3] for ch in chs])
dist = pdist(locs3d)
dist = squareform(dist)

_check_option('fnirs_method', method, ['nearest'])

if method == 'nearest':

for bad in picks_bad:
dists_to_bad = dist[bad]
# Ignore distances to self
dists_to_bad[dists_to_bad == 0] = np.inf
# Ignore distances to other bad channels
dists_to_bad[bads_mask] = np.inf
# Find closest remaining channels for same frequency
closest_idx = np.argmin(dists_to_bad) + (bad % 2)
inst._data[bad] = inst._data[closest_idx]

inst.info['bads'] = []

return inst
17 changes: 17 additions & 0 deletions mne/channels/tests/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@
import numpy as np
from numpy.testing import (assert_allclose, assert_array_equal)
import pytest
from itertools import compress

from mne import io, pick_types, pick_channels, read_events, Epochs
from mne.channels.interpolation import _make_interpolation_matrix
from mne.datasets import testing
from mne.utils import run_tests_if_main
from mne.preprocessing.nirs import optical_density, scalp_coupling_index
from mne.datasets.testing import data_path
from mne.io import read_raw_nirx

base_dir = op.join(op.dirname(__file__), '..', '..', 'io', 'tests', 'data')
raw_fname = op.join(base_dir, 'test_raw.fif')
Expand Down Expand Up @@ -219,4 +223,17 @@ def test_interpolation_ctf_comp():
assert raw.info['bads'] == []


@testing.requires_testing_data
def test_interpolation_nirs():
"""Test interpolating bad nirs channels."""
fname = op.join(data_path(download=False),
'NIRx', 'nirx_15_2_recording_w_overlap')
raw_intensity = read_raw_nirx(fname, preload=False)
raw_od = optical_density(raw_intensity)
sci = scalp_coupling_index(raw_od)
raw_od.info['bads'] = list(compress(raw_od.ch_names, sci < 0.5))
raw_od.interpolate_bads()
assert raw_od.info['bads'] == []


run_tests_if_main()
4 changes: 2 additions & 2 deletions mne/preprocessing/nirs/nirs.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def short_channels(info, threshold=0.01):

def _channel_frequencies(raw):
"""Return the light frequency for each channel."""
picks = _picks_to_idx(raw.info, 'fnirs', exclude=[])
picks = _picks_to_idx(raw.info, 'fnirs', exclude=[], allow_empty=True)
freqs = np.empty(picks.size, int)
for ii in picks:
freqs[ii] = raw.info['chs'][ii]['loc'][9]
Expand All @@ -69,7 +69,7 @@ def _check_channels_ordered(raw, freqs):
"""Check channels followed expected fNIRS format."""
# Every second channel should be same SD pair
# and have the specified light frequencies.
picks = _picks_to_idx(raw.info, 'fnirs', exclude=[])
picks = _picks_to_idx(raw.info, 'fnirs', exclude=[], allow_empty=True)
if len(picks) % 2 != 0:
raise ValueError(
'NIRS channels not ordered correctly. An even number of NIRS '
Expand Down