Skip to content

Commit

Permalink
Add interpolation of nirs data. Replaces with nearest
Browse files Browse the repository at this point in the history
  • Loading branch information
rob-luke committed Apr 4, 2020
1 parent 59925bb commit 59e431e
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 4 deletions.
2 changes: 2 additions & 0 deletions doc/changes/latest.inc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ Changelog

- :func:`mne.viz.plot_dipole_locations` and :meth:`mne.Dipole.plot_locations` gained a ``title`` argument to specify a custom figure title in ``orthoview`` mode by `Richard Höchenberger`_

- Add functionality to interpolate bad NIRS channels by `Robert Luke`_

Bug
~~~

Expand Down
9 changes: 7 additions & 2 deletions mne/channels/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,7 +962,7 @@ class InterpolationMixin(object):

@verbose
def interpolate_bads(self, reset_bads=True, mode='accurate',
origin='auto', verbose=None):
fnirs_method='nearest', origin='auto', verbose=None):
"""Interpolate bad MEG and EEG channels.
Operates in place.
Expand All @@ -975,6 +975,9 @@ def interpolate_bads(self, reset_bads=True, mode='accurate',
Either ``'accurate'`` or ``'fast'``, determines the quality of the
Legendre polynomial expansion used for interpolation of MEG
channels.
fnirs_method : str
Method to be used for fNIRS interpolation. Currently only 'nearest'
is supported.
origin : array-like, shape (3,) | str
Origin of the sphere in the head coordinate frame and in meters.
Can be ``'auto'`` (default), which means a head-digitization-based
Expand All @@ -993,7 +996,8 @@ def interpolate_bads(self, reset_bads=True, mode='accurate',
.. versionadded:: 0.9.0
"""
from ..bem import _check_origin
from .interpolation import _interpolate_bads_eeg, _interpolate_bads_meg
from .interpolation import _interpolate_bads_eeg,\
_interpolate_bads_meg, _interpolate_bads_nirs

_check_preload(self, "interpolation")

Expand All @@ -1003,6 +1007,7 @@ def interpolate_bads(self, reset_bads=True, mode='accurate',
origin = _check_origin(origin, self.info)
_interpolate_bads_eeg(self, origin=origin)
_interpolate_bads_meg(self, mode=mode, origin=origin)
_interpolate_bads_nirs(self, method=fnirs_method)

if reset_bads is True:
self.info['bads'] = []
Expand Down
57 changes: 57 additions & 0 deletions mne/channels/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ..io.pick import pick_types, pick_channels, pick_info
from ..surface import _normalize_vectors
from ..forward import _map_meg_channels
from ..utils import _check_option


def _calc_h(cosang, stiffness=4, n_legendre_terms=50):
Expand Down Expand Up @@ -213,3 +214,59 @@ def _interpolate_bads_meg(inst, mode='accurate', origin=(0., 0., 0.04),
info_to = pick_info(inst.info, picks_bad)
mapping = _map_meg_channels(info_from, info_to, mode=mode, origin=origin)
_do_interp_dots(inst, mapping, picks_good, picks_bad)


@verbose
def _interpolate_bads_nirs(inst, method='nearest', verbose=None):
"""Interpolate bad nirs channels. Simply replaces by closest non bad.
Parameters
----------
inst : mne.io.Raw, mne.Epochs or mne.Evoked
The data to interpolate. Must be preloaded.
method : str
Only the method 'nearest' is currently available. This method replaces
each bad channel with the nearest non bad channel.
%(verbose)s
"""
from scipy.spatial.distance import pdist, squareform
from mne.preprocessing.nirs import _channel_frequencies,\
_check_channels_ordered

# Returns pick of all nirs and ensures channels are correctly ordered
freqs = np.unique(_channel_frequencies(inst))
picks_nirs = _check_channels_ordered(inst, freqs)
if len(picks_nirs) == 0:
return

nirs_ch_names = [inst.info['ch_names'][p] for p in picks_nirs]
bads_nirs = [ch for ch in inst.info['bads'] if ch in nirs_ch_names]
if len(bads_nirs) == 0:
return
picks_bad = pick_channels(inst.info['ch_names'], bads_nirs, exclude=[])
bads_mask = [p in picks_bad for p in picks_nirs]

chs = [inst.info['chs'][i] for i in picks_nirs]
locs3d = np.array([ch['loc'][:3] for ch in chs])
dist = pdist(locs3d)
dist = squareform(dist)

_check_option('fnirs_method', method, ['nearest'])

if method == 'nearest':

for bad in picks_bad:
dists_to_bad = dist[bad]
# Ignore distances to self
dists_to_bad[dists_to_bad == 0] = np.inf
# Ignore distances to other bad channels
dists_to_bad[bads_mask] = np.inf
# Find closest remaining channels
closest_idx = np.where(dists_to_bad == np.amin(dists_to_bad))[0]
# Return the same frequency
closest_idx = closest_idx[bad % 2]
inst._data[bad] = inst._data[closest_idx]

inst.info['bads'] = []

return inst
18 changes: 18 additions & 0 deletions mne/channels/tests/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,4 +219,22 @@ def test_interpolation_ctf_comp():
assert raw.info['bads'] == []


@testing.requires_testing_data
def test_interpolation_nirs():
"""Test interpolating bad nirs channels."""
from mne.preprocessing.nirs import optical_density, scalp_coupling_index
from mne.datasets.testing import data_path
from mne.io import read_raw_nirx
from itertools import compress

fname = op.join(data_path(download=False),
'NIRx', 'nirx_15_2_recording_w_overlap')
raw_intensity = read_raw_nirx(fname, preload=False)
raw_od = optical_density(raw_intensity)
sci = scalp_coupling_index(raw_od)
raw_od.info['bads'] = list(compress(raw_od.ch_names, sci < 0.5))
raw_od.interpolate_bads()
assert raw_od.info['bads'] == []


run_tests_if_main()
4 changes: 2 additions & 2 deletions mne/preprocessing/nirs/nirs.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def short_channels(info, threshold=0.01):

def _channel_frequencies(raw):
"""Return the light frequency for each channel."""
picks = _picks_to_idx(raw.info, 'fnirs', exclude=[])
picks = _picks_to_idx(raw.info, 'fnirs', exclude=[], allow_empty=True)
freqs = np.empty(picks.size, int)
for ii in picks:
freqs[ii] = raw.info['chs'][ii]['loc'][9]
Expand All @@ -69,7 +69,7 @@ def _check_channels_ordered(raw, freqs):
"""Check channels followed expected fNIRS format."""
# Every second channel should be same SD pair
# and have the specified light frequencies.
picks = _picks_to_idx(raw.info, 'fnirs', exclude=[])
picks = _picks_to_idx(raw.info, 'fnirs', exclude=[], allow_empty=True)
if len(picks) % 2 != 0:
raise ValueError(
'NIRS channels not ordered correctly. An even number of NIRS '
Expand Down

0 comments on commit 59e431e

Please sign in to comment.