Skip to content

Commit

Permalink
MRG: Add interpolation for NIRS signals (#7428)
Browse files Browse the repository at this point in the history
* Add interpolation of nirs data. Replaces with nearest

* fNIRS interpolation remove method name and add test
  • Loading branch information
rob-luke authored Apr 6, 2020
1 parent 50f9f7b commit 86d7988
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 3 deletions.
1 change: 1 addition & 0 deletions doc/changes/latest.inc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Changelog

- Added temporal derivative distribution repair :func:`mne.preprocessing.nirs.temporal_derivative_distribution_repair` by `Robert Luke`_

- Add functionality to interpolate bad NIRS channels by `Robert Luke`_

Bug
~~~
Expand Down
4 changes: 3 additions & 1 deletion mne/channels/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,7 +993,8 @@ def interpolate_bads(self, reset_bads=True, mode='accurate',
.. versionadded:: 0.9.0
"""
from ..bem import _check_origin
from .interpolation import _interpolate_bads_eeg, _interpolate_bads_meg
from .interpolation import _interpolate_bads_eeg,\
_interpolate_bads_meg, _interpolate_bads_nirs

_check_preload(self, "interpolation")

Expand All @@ -1003,6 +1004,7 @@ def interpolate_bads(self, reset_bads=True, mode='accurate',
origin = _check_origin(origin, self.info)
_interpolate_bads_eeg(self, origin=origin)
_interpolate_bads_meg(self, mode=mode, origin=origin)
_interpolate_bads_nirs(self)

if reset_bads is True:
self.info['bads'] = []
Expand Down
56 changes: 56 additions & 0 deletions mne/channels/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ..io.pick import pick_types, pick_channels, pick_info
from ..surface import _normalize_vectors
from ..forward import _map_meg_channels
from ..utils import _check_option


def _calc_h(cosang, stiffness=4, n_legendre_terms=50):
Expand Down Expand Up @@ -213,3 +214,58 @@ def _interpolate_bads_meg(inst, mode='accurate', origin=(0., 0., 0.04),
info_to = pick_info(inst.info, picks_bad)
mapping = _map_meg_channels(info_from, info_to, mode=mode, origin=origin)
_do_interp_dots(inst, mapping, picks_good, picks_bad)


@verbose
def _interpolate_bads_nirs(inst, method='nearest', verbose=None):
"""Interpolate bad nirs channels. Simply replaces by closest non bad.
Parameters
----------
inst : mne.io.Raw, mne.Epochs or mne.Evoked
The data to interpolate. Must be preloaded.
method : str
Only the method 'nearest' is currently available. This method replaces
each bad channel with the nearest non bad channel.
%(verbose)s
"""
from scipy.spatial.distance import pdist, squareform
from mne.preprocessing.nirs import _channel_frequencies,\
_check_channels_ordered

# Returns pick of all nirs and ensures channels are correctly ordered
freqs = np.unique(_channel_frequencies(inst))
picks_nirs = _check_channels_ordered(inst, freqs)
if len(picks_nirs) == 0:
return

nirs_ch_names = [inst.info['ch_names'][p] for p in picks_nirs]
bads_nirs = [ch for ch in inst.info['bads'] if ch in nirs_ch_names]
if len(bads_nirs) == 0:
return
picks_bad = pick_channels(inst.info['ch_names'], bads_nirs, exclude=[])
bads_mask = [p in picks_bad for p in picks_nirs]

chs = [inst.info['chs'][i] for i in picks_nirs]
locs3d = np.array([ch['loc'][:3] for ch in chs])

_check_option('fnirs_method', method, ['nearest'])

if method == 'nearest':

dist = pdist(locs3d)
dist = squareform(dist)

for bad in picks_bad:
dists_to_bad = dist[bad]
# Ignore distances to self
dists_to_bad[dists_to_bad == 0] = np.inf
# Ignore distances to other bad channels
dists_to_bad[bads_mask] = np.inf
# Find closest remaining channels for same frequency
closest_idx = np.argmin(dists_to_bad) + (bad % 2)
inst._data[bad] = inst._data[closest_idx]

inst.info['bads'] = []

return inst
21 changes: 21 additions & 0 deletions mne/channels/tests/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@
import numpy as np
from numpy.testing import (assert_allclose, assert_array_equal)
import pytest
from itertools import compress

from mne import io, pick_types, pick_channels, read_events, Epochs
from mne.channels.interpolation import _make_interpolation_matrix
from mne.datasets import testing
from mne.utils import run_tests_if_main
from mne.preprocessing.nirs import optical_density, scalp_coupling_index
from mne.datasets.testing import data_path
from mne.io import read_raw_nirx

base_dir = op.join(op.dirname(__file__), '..', '..', 'io', 'tests', 'data')
raw_fname = op.join(base_dir, 'test_raw.fif')
Expand Down Expand Up @@ -219,4 +223,21 @@ def test_interpolation_ctf_comp():
assert raw.info['bads'] == []


@testing.requires_testing_data
def test_interpolation_nirs():
"""Test interpolating bad nirs channels."""
fname = op.join(data_path(download=False),
'NIRx', 'nirx_15_2_recording_w_overlap')
raw_intensity = read_raw_nirx(fname, preload=False)
raw_od = optical_density(raw_intensity)
sci = scalp_coupling_index(raw_od)
raw_od.info['bads'] = list(compress(raw_od.ch_names, sci < 0.5))
bad_0 = np.where([name == raw_od.info['bads'][0] for
name in raw_od.ch_names])[0][0]
bad_0_std_pre_interp = np.std(raw_od._data[bad_0])
raw_od.interpolate_bads()
assert raw_od.info['bads'] == []
assert bad_0_std_pre_interp > np.std(raw_od._data[bad_0])


run_tests_if_main()
4 changes: 2 additions & 2 deletions mne/preprocessing/nirs/nirs.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def short_channels(info, threshold=0.01):

def _channel_frequencies(raw):
"""Return the light frequency for each channel."""
picks = _picks_to_idx(raw.info, 'fnirs', exclude=[])
picks = _picks_to_idx(raw.info, 'fnirs', exclude=[], allow_empty=True)
freqs = np.empty(picks.size, int)
for ii in picks:
freqs[ii] = raw.info['chs'][ii]['loc'][9]
Expand All @@ -69,7 +69,7 @@ def _check_channels_ordered(raw, freqs):
"""Check channels followed expected fNIRS format."""
# Every second channel should be same SD pair
# and have the specified light frequencies.
picks = _picks_to_idx(raw.info, 'fnirs', exclude=[])
picks = _picks_to_idx(raw.info, 'fnirs', exclude=[], allow_empty=True)
if len(picks) % 2 != 0:
raise ValueError(
'NIRS channels not ordered correctly. An even number of NIRS '
Expand Down

0 comments on commit 86d7988

Please sign in to comment.