Skip to content

Commit

Permalink
ENH: Overlap-add processing for maxwell filter (#13080)
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel McCloy <[email protected]>
  • Loading branch information
larsoner and drammock authored Feb 25, 2025
1 parent c570cfc commit 2119b22
Show file tree
Hide file tree
Showing 12 changed files with 593 additions and 280 deletions.
1 change: 1 addition & 0 deletions doc/changes/devel/13080.apichange.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
The backward-compatible defaults in :func:`mne.preprocessing.maxwell_filter` of ``st_overlap=False`` and ```mc_interp=None`` will change to their smooth variants ``True`` and ``"hann"``, respectively, in 1.11, by `Eric Larson`_.
1 change: 1 addition & 0 deletions doc/changes/devel/13080.newfeature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add smooth processing of tSSS windows (using overlap-add) and movement compensation (using smooth interpolation of head positions) in :func:`mne.preprocessing.maxwell_filter` via ``st_overlap`` and ```mc_interp`` options, respectively, by `Eric Larson`_.
63 changes: 39 additions & 24 deletions mne/_ola.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
from scipy.signal import get_window

from .utils import _ensure_int, logger, verbose
from .utils import _ensure_int, _validate_type, logger, verbose

###############################################################################
# Class for interpolation between adjacent points
Expand Down Expand Up @@ -42,7 +42,7 @@ class _Interp2:
"""

def __init__(self, control_points, values, interp="hann"):
def __init__(self, control_points, values, interp="hann", *, name="Interp2"):
# set up interpolation
self.control_points = np.array(control_points, int).ravel()
if not np.array_equal(np.unique(self.control_points), self.control_points):
Expand Down Expand Up @@ -79,6 +79,7 @@ def val(pt):
self._position = 0 # start at zero
self._left_idx = 0
self._left = self._right = self._use_interp = None
self.name = name
known_types = ("cos2", "linear", "zero", "hann")
if interp not in known_types:
raise ValueError(f'interp must be one of {known_types}, got "{interp}"')
Expand All @@ -90,10 +91,10 @@ def feed_generator(self, n_pts):
n_pts = _ensure_int(n_pts, "n_pts")
original_position = self._position
stop = self._position + n_pts
logger.debug(f"Feed {n_pts} ({self._position}-{stop})")
logger.debug(f" ~ {self.name} Feed {n_pts} ({self._position}-{stop})")
used = np.zeros(n_pts, bool)
if self._left is None: # first one
logger.debug(f" Eval @ 0 ({self.control_points[0]})")
logger.debug(f" ~ {self.name} Eval @ 0 ({self.control_points[0]})")
self._left = self.values(self.control_points[0])
if len(self.control_points) == 1:
self._right = self._left
Expand All @@ -102,7 +103,7 @@ def feed_generator(self, n_pts):
# Left zero-order hold condition
if self._position < self.control_points[self._left_idx]:
n_use = min(self.control_points[self._left_idx] - self._position, n_pts)
logger.debug(f" Left ZOH {n_use}")
logger.debug(f" ~ {self.name} Left ZOH {n_use}")
this_sl = slice(None, n_use)
assert used[this_sl].size == n_use
assert not used[this_sl].any()
Expand All @@ -127,7 +128,9 @@ def feed_generator(self, n_pts):
self._left_idx += 1
self._use_interp = None # need to recreate it
eval_pt = self.control_points[self._left_idx + 1]
logger.debug(f" Eval @ {self._left_idx + 1} ({eval_pt})")
logger.debug(
f" ~ {self.name} Eval @ {self._left_idx + 1} ({eval_pt})"
)
self._right = self.values(eval_pt)
assert self._right is not None
left_point = self.control_points[self._left_idx]
Expand All @@ -148,7 +151,8 @@ def feed_generator(self, n_pts):
n_use = min(stop, right_point) - self._position
if n_use > 0:
logger.debug(
f" Interp {self._interp} {n_use} ({left_point}-{right_point})"
f" ~ {self.name} Interp {self._interp} {n_use} "
f"({left_point}-{right_point})"
)
interp_start = self._position - left_point
assert interp_start >= 0
Expand All @@ -169,7 +173,7 @@ def feed_generator(self, n_pts):
if self.control_points[self._left_idx] <= self._position:
n_use = stop - self._position
if n_use > 0:
logger.debug(f" Right ZOH {n_use}")
logger.debug(f" ~ {self.name} Right ZOH %s" % n_use)
this_sl = slice(n_pts - n_use, None)
assert not used[this_sl].any()
used[this_sl] = True
Expand Down Expand Up @@ -210,14 +214,13 @@ def feed(self, n_pts):


def _check_store(store):
_validate_type(store, (np.ndarray, list, tuple, _Storer), "store")
if isinstance(store, np.ndarray):
store = [store]
if isinstance(store, list | tuple) and all(
isinstance(s, np.ndarray) for s in store
):
if not isinstance(store, _Storer):
if not all(isinstance(s, np.ndarray) for s in store):
raise TypeError("All instances must be ndarrays")
store = _Storer(*store)
if not callable(store):
raise TypeError(f"store must be callable, got type {type(store)}")
return store


Expand All @@ -229,10 +232,8 @@ class _COLA:
process : callable
A function that takes a chunk of input data with shape
``(n_channels, n_samples)`` and processes it.
store : callable | ndarray
A function that takes a completed chunk of output data.
Can also be an ``ndarray``, in which case it is treated as the
output data in which to store the results.
store : ndarray | list of ndarray | _Storer
The output data in which to store the results.
n_total : int
The total number of samples.
n_samples : int
Expand Down Expand Up @@ -276,6 +277,7 @@ def __init__(
window="hann",
tol=1e-10,
*,
name="COLA",
verbose=None,
):
n_samples = _ensure_int(n_samples, "n_samples")
Expand All @@ -302,6 +304,7 @@ def __init__(
self._store = _check_store(store)
self._idx = 0
self._in_buffers = self._out_buffers = None
self.name = name

# Create our window boundaries
window_name = window if isinstance(window, str) else "custom"
Expand Down Expand Up @@ -343,6 +346,7 @@ def feed(self, *datas, verbose=None, **kwargs):
raise ValueError(
f"Got {len(datas)} array(s), needed {len(self._in_buffers)}"
)
current_offset = 0 # should be updated below
for di, data in enumerate(datas):
if not isinstance(data, np.ndarray) or data.ndim < 1:
raise TypeError(
Expand All @@ -363,9 +367,12 @@ def feed(self, *datas, verbose=None, **kwargs):
f"shape[:-1]=={self._in_buffers[di].shape[:-1]}, got dtype "
f"{data.dtype} shape[:-1]={data.shape[:-1]}"
)
# This gets updated on first iteration, so store it before it updates
if di == 0:
current_offset = self._in_offset
logger.debug(
f" + Appending {self._in_offset:d}->"
f"{self._in_offset + data.shape[-1]:d}"
f" + {self.name}[{di}] Appending "
f"{current_offset}:{current_offset + data.shape[-1]}"
)
self._in_buffers[di] = np.concatenate([self._in_buffers[di], data], -1)
if self._in_offset > self.stops[-1]:
Expand All @@ -388,13 +395,18 @@ def feed(self, *datas, verbose=None, **kwargs):
if self._idx == 0:
for offset in range(self._n_samples - self._step, 0, -self._step):
this_window[:offset] += self._window[-offset:]
logger.debug(f" * Processing {start}->{stop}")
this_proc = [in_[..., :this_len].copy() for in_ in self._in_buffers]
logger.debug(
f" * {self.name}[:] Processing {start}:{stop} "
f"(e.g., {this_proc[0].flat[[0, -1]]})"
)
if not all(
proc.shape[-1] == this_len == this_window.size for proc in this_proc
):
raise RuntimeError("internal indexing error")
outs = self._process(*this_proc, **kwargs)
start = self._store.idx
stop = self._store.idx + this_len
outs = self._process(*this_proc, start=start, stop=stop, **kwargs)
if self._out_buffers is None:
max_len = np.max(self.stops - self.starts)
self._out_buffers = [
Expand All @@ -409,9 +421,12 @@ def feed(self, *datas, verbose=None, **kwargs):
else:
next_start = self.stops[-1]
delta = next_start - self.starts[self._idx - 1]
logger.debug(
f" + {self.name}[:] Shifting input and output buffers by "
f"{delta} samples (storing {start}:{stop})"
)
for di in range(len(self._in_buffers)):
self._in_buffers[di] = self._in_buffers[di][..., delta:]
logger.debug(f" - Shifting input/output buffers by {delta:d} samples")
self._store(*[o[..., :delta] for o in self._out_buffers])
for ob in self._out_buffers:
ob[..., :-delta] = ob[..., delta:]
Expand All @@ -430,8 +445,8 @@ def _check_cola(win, nperseg, step, window_name, tol=1e-10):
deviation = np.max(np.abs(binsums - const))
if deviation > tol:
raise ValueError(
f"segment length {nperseg:d} with step {step:d} for {window_name} window "
"type does not provide a constant output "
f"segment length {nperseg} with step {step} for {window_name} "
"window type does not provide a constant output "
f"({100 * deviation / const:g}% deviation)"
)
return const
Expand Down
2 changes: 1 addition & 1 deletion mne/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4824,7 +4824,7 @@ def average_movements(
del head_pos
_check_usable(epochs, ignore_ref)
origin = _check_origin(origin, epochs.info, "head")
recon_trans = _check_destination(destination, epochs.info, True)
recon_trans = _check_destination(destination, epochs.info, "head")

logger.info(f"Aligning and averaging up to {len(epochs.events)} epochs")
if not np.array_equal(epochs.events[:, 0], np.unique(epochs.events[:, 0])):
Expand Down
12 changes: 2 additions & 10 deletions mne/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1664,24 +1664,16 @@ def _mt_spectrum_remove_win(
n_overlap = (n_samples + 1) // 2
x_out = np.zeros_like(x)
rm_freqs = list()
idx = [0]

# Define how to process a chunk of data
def process(x_):
def process(x_, *, start, stop):
out = _mt_spectrum_remove(
x_, sfreq, line_freqs, notch_widths, window_fun, threshold, get_thresh
)
rm_freqs.append(out[1])
return (out[0],) # must return a tuple

# Define how to store a chunk of fully processed data (it's trivial)
def store(x_):
stop = idx[0] + x_.shape[-1]
x_out[..., idx[0] : stop] += x_
idx[0] = stop

_COLA(process, store, n_times, n_samples, n_overlap, sfreq, verbose=False).feed(x)
assert idx[0] == n_times
_COLA(process, x_out, n_times, n_samples, n_overlap, sfreq, verbose=False).feed(x)
return x_out, rm_freqs


Expand Down
Loading

0 comments on commit 2119b22

Please sign in to comment.