From 257dcd20759a6888c445b5af0630c7dd17aa0864 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Wed, 16 Nov 2022 16:15:35 -0500 Subject: [PATCH 1/7] MAINT: Split bad-finding step --- docs/source/changes.md | 3 + mne_bids_pipeline/_config_import.py | 5 +- mne_bids_pipeline/_import_data.py | 376 +++++++++--------- mne_bids_pipeline/_report.py | 102 ++--- mne_bids_pipeline/_run.py | 22 +- .../steps/freesurfer/_02_coreg_surfaces.py | 29 +- ...ves_dir.py => _01_init_derivatives_dir.py} | 9 +- ...d_empty_room.py => _02_find_empty_room.py} | 9 +- mne_bids_pipeline/steps/init/__init__.py | 8 +- .../steps/preprocessing/_01_data_quality.py | 319 +++++++++++++++ .../{_01_maxfilter.py => _02_maxfilter.py} | 174 ++++---- ...ency_filter.py => _03_frequency_filter.py} | 197 +++------ ...{_03_make_epochs.py => _04_make_epochs.py} | 40 +- .../{_04a_run_ica.py => _05a_run_ica.py} | 33 +- .../{_04b_run_ssp.py => _05b_run_ssp.py} | 34 +- .../{_05a_apply_ica.py => _06a_apply_ica.py} | 44 +- .../{_05b_apply_ssp.py => _06b_apply_ssp.py} | 32 +- .../{_06_ptp_reject.py => _07_ptp_reject.py} | 38 +- .../steps/preprocessing/__init__.py | 34 +- .../steps/sensor/_01_make_evoked.py | 41 +- .../steps/sensor/_02_decoding_full_epochs.py | 42 +- .../steps/sensor/_03_decoding_time_by_time.py | 58 +-- .../steps/sensor/_04_time_frequency.py | 38 +- .../steps/sensor/_05_decoding_csp.py | 33 +- .../steps/sensor/_06_make_cov.py | 40 +- .../steps/sensor/_99_group_average.py | 44 +- .../steps/source/_01_make_bem_surfaces.py | 26 +- .../steps/source/_04_make_forward.py | 40 +- .../steps/source/_05_make_inverse.py | 23 +- .../steps/source/_99_group_average.py | 43 +- 30 files changed, 1219 insertions(+), 717 deletions(-) rename mne_bids_pipeline/steps/init/{_00_init_derivatives_dir.py => _01_init_derivatives_dir.py} (92%) rename mne_bids_pipeline/steps/init/{_01_find_empty_room.py => _02_find_empty_room.py} (96%) create mode 100644 mne_bids_pipeline/steps/preprocessing/_01_data_quality.py rename mne_bids_pipeline/steps/preprocessing/{_01_maxfilter.py => _02_maxfilter.py} (74%) rename mne_bids_pipeline/steps/preprocessing/{_02_frequency_filter.py => _03_frequency_filter.py} (63%) rename mne_bids_pipeline/steps/preprocessing/{_03_make_epochs.py => _04_make_epochs.py} (95%) rename mne_bids_pipeline/steps/preprocessing/{_04a_run_ica.py => _05a_run_ica.py} (97%) rename mne_bids_pipeline/steps/preprocessing/{_04b_run_ssp.py => _05b_run_ssp.py} (93%) rename mne_bids_pipeline/steps/preprocessing/{_05a_apply_ica.py => _06a_apply_ica.py} (88%) rename mne_bids_pipeline/steps/preprocessing/{_05b_apply_ssp.py => _06b_apply_ssp.py} (85%) rename mne_bids_pipeline/steps/preprocessing/{_06_ptp_reject.py => _07_ptp_reject.py} (89%) diff --git a/docs/source/changes.md b/docs/source/changes.md index 8bfef5123..28946fa84 100644 --- a/docs/source/changes.md +++ b/docs/source/changes.md @@ -205,6 +205,9 @@ authors: - Make report generation happen within relevant steps instead of at the end of all steps ({{ gh(652) }} by {{ authors.larsoner }}) +- Initial raw data plots are now added to reports and bad channel detection + is executed in a dedicated step + ({{ gh(666) }} by {{ authors.larsoner }}) ### Behavior changes diff --git a/mne_bids_pipeline/_config_import.py b/mne_bids_pipeline/_config_import.py index b2ab0d343..f1151b3c9 100644 --- a/mne_bids_pipeline/_config_import.py +++ b/mne_bids_pipeline/_config_import.py @@ -34,6 +34,7 @@ def _import_config( overrides=overrides, log=log, ) + # Check it if check: _check_config(config) @@ -59,9 +60,9 @@ def _import_config( # Caching 'memory_location', 'memory_verbose', 'memory_file_method', # Misc - 'deriv_root', + 'deriv_root', 'config_path', ) - in_both = {'deriv_root', 'interactive'} + in_both = {'deriv_root'} exec_params = SimpleNamespace(**{k: getattr(config, k) for k in keys}) for k in keys: if k not in in_both: diff --git a/mne_bids_pipeline/_import_data.py b/mne_bids_pipeline/_import_data.py index 79f1b0cd2..0168f602c 100644 --- a/mne_bids_pipeline/_import_data.py +++ b/mne_bids_pipeline/_import_data.py @@ -1,16 +1,14 @@ -import copy from types import SimpleNamespace from typing import Dict, Optional, Iterable, Union, List, Literal import mne -from mne_bids import BIDSPath, read_raw_bids +from mne_bids import BIDSPath, read_raw_bids, get_bids_path_from_fname import numpy as np import pandas as pd -from ._config_utils import get_channels_to_analyze, get_task -from ._io import _write_json +from ._io import _read_json, _empty_room_match_path from ._logging import gen_log_kwargs, logger -from ._viz import plot_auto_scores +from ._run import _update_for_splits from .typing import PathLike @@ -205,129 +203,6 @@ def _rename_events_func( raw.annotations.description = descriptions -def _find_bad_channels( - cfg: SimpleNamespace, - raw: mne.io.BaseRaw, - subject: str, - session: Optional[str], - task: Optional[str], - run: Optional[str], -) -> None: - """Find and mark bad MEG channels. - - Modifies ``raw`` in-place. - """ - if not (cfg.find_flat_channels_meg or cfg.find_noisy_channels_meg): - return - - if (cfg.find_flat_channels_meg and - not cfg.find_noisy_channels_meg): - msg = 'Finding flat channels.' - elif (cfg.find_noisy_channels_meg and - not cfg.find_flat_channels_meg): - msg = 'Finding noisy channels using Maxwell filtering.' - else: - msg = ('Finding flat channels, and noisy channels using ' - 'Maxwell filtering.') - - logger.info(**gen_log_kwargs(message=msg)) - - bids_path = BIDSPath(subject=subject, - session=session, - task=task, - run=run, - acquisition=cfg.acq, - processing=cfg.proc, - recording=cfg.rec, - space=cfg.space, - suffix=cfg.datatype, - datatype=cfg.datatype, - root=cfg.deriv_root) - - # Filter the data manually before passing it to find_bad_channels_maxwell() - # This reduces memory usage, as we can control the number of jobs used - # during filtering. - raw_filt = raw.copy().filter(l_freq=None, h_freq=40, n_jobs=1) - - auto_noisy_chs, auto_flat_chs, auto_scores = \ - mne.preprocessing.find_bad_channels_maxwell( - raw=raw_filt, - calibration=cfg.mf_cal_fname, - cross_talk=cfg.mf_ctc_fname, - origin=cfg.mf_head_origin, - coord_frame='head', - return_scores=True, - h_freq=None # we filtered manually above - ) - del raw_filt - - preexisting_bads = raw.info['bads'].copy() - bads = preexisting_bads.copy() - - if cfg.find_flat_channels_meg: - if auto_flat_chs: - msg = (f'Found {len(auto_flat_chs)} flat channels: ' - f'{", ".join(auto_flat_chs)}') - else: - msg = 'Found no flat channels.' - logger.info(**gen_log_kwargs(message=msg)) - bads.extend(auto_flat_chs) - - if cfg.find_noisy_channels_meg: - if auto_noisy_chs: - msg = (f'Found {len(auto_noisy_chs)} noisy channels: ' - f'{", ".join(auto_noisy_chs)}') - else: - msg = 'Found no noisy channels.' - - logger.info(**gen_log_kwargs(message=msg)) - bads.extend(auto_noisy_chs) - - bads = sorted(set(bads)) - raw.info['bads'] = bads - msg = f'Marked {len(raw.info["bads"])} channels as bad.' - logger.info(**gen_log_kwargs(message=msg)) - - if cfg.find_noisy_channels_meg: - auto_scores_fname = bids_path.copy().update( - suffix='scores', extension='.json', check=False) - # TODO: This should be in our list of output files! - _write_json(auto_scores_fname, auto_scores) - - if cfg.interactive: - import matplotlib.pyplot as plt - plot_auto_scores(auto_scores, ch_types=cfg.ch_types) - plt.show() - - # Write the bad channels to disk. - # TODO: This should also be in our list of output files - bads_tsv_fname = bids_path.copy().update(suffix='bads', - extension='.tsv', - check=False) - bads_for_tsv = [] - reasons = [] - - if cfg.find_flat_channels_meg: - bads_for_tsv.extend(auto_flat_chs) - reasons.extend(['auto-flat'] * len(auto_flat_chs)) - preexisting_bads = set(preexisting_bads) - set(auto_flat_chs) - - if cfg.find_noisy_channels_meg: - bads_for_tsv.extend(auto_noisy_chs) - reasons.extend(['auto-noisy'] * len(auto_noisy_chs)) - preexisting_bads = set(preexisting_bads) - set(auto_noisy_chs) - - preexisting_bads = list(preexisting_bads) - if preexisting_bads: - bads_for_tsv.extend(preexisting_bads) - reasons.extend(['pre-existing (before MNE-BIDS-pipeline was run)'] * - len(preexisting_bads)) - - tsv_data = pd.DataFrame(dict(name=bads_for_tsv, reason=reasons)) - tsv_data = tsv_data.sort_values(by='name') - tsv_data.to_csv(bads_tsv_fname, sep='\t', index=False) - - def _load_data(cfg: SimpleNamespace, bids_path: BIDSPath) -> mne.io.BaseRaw: # read_raw_bids automatically # - populates bad channels using the BIDS channels.tsv @@ -338,12 +213,6 @@ def _load_data(cfg: SimpleNamespace, bids_path: BIDSPath) -> mne.io.BaseRaw: raw = read_raw_bids(bids_path=bids_path, extra_params=cfg.reader_extra_params) - # Save only the channel types we wish to analyze (including the - # channels marked as "bad"). - if not cfg.use_maxwell_filter: - picks = get_channels_to_analyze(raw.info, cfg) - raw.pick(picks) - _crop_data(cfg, raw=raw, subject=subject) raw.load_data() @@ -465,6 +334,8 @@ def import_experimental_data( *, cfg: SimpleNamespace, bids_path_in: BIDSPath, + bids_path_bads_in: Optional[BIDSPath], + data_is_rest: Optional[bool], ) -> mne.io.BaseRaw: """Run the data import. @@ -474,6 +345,11 @@ def import_experimental_data( The local configuration. bids_path_in The BIDS path to the data to import. + bids_path_bads_in + The BIDS path to the bad channels file. + data_is_rest : bool | None + Whether the data is resting state data. If ``None``, ``cfg.task`` + is checked. Returns ------- @@ -493,13 +369,20 @@ def import_experimental_data( _drop_channels_func(cfg=cfg, raw=raw, subject=subject, session=session) _find_breaks_func(cfg=cfg, raw=raw, subject=subject, session=session, run=run) - if cfg.task != "rest": + if data_is_rest is None: + data_is_rest = (cfg.task == 'rest') or cfg.task_is_rest + if not data_is_rest: _rename_events_func( cfg=cfg, raw=raw, subject=subject, session=session, run=run ) _fix_stim_artifact_func(cfg=cfg, raw=raw) - _find_bad_channels(cfg=cfg, raw=raw, subject=subject, session=session, - task=get_task(cfg), run=run) + + if bids_path_bads_in is not None: + bads = _read_bads_tsv(cfg=cfg, bids_path_bads=bids_path_bads_in) + msg = f'Marking {len(bads)} channels as bad.' + logger.info(**gen_log_kwargs(message=msg)) + raw.info['bads'] = bads + raw.info._check_consistency() return raw @@ -508,7 +391,9 @@ def import_er_data( *, cfg: SimpleNamespace, bids_path_er_in: BIDSPath, - bids_path_ref_in: BIDSPath, + bids_path_ref_in: Optional[BIDSPath], + bids_path_er_bads_in: Optional[BIDSPath], + bids_path_ref_bads_in: Optional[BIDSPath], ) -> mne.io.BaseRaw: """Import empty-room data. @@ -520,6 +405,10 @@ def import_er_data( The BIDS path to the empty room data. bids_path_ref_in The BIDS path to the reference data. + bids_path_er_bads_in + The BIDS path to the empty room bad channels file. + bids_path_ref_bads_in + The BIDS path to the reference data bad channels file. Returns ------- @@ -530,74 +419,46 @@ def import_er_data( session = bids_path_er_in.session _drop_channels_func(cfg, raw=raw_er, subject='emptyroom', session=session) - - # Only keep MEG channels. + if bids_path_er_bads_in is not None: + raw_er.info['bads'] = _read_bads_tsv( + cfg=cfg, + bids_path_bads=bids_path_er_bads_in, + ) raw_er.pick_types(meg=True, exclude=[]) - # TODO: This 'union' operation should affect the raw runs, too, otherwise - # rank mismatches will still occur (eventually for some configs). - # But at least using the union here should reduce them. - # TODO: We should also uso automatic bad finding on the empty room data + # Don't deal with ref for now (initial data quality / auto bad step) + if bids_path_ref_in is None: + return raw_er + + # Load reference run plus its auto-bads raw_ref = read_raw_bids(bids_path_ref_in, extra_params=cfg.reader_extra_params) + if bids_path_ref_bads_in is not None: + bads = _read_bads_tsv( + cfg=cfg, + bids_path_in=bids_path_ref_bads_in, + ) + raw_ref.info['bads'] = bads + raw_ref.info._check_consistency() + raw_ref.pick_types(meg=True, exclude=[]) + if cfg.use_maxwell_filter: # We need to include any automatically found bad channels, if relevant. - # TODO this is a bit of a hack because we don't use "in_files" access - # here, but this is *in the same step where this file is generated* - # so we cannot / should not put it in `in_files`. - if cfg.find_flat_channels_meg or cfg.find_noisy_channels_meg: - # match filename from _find_bad_channels - bads_tsv_fname = bids_path_ref_in.copy().update( - suffix='bads', extension='.tsv', root=cfg.deriv_root, - check=False) - bads_tsv = pd.read_csv(bads_tsv_fname.fpath, sep='\t', header=0) - bads_tsv = bads_tsv[bads_tsv.columns[0]].tolist() - raw_ref.info['bads'] = sorted( - set(raw_ref.info['bads']) | set(bads_tsv)) - raw_ref.info._check_consistency() + # TODO: This 'union' operation should affect the raw runs, too, + # otherwise rank mismatches will still occur (eventually for some + # configs). But at least using the union here should reduce them. raw_er = mne.preprocessing.maxwell_filter_prepare_emptyroom( raw_er=raw_er, raw=raw_ref, bads='union', ) else: - # Set same set of bads as in the reference run, but only for MEG - # channels (we might not have non-MEG channels in empty-room - # recordings). - raw_er.info['bads'] = [ch for ch in raw_ref.info['bads'] - if ch.startswith('MEG')] + # Take bads from the reference run + raw_er.info['bads'] = raw_ref.info['bads'] return raw_er -def import_rest_data( - *, - cfg: SimpleNamespace, - bids_path_in: BIDSPath, -) -> mne.io.BaseRaw: - """Import resting-state data for use as a noise source. - - Parameters - ---------- - cfg - The local configuration. - bids_path_in : BIDSPath - The path. - - Returns - ------- - raw_rest - The imported data. - """ - cfg = copy.deepcopy(cfg) - cfg.task = 'rest' - - raw_rest = import_experimental_data( - cfg=cfg, bids_path_in=bids_path_in, - ) - return raw_rest - - def _find_breaks_func( *, cfg, @@ -627,3 +488,138 @@ def _find_breaks_func( logger.info(**gen_log_kwargs(message=msg)) raw.set_annotations(raw.annotations + break_annots) # add to existing + + +def _get_raw_paths( + *, + cfg: SimpleNamespace, + subject: str, + session: Optional[str], + run: Optional[str], + kind: Literal['raw', 'sss'], + add_bads: bool = True, +) -> dict: + # Construct the basenames of the files we wish to load, and of the empty- + # room recording we wish to save. + # The basenames of the empty-room recording output file does not contain + # the "run" entity. + path_kwargs = dict( + subject=subject, + run=run, + session=session, + task=cfg.task, + acquisition=cfg.acq, + recording=cfg.rec, + space=cfg.space, + datatype=cfg.datatype, + check=False + ) + if kind == 'sss': + path_kwargs['root'] = cfg.deriv_root + path_kwargs['suffix'] = 'raw' + path_kwargs['extension'] = '.fif' + path_kwargs['processing'] = 'sss' + else: + assert kind == 'orig', kind + path_kwargs['root'] = cfg.bids_root + path_kwargs['suffix'] = None + path_kwargs['extension'] = None + path_kwargs['processing'] = cfg.proc + bids_path_in = BIDSPath(**path_kwargs) + + in_files = dict() + key = f'raw_run-{run}' + in_files[key] = bids_path_in + _update_for_splits(in_files, key, single=True) + if add_bads: + _add_bads_file( + cfg=cfg, + in_files=in_files, + key=key, + ) + + if run == cfg.runs[0]: + do = dict( + rest=cfg.process_rest and not cfg.task_is_rest, + noise=cfg.process_empty_room and cfg.datatype == 'meg', + ) + for task in ('rest', 'noise'): + if not do[task]: + continue + key = f'raw_{task}' + if kind == 'sss': + raw_fname = bids_path_in.copy().update( + run=None, task=task) + else: + if task == 'rest': + raw_fname = bids_path_in.copy().update( + run=None, task=task) + else: + raw_fname = _read_json( + _empty_room_match_path(bids_path_in, cfg))['fname'] + if raw_fname is not None: + raw_fname = get_bids_path_from_fname(raw_fname) + if raw_fname is None: + continue + in_files[key] = raw_fname + _update_for_splits( + in_files, key, single=True, allow_missing=True) + if not in_files[key].fpath.exists(): + in_files.pop(key) + elif add_bads: + _add_bads_file( + cfg=cfg, + in_files=in_files, + key=key, + ) + + return in_files + + +def _add_bads_file( + *, + cfg: SimpleNamespace, + in_files: dict, + key: str, +) -> None: + bids_path_in = in_files[key] + bads_tsv_fname = _bads_path(cfg=cfg, bids_path_in=bids_path_in) + if bads_tsv_fname.fpath.is_file(): + in_files[f'{key}-bads'] = bads_tsv_fname + + +def _auto_scores_path( + *, + cfg: SimpleNamespace, + bids_path_in: BIDSPath, +) -> BIDSPath: + return bids_path_in.copy().update( + suffix='scores', + extension='.json', + root=cfg.deriv_root, + split=None, + check=False, + ) + + +def _bads_path( + *, + cfg: SimpleNamespace, + bids_path_in: BIDSPath, +) -> BIDSPath: + return bids_path_in.copy().update( + suffix='bads', + extension='.tsv', + root=cfg.deriv_root, + split=None, + check=False, + ) + + +def _read_bads_tsv( + *, + cfg: SimpleNamespace, + bids_path_bads: BIDSPath, +) -> List[str]: + bads_tsv = pd.read_csv(bids_path_bads.fpath, sep='\t', header=0) + return bads_tsv[bads_tsv.columns[0]].tolist() diff --git a/mne_bids_pipeline/_report.py b/mne_bids_pipeline/_report.py index 588086d3c..5363e64b0 100644 --- a/mne_bids_pipeline/_report.py +++ b/mne_bids_pipeline/_report.py @@ -19,13 +19,13 @@ sanitize_cond_name, get_subjects, _restrict_analyze_channels) from ._decoding import _handle_csp_args from ._logging import logger, gen_log_kwargs -from ._viz import plot_auto_scores @contextlib.contextmanager def _open_report( *, cfg: SimpleNamespace, + exec_params: SimpleNamespace, subject: str, session: Optional[str], run: Optional[str] = None, @@ -69,7 +69,7 @@ def _open_report( try: msg = 'Adding config and sys info to report' logger.info(**gen_log_kwargs(message=msg)) - _finalize(report=report, cfg=cfg) + _finalize(report=report, exec_params=exec_params) except Exception: pass fname_report_html = fname_report.with_suffix('.html') @@ -81,43 +81,6 @@ def _open_report( open_browser=False) -def plot_auto_scores_(cfg, subject, session): - """Plot automated bad channel detection scores. - """ - import json_tricks - fname_scores = BIDSPath(subject=subject, - session=session, - task=cfg.task, - acquisition=cfg.acq, - run=None, - processing=cfg.proc, - recording=cfg.rec, - space=cfg.space, - suffix='scores', - extension='.json', - datatype=cfg.datatype, - root=cfg.deriv_root, - check=False) - - all_figs = [] - all_captions = [] - for run in cfg.runs: - fname_scores.update(run=run) - auto_scores = json_tricks.loads( - fname_scores.fpath.read_text(encoding='utf-8-sig') - ) - - figs = plot_auto_scores(auto_scores, ch_types=cfg.ch_types) - all_figs.extend(figs) - - # Could be more than 1 fig, e.g. "grad" and "mag" - captions = [f'Run {run}'] * len(figs) - all_captions.extend(captions) - - assert all_figs - return all_figs, all_captions - - # def plot_full_epochs_decoding_scores( # contrast: str, # cross_val_scores: np.ndarray, @@ -504,7 +467,7 @@ def add_event_counts(*, report.add_custom_css(css=css) -def _finalize(*, report: mne.Report, cfg: SimpleNamespace): +def _finalize(*, report: mne.Report, exec_params: SimpleNamespace): """Add system information and the pipeline configuration to the report.""" # ensure they are always appended titles = ['Configuration file', 'System information'] @@ -512,7 +475,7 @@ def _finalize(*, report: mne.Report, cfg: SimpleNamespace): report.remove(title=title, remove_all=True) # No longer need replace=True in these report.add_code( - code=cfg.config_path, + code=exec_params.config_path, title=titles[0], tags=('configuration',), ) @@ -539,8 +502,13 @@ def _all_conditions(*, cfg): return conditions -def run_report_average_sensor(*, cfg, session: str) -> None: - subject = 'average' +def run_report_average_sensor( + *, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, + subject: str, + session: Optional[str], +) -> None: msg = 'Generating grand average report …' logger.info(**gen_log_kwargs(message=msg)) assert matplotlib.get_backend() == 'agg', matplotlib.get_backend() @@ -571,7 +539,11 @@ def run_report_average_sensor(*, cfg, session: str) -> None: _restrict_analyze_channels(evoked, cfg) conditions = _all_conditions(cfg=cfg) - with _open_report(cfg=cfg, subject=subject, session=session) as report: + with _open_report( + cfg=cfg, + exec_params=exec_params, + subject=subject, + session=session) as report: ####################################################################### # @@ -625,12 +597,17 @@ def run_report_average_sensor(*, cfg, session: str) -> None: ) -def run_report_average_source(*, cfg, session: str) -> None: +def run_report_average_source( + *, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, + subject: str, + session: Optional[str], +) -> None: ####################################################################### # # Visualize forward solution, inverse operator, and inverse solutions. # - subject = 'average' evoked_fname = BIDSPath( subject=subject, session=session, @@ -651,7 +628,11 @@ def run_report_average_source(*, cfg, session: str) -> None: hemi_str = 'hemi' # MNE will auto-append '-lh' and '-rh'. morph_str = 'morph2fsaverage' conditions = _all_conditions(cfg=cfg) - with _open_report(cfg=cfg, subject=subject, session=session) as report: + with _open_report( + cfg=cfg, + exec_params=exec_params, + subject=subject, + session=session) as report: for condition, evoked in zip(conditions, evokeds): tags = ( 'source-estimate', @@ -1128,3 +1109,30 @@ def _agg_backend(): yield finally: matplotlib.use(backend, force=True) + + +def _add_raw( + *, + cfg: SimpleNamespace, + report: mne.report.Report, + bids_path_in: BIDSPath, + title: str, +): + if bids_path_in.run is not None: + title += f', run {bids_path_in.run}' + elif bids_path_in.task in ('noise', 'rest'): + title += f', run {bids_path_in.task}' + plot_raw_psd = ( + cfg.plot_psd_for_runs == 'all' or + bids_path_in.run in cfg.plot_psd_for_runs or + bids_path_in.task in cfg.plot_psd_for_runs + ) + report.add_raw( + raw=bids_path_in, + title=title, + butterfly=5, + psd=plot_raw_psd, + tags=('raw', 'filtered', f'run-{bids_path_in.run}'), + # caption=bids_path_in.basename, # TODO upstream + replace=True, + ) diff --git a/mne_bids_pipeline/_run.py b/mne_bids_pipeline/_run.py index c2c75e67f..7797bf844 100644 --- a/mne_bids_pipeline/_run.py +++ b/mne_bids_pipeline/_run.py @@ -31,8 +31,7 @@ def failsafe_run( def failsafe_run_decorator(func): @functools.wraps(func) # Preserve "identity" of original function def wrapper(*args, **kwargs): - exec_params = kwargs['cfg'].exec_params - delattr(kwargs['cfg'], 'exec_params') + exec_params = kwargs['exec_params'] on_error = exec_params.on_error memory = ConditionalStepMemory( exec_params=exec_params, @@ -134,6 +133,8 @@ def __init__(self, *, exec_params, get_input_fnames, get_output_fnames): use_location, verbose=exec_params.memory_verbose) else: self.memory = None + # Ignore these as they have no effect on the output + self.ignore = ['exec_params'] self.get_input_fnames = get_input_fnames self.get_output_fnames = get_output_fnames self.memory_file_method = exec_params.memory_file_method @@ -143,10 +144,15 @@ def cache(self, func): def wrapper(*args, **kwargs): in_files = out_files = None force_run = kwargs.pop('force_run', False) + these_kwargs = kwargs.copy() + these_kwargs.pop('exec_params', None) if self.get_output_fnames is not None: - out_files = self.get_output_fnames(**kwargs) + out_files = self.get_output_fnames( + **these_kwargs) if self.get_input_fnames is not None: - in_files = kwargs['in_files'] = self.get_input_fnames(**kwargs) + in_files = kwargs['in_files'] = self.get_input_fnames( + **these_kwargs) + del these_kwargs if self.memory is None: func(*args, **kwargs) return @@ -200,13 +206,17 @@ def hash_(k, v): # Someday we could modify the joblib API to combine this with the # call (https://github.com/joblib/joblib/issues/1342), but our hash # should be plenty fast so let's not bother for now. - memorized_func = self.memory.cache(func) + memorized_func = self.memory.cache(func, ignore=self.ignore) msg = emoji = None short_circuit = False subject = kwargs.get('subject', None) session = kwargs.get('session', None) run = kwargs.get('run', None) - if memorized_func.check_call_in_cache(*args, **kwargs): + try: + done = memorized_func.check_call_in_cache(*args, **kwargs) + except Exception: + done = False + if done: if unknown_inputs: msg = ('Computation forced because input files cannot ' f'be determined ({unknown_inputs}) …') diff --git a/mne_bids_pipeline/steps/freesurfer/_02_coreg_surfaces.py b/mne_bids_pipeline/steps/freesurfer/_02_coreg_surfaces.py index 2a6f64752..2d728c9af 100644 --- a/mne_bids_pipeline/steps/freesurfer/_02_coreg_surfaces.py +++ b/mne_bids_pipeline/steps/freesurfer/_02_coreg_surfaces.py @@ -19,16 +19,19 @@ fs_bids_app = Path(__file__).parent / 'contrib' / 'run.py' -def get_input_fnames_coreg_surfaces(**kwargs): - cfg = kwargs.pop('cfg') - kwargs.pop('subject') # unused - assert len(kwargs) == 0, kwargs.keys() - del kwargs - in_files = _get_scalp_in_files(cfg) - return in_files +def get_input_fnames_coreg_surfaces( + *, + cfg: SimpleNamespace, + subject: str, +) -> dict: + return _get_scalp_in_files(cfg) -def get_output_fnames_coreg_surfaces(*, cfg, subject): +def get_output_fnames_coreg_surfaces( + *, + cfg: SimpleNamespace, + subject: str +) -> dict: out_files = dict() subject_path = Path(cfg.subjects_dir) / cfg.fs_subject out_files['seghead'] = subject_path / 'surf' / 'lh.seghead' @@ -44,6 +47,7 @@ def get_output_fnames_coreg_surfaces(*, cfg, subject): ) def make_coreg_surfaces( cfg: SimpleNamespace, + exec_params: SimpleNamespace, subject: str, in_files: dict, ) -> dict: @@ -63,7 +67,6 @@ def make_coreg_surfaces( def get_config(*, config, subject) -> SimpleNamespace: cfg = SimpleNamespace( - exec_params=config.exec_params, subject=subject, fs_subject=get_fs_subject(config, subject), subjects_dir=get_fs_subjects_dir(config), @@ -83,8 +86,12 @@ def main(*, config) -> None: parallel( run_func( - cfg=get_config(config=config, subject=subject), - subject=subject, + cfg=get_config( + config=config, + subject=subject, + ), + exec_params=config.exec_params, force_run=config.recreate_scalp_surface, + subject=subject, ) for subject in subjects ) diff --git a/mne_bids_pipeline/steps/init/_00_init_derivatives_dir.py b/mne_bids_pipeline/steps/init/_01_init_derivatives_dir.py similarity index 92% rename from mne_bids_pipeline/steps/init/_00_init_derivatives_dir.py rename to mne_bids_pipeline/steps/init/_01_init_derivatives_dir.py index 3daf77c64..17f04be87 100644 --- a/mne_bids_pipeline/steps/init/_00_init_derivatives_dir.py +++ b/mne_bids_pipeline/steps/init/_01_init_derivatives_dir.py @@ -44,7 +44,8 @@ def init_dataset(cfg) -> None: @failsafe_run() def init_subject_dirs( *, - cfg, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, subject: str, session: Optional[str], ) -> None: @@ -63,7 +64,6 @@ def get_config( config, ) -> SimpleNamespace: cfg = SimpleNamespace( - exec_params=config.exec_params, datatype=get_datatype(config), deriv_root=config.deriv_root, PIPELINE_NAME=config.PIPELINE_NAME, @@ -81,7 +81,10 @@ def main(*, config): for subject in get_subjects(config): for session in get_sessions(config): init_subject_dirs( - cfg=get_config(config=config), + cfg=get_config( + config=config, + ), + exec_params=config.exec_params, subject=subject, session=session, ) diff --git a/mne_bids_pipeline/steps/init/_01_find_empty_room.py b/mne_bids_pipeline/steps/init/_02_find_empty_room.py similarity index 96% rename from mne_bids_pipeline/steps/init/_01_find_empty_room.py rename to mne_bids_pipeline/steps/init/_02_find_empty_room.py index bd8bce55e..849a10a4a 100644 --- a/mne_bids_pipeline/steps/init/_01_find_empty_room.py +++ b/mne_bids_pipeline/steps/init/_02_find_empty_room.py @@ -56,11 +56,12 @@ def get_input_fnames_find_empty_room( ) def find_empty_room( *, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, subject: str, session: Optional[str], run: Optional[str], in_files: Dict[str, BIDSPath], - cfg: SimpleNamespace ) -> Dict[str, BIDSPath]: raw_path = in_files.pop(f'raw_run-{run}') in_files.pop('sidecar', None) @@ -98,7 +99,6 @@ def get_config( config, ) -> SimpleNamespace: cfg = SimpleNamespace( - exec_params=config.exec_params, proc=config.proc, task=get_task(config), datatype=get_datatype(config), @@ -130,7 +130,10 @@ def main(*, config) -> None: else: run = get_runs(config=config, subject=subject)[0] logs.append(find_empty_room( - cfg=get_config(config=config), + cfg=get_config( + config=config, + ), + exec_params=config.exec_params, subject=subject, session=get_sessions(config)[0], run=run, diff --git a/mne_bids_pipeline/steps/init/__init__.py b/mne_bids_pipeline/steps/init/__init__.py index e3672a8c3..72a80cf13 100644 --- a/mne_bids_pipeline/steps/init/__init__.py +++ b/mne_bids_pipeline/steps/init/__init__.py @@ -1,9 +1,9 @@ """Filesystem initialization and dataset inspection.""" -from . import _00_init_derivatives_dir -from . import _01_find_empty_room +from . import _01_init_derivatives_dir +from . import _02_find_empty_room _STEPS = ( - _00_init_derivatives_dir, - _01_find_empty_room, + _01_init_derivatives_dir, + _02_find_empty_room, ) diff --git a/mne_bids_pipeline/steps/preprocessing/_01_data_quality.py b/mne_bids_pipeline/steps/preprocessing/_01_data_quality.py new file mode 100644 index 000000000..811466095 --- /dev/null +++ b/mne_bids_pipeline/steps/preprocessing/_01_data_quality.py @@ -0,0 +1,319 @@ +"""Assess data quality and find bad (and flat) channels.""" + +from types import SimpleNamespace +from typing import Optional + +import pandas as pd + +import mne +from mne_bids import BIDSPath + +from ..._config_utils import ( + get_mf_cal_fname, get_mf_ctc_fname, get_subjects, get_sessions, + get_runs, get_task, get_datatype, +) +from ..._import_data import ( + _get_raw_paths, import_experimental_data, import_er_data, + _bads_path, _auto_scores_path) +from ..._io import _write_json +from ..._logging import gen_log_kwargs, logger +from ..._parallel import parallel_func, get_parallel_backend +from ..._report import _open_report, _add_raw +from ..._run import failsafe_run, save_logs +from ..._viz import plot_auto_scores + + +def get_input_fnames_data_quality( + *, + cfg: SimpleNamespace, + subject: str, + session: Optional[str], + run: str, +) -> dict: + """Get paths of files required by maxwell_filter function.""" + in_files = _get_raw_paths( + cfg=cfg, + subject=subject, + session=session, + run=run, + kind='orig', + add_bads=False, + ) + return in_files + + +@failsafe_run( + get_input_fnames=get_input_fnames_data_quality, +) +def assess_data_quality( + *, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, + subject: str, + session: Optional[str], + run: str, + in_files: dict, +) -> None: + """Assess data quality and find and mark bad channels.""" + import matplotlib.pyplot as plt + out_files = dict() + for key in list(in_files.keys()): + bids_path_in = in_files.pop(key) + auto_scores = _find_bads( + cfg=cfg, + exec_params=exec_params, + bids_path_in=bids_path_in, + key=key, + subject=subject, + session=session, + run=run, + out_files=out_files, + ) + + # Report + with _open_report( + cfg=cfg, + exec_params=exec_params, + subject=subject, + session=session, + run=run) as report: + # Original data + kind = 'original' if not cfg.proc else cfg.proc + msg = f'Adding {kind} raw data to report.' + logger.info(**gen_log_kwargs(message=msg)) + _add_raw( + cfg=cfg, + report=report, + bids_path_in=bids_path_in, + title=f'Raw ({kind})', + ) + if cfg.find_noisy_channels_meg: + assert auto_scores is not None + msg = 'Adding noisy channel detection to report.' + logger.info(**gen_log_kwargs(message=msg)) + figs = plot_auto_scores(auto_scores, ch_types=cfg.ch_types) + captions = [f'Run {run}'] * len(figs) + tags = ('raw', 'data-quality', f'run-{run}') + report.add_figure( + fig=figs, + caption=captions, + section='Data quality', + title=f'Bad channel detection: {run}', + tags=tags, + replace=True, + ) + for fig in figs: + plt.close(fig) + + assert len(in_files) == 0, in_files.keys() + return out_files + + +def _find_bads( + *, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, + bids_path_in: BIDSPath, + subject: str, + session: Optional[str], + run: str, + key: str, + out_files: dict, +): + if not (cfg.find_noisy_channels_meg or cfg.find_flat_channels_meg): + return None + + if (cfg.find_flat_channels_meg and + not cfg.find_noisy_channels_meg): + msg = 'Finding flat channels.' + elif (cfg.find_noisy_channels_meg and + not cfg.find_flat_channels_meg): + msg = 'Finding noisy channels using Maxwell filtering.' + else: + msg = ('Finding flat channels, and noisy channels using ' + 'Maxwell filtering.') + logger.info(**gen_log_kwargs(message=msg)) + + if key == 'raw_noise': + raw = import_er_data( + bids_path_er_in=bids_path_in, + bids_path_er_bads_in=None, + bids_path_ref_in=None, + bids_path_ref_bads_in=None, + cfg=cfg, + datatype='meg', + ) + else: + data_is_rest = (key == 'raw_rest') + raw = import_experimental_data( + bids_path_in=bids_path_in, + bids_path_bads_in=None, + cfg=cfg, + data_is_rest=data_is_rest, + ) + + # Filter the data manually before passing it to find_bad_channels_maxwell() + # This reduces memory usage, as we can control the number of jobs used + # during filtering. + preexisting_bads = raw.info['bads'].copy() + bads = preexisting_bads.copy() + raw_filt = raw.copy().filter(l_freq=None, h_freq=40, n_jobs=1) + auto_noisy_chs, auto_flat_chs, auto_scores = \ + mne.preprocessing.find_bad_channels_maxwell( + raw=raw_filt, + calibration=cfg.mf_cal_fname, + cross_talk=cfg.mf_ctc_fname, + origin=cfg.mf_head_origin, + coord_frame='head', + return_scores=True, + h_freq=None # we filtered manually above + ) + del raw_filt + + if cfg.find_flat_channels_meg: + if auto_flat_chs: + msg = (f'Found {len(auto_flat_chs)} flat channels: ' + f'{", ".join(auto_flat_chs)}') + else: + msg = 'Found no flat channels.' + logger.info(**gen_log_kwargs(message=msg)) + bads.extend(auto_flat_chs) + + if cfg.find_noisy_channels_meg: + if auto_noisy_chs: + msg = (f'Found {len(auto_noisy_chs)} noisy channels: ' + f'{", ".join(auto_noisy_chs)}') + else: + msg = 'Found no noisy channels.' + + logger.info(**gen_log_kwargs(message=msg)) + bads.extend(auto_noisy_chs) + + bads = sorted(set(bads)) + msg = f'Found {len(bads)} channels as bad.' + raw.info['bads'] = bads + del bads + logger.info(**gen_log_kwargs(message=msg)) + + if cfg.find_noisy_channels_meg: + out_files['auto_scores'] = _auto_scores_path( + cfg=cfg, + bids_path_in=bids_path_in, + ) + _write_json(out_files['auto_scores'], auto_scores) + + # Write the bad channels to disk. + out_files['bads_tsv'] = _bads_path(cfg=cfg, bids_path_in=bids_path_in) + bads_for_tsv = [] + reasons = [] + + if cfg.find_flat_channels_meg: + bads_for_tsv.extend(auto_flat_chs) + reasons.extend(['auto-flat'] * len(auto_flat_chs)) + preexisting_bads = set(preexisting_bads) - set(auto_flat_chs) + + if cfg.find_noisy_channels_meg: + bads_for_tsv.extend(auto_noisy_chs) + reasons.extend(['auto-noisy'] * len(auto_noisy_chs)) + preexisting_bads = set(preexisting_bads) - set(auto_noisy_chs) + + preexisting_bads = list(preexisting_bads) + if preexisting_bads: + bads_for_tsv.extend(preexisting_bads) + reasons.extend(['pre-existing (before MNE-BIDS-pipeline was run)'] * + len(preexisting_bads)) + + tsv_data = pd.DataFrame(dict(name=bads_for_tsv, reason=reasons)) + tsv_data = tsv_data.sort_values(by='name') + tsv_data.to_csv(out_files['bads_tsv'], sep='\t', index=False) + + # Interaction + if exec_params.interactive and cfg.find_noisy_channels_meg: + import matplotlib.pyplot as plt + plot_auto_scores(auto_scores, ch_types=cfg.ch_types) + plt.show() + + return auto_scores + + +def get_config( + *, + config: SimpleNamespace, + subject: str, + session: Optional[str], +) -> SimpleNamespace: + if config.find_noisy_channels_meg or config.find_flat_channels_meg: + mf_cal_fname = get_mf_cal_fname( + config=config, + subject=subject, + session=session, + ) + mf_ctc_fname = get_mf_ctc_fname( + config=config, + subject=subject, + session=session, + ) + else: + mf_cal_fname = mf_ctc_fname = None + cfg = SimpleNamespace( + process_empty_room=config.process_empty_room, + process_rest=config.process_rest, + task_is_rest=config.task_is_rest, + runs=get_runs(config=config, subject=subject), + proc=config.proc, + task=get_task(config), + datatype=get_datatype(config), + acq=config.acq, + rec=config.rec, + space=config.space, + bids_root=config.bids_root, + deriv_root=config.deriv_root, + mf_cal_fname=mf_cal_fname, + mf_ctc_fname=mf_ctc_fname, + mf_head_origin=config.mf_head_origin, + reader_extra_params=config.reader_extra_params, + crop_runs=config.crop_runs, + rename_events=config.rename_events, + eeg_bipolar_channels=config.eeg_bipolar_channels, + eeg_template_montage=config.eeg_template_montage, + fix_stim_artifact=config.fix_stim_artifact, + stim_artifact_tmin=config.stim_artifact_tmin, + stim_artifact_tmax=config.stim_artifact_tmax, + find_flat_channels_meg=config.find_flat_channels_meg, + find_noisy_channels_meg=config.find_noisy_channels_meg, + drop_channels=config.drop_channels, + find_breaks=config.find_breaks, + min_break_duration=config.min_break_duration, + t_break_annot_start_after_previous_event=config.t_break_annot_start_after_previous_event, # noqa:E501 + t_break_annot_stop_before_next_event=config.t_break_annot_stop_before_next_event, # noqa:E501 + data_type=config.data_type, + ch_types=config.ch_types, + eog_channels=config.eog_channels, + on_rename_missing_events=config.on_rename_missing_events, + plot_psd_for_runs=config.plot_psd_for_runs, + ) + return cfg + + +def main(*, config: SimpleNamespace) -> None: + """Run maxwell_filter.""" + with get_parallel_backend(config.exec_params): + parallel, run_func = parallel_func( + assess_data_quality, exec_params=config.exec_params) + logs = parallel( + run_func( + cfg=get_config( + config=config, + subject=subject, + session=session), + exec_params=config.exec_params, + subject=subject, + session=session, + run=run + ) + for subject in get_subjects(config) + for session in get_sessions(config) + for run in get_runs(config=config, subject=subject) + ) + + save_logs(config=config, logs=logs) diff --git a/mne_bids_pipeline/steps/preprocessing/_01_maxfilter.py b/mne_bids_pipeline/steps/preprocessing/_02_maxfilter.py similarity index 74% rename from mne_bids_pipeline/steps/preprocessing/_01_maxfilter.py rename to mne_bids_pipeline/steps/preprocessing/_02_maxfilter.py index f5f580ba4..5d14721be 100644 --- a/mne_bids_pipeline/steps/preprocessing/_01_maxfilter.py +++ b/mne_bids_pipeline/steps/preprocessing/_02_maxfilter.py @@ -19,67 +19,48 @@ import numpy as np import mne -from mne_bids import BIDSPath, read_raw_bids, get_bids_path_from_fname +from mne_bids import read_raw_bids from ..._config_utils import ( get_mf_cal_fname, get_mf_ctc_fname, get_subjects, get_sessions, get_runs, get_task, get_datatype, get_mf_reference_run, ) from ..._import_data import ( - import_experimental_data, import_er_data, import_rest_data + import_experimental_data, import_er_data, _get_raw_paths, _add_bads_file, ) -from ..._io import _read_json, _empty_room_match_path from ..._logging import gen_log_kwargs, logger from ..._parallel import parallel_func, get_parallel_backend +from ..._report import _open_report, _add_raw from ..._run import failsafe_run, save_logs, _update_for_splits -def get_input_fnames_maxwell_filter(**kwargs): +def get_input_fnames_maxwell_filter( + *, + cfg: SimpleNamespace, + subject: str, + session: Optional[str], + run: str, +) -> dict: """Get paths of files required by maxwell_filter function.""" - cfg = kwargs.pop('cfg') - subject = kwargs.pop('subject') - session = kwargs.pop('session') - run = kwargs.pop('run') - assert len(kwargs) == 0, kwargs.keys() - del kwargs - - bids_path_in = BIDSPath(subject=subject, - session=session, - run=run, - task=cfg.task, - acquisition=cfg.acq, - recording=cfg.rec, - space=cfg.space, - suffix='meg', - datatype=cfg.datatype, - root=cfg.bids_root, - check=False) - - in_files = { - f'raw_run-{bp.run}': bp for bp in bids_path_in.match() - } - - ref_bids_path = bids_path_in.copy().update( + in_files = _get_raw_paths( + cfg=cfg, + subject=subject, + session=session, + run=run, + kind='orig', + ) + ref_bids_path = list(in_files.values())[0].copy().update( run=cfg.mf_reference_run, extension='.fif', check=True ) - - if run == cfg.mf_reference_run: - if cfg.process_rest and not cfg.task_is_rest: - raw_rest = bids_path_in.copy().update(task="rest") - if raw_rest.fpath.exists(): - in_files["raw_rest"] = raw_rest - elif raw_rest.copy().update(run=None).fpath.exists(): - in_files["raw_rest"] = raw_rest.update(run=None) - if cfg.process_empty_room and cfg.datatype == 'meg': - raw_noise = _read_json( - _empty_room_match_path(bids_path_in, cfg))['fname'] - if raw_noise is not None: - raw_noise = get_bids_path_from_fname(raw_noise) - in_files["raw_noise"] = raw_noise - - in_files["raw_ref_run"] = ref_bids_path + key = "raw_ref_run" + in_files[key] = ref_bids_path + _add_bads_file( + cfg=cfg, + in_files=in_files, + key=key, + ) in_files["mf_cal_fname"] = cfg.mf_cal_fname in_files["mf_ctc_fname"] = cfg.mf_ctc_fname return in_files @@ -88,12 +69,22 @@ def get_input_fnames_maxwell_filter(**kwargs): @failsafe_run( get_input_fnames=get_input_fnames_maxwell_filter, ) -def run_maxwell_filter(*, cfg, subject, session, run, in_files): +def run_maxwell_filter( + *, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, + subject: str, + session: Optional[str], + run: str, + in_files: dict, +) -> dict: if cfg.proc and 'sss' in cfg.proc and cfg.use_maxwell_filter: raise ValueError(f'You cannot set use_maxwell_filter to True ' f'if data have already processed with Maxwell-filter.' f' Got proc={cfg.proc}.') - bids_path_in = in_files.pop(f"raw_run-{run}") + in_key = f"raw_run-{run}" + bids_path_in = in_files.pop(in_key) + bids_path_bads_in = in_files.pop(f'{in_key}-bads', None) bids_path_out = bids_path_in.copy().update( processing="sss", suffix="raw", @@ -113,15 +104,18 @@ def run_maxwell_filter(*, cfg, subject, session, run, in_files): msg = f'Loading reference run: {cfg.mf_reference_run}.' logger.info(**gen_log_kwargs(message=msg)) - ref_fname = in_files.pop("raw_ref_run") - raw = read_raw_bids(bids_path=ref_fname, + bids_path_ref_in = in_files.pop("raw_ref_run") + raw = read_raw_bids(bids_path=bids_path_ref_in, extra_params=cfg.reader_extra_params) + bids_path_ref_bads_in = in_files.pop("raw_ref_run-bads", None) dev_head_t = raw.info['dev_head_t'] del raw raw = import_experimental_data( + cfg=cfg, bids_path_in=bids_path_in, - cfg=cfg + bids_path_bads_in=bids_path_bads_in, + data_is_rest=False, ) # Maxwell-filter experimental data. @@ -158,17 +152,28 @@ def run_maxwell_filter(*, cfg, subject, session, run, in_files): logger.info(**gen_log_kwargs(message=msg)) raw_sss.save(out_files['sss_raw'], split_naming='bids', overwrite=True, split_size=cfg._raw_split_size) + del raw # we need to be careful about split files _update_for_splits(out_files, 'sss_raw') - del raw, raw_sss - if cfg.interactive: - # Load the data we have just written, because it contains only - # the relevant channels. - raw_sss = mne.io.read_raw_fif( - out_files['sss_raw'], allow_maxshield=True) + if exec_params.interactive: raw_sss.plot(n_channels=50, butterfly=True, block=True) - del raw_sss + del raw_sss + + # Reporting + with _open_report( + cfg=cfg, + exec_params=exec_params, + subject=subject, + session=session) as report: + msg = 'Adding Maxwell filtered raw data to report.' + _add_raw( + cfg=cfg, + report=report, + bids_path_in=out_files['sss_raw'], + title='Raw (maxwell filtered)', + ) + del bids_path_in # Noise data processing. nice_names = dict(rest='resting-state', noise='empty-room') @@ -178,24 +183,28 @@ def run_maxwell_filter(*, cfg, subject, session, run, in_files): continue recording_type = nice_names[task] msg = f'Processing {recording_type} recording …' - logger.info(**gen_log_kwargs(message=msg)) - bids_path_in = in_files.pop(in_key) + logger.info(**gen_log_kwargs(message=msg, run=task)) + bids_path_noise = in_files.pop(in_key) + bids_path_noise_bads = in_files.pop(f'{in_key}-bads', None) if task == 'rest': - raw_noise = import_rest_data( + raw_noise = import_experimental_data( cfg=cfg, - bids_path_in=bids_path_in, + bids_path_in=bids_path_noise, + bids_path_bads_in=bids_path_noise_bads, + data_is_rest=True, ) else: raw_noise = import_er_data( cfg=cfg, - bids_path_er_in=bids_path_in, - bids_path_ref_in=ref_fname, + bids_path_er_in=bids_path_noise, + bids_path_ref_in=bids_path_ref_in, + bids_path_er_bads_in=bids_path_noise_bads, + bids_path_ref_bads_in=bids_path_ref_bads_in, ) - del bids_path_in # Maxwell-filter noise data. msg = f'Applying Maxwell filter to {recording_type} recording' - logger.info(**gen_log_kwargs(message=msg)) + logger.info(**gen_log_kwargs(message=msg, run=task)) raw_noise_sss = mne.preprocessing.maxwell_filter( raw_noise, **common_mf_kws ) @@ -220,7 +229,7 @@ def run_maxwell_filter(*, cfg, subject, session, run, in_files): f'take care of this during epoching of the experimental ' f'data.' ) - logger.warning(**gen_log_kwargs(message=msg)) + logger.warning(**gen_log_kwargs(message=msg, run=task)) else: pass # Should cause no problems! elif not np.isclose(rank_exp, rank_noise): @@ -230,8 +239,7 @@ def run_maxwell_filter(*, cfg, subject, session, run, in_files): f'were processed differently.') raise RuntimeError(msg) - out_key = f'sss_{task}' - out_files[out_key] = bids_path_out.copy().update( + out_files[in_key] = bids_path_out.copy().update( task=task, run=None, processing='sss' @@ -240,26 +248,41 @@ def run_maxwell_filter(*, cfg, subject, session, run, in_files): # Save only the channel types we wish to analyze # (same as for experimental data above). msg = ("Writing " - f"{out_files[out_key].fpath.relative_to(cfg.deriv_root)}") - logger.info(**gen_log_kwargs(message=msg)) + f"{out_files[in_key].fpath.relative_to(cfg.deriv_root)}") + logger.info(**gen_log_kwargs(message=msg, run=task)) raw_noise_sss.save( - out_files[out_key], overwrite=True, + out_files[in_key], overwrite=True, split_naming='bids', split_size=cfg._raw_split_size, ) - _update_for_splits(out_files, out_key) + _update_for_splits(out_files, in_key) del raw_noise_sss + + with _open_report( + cfg=cfg, + exec_params=exec_params, + subject=subject, + session=session) as report: + msg = 'Adding Maxwell filtered raw data to report.' + logger.info(**gen_log_kwargs(message=msg)) + for fname in out_files.values(): + _add_raw( + cfg=cfg, + report=report, + bids_path_in=fname, + title='Raw (maxwell filtered)', + ) + assert len(in_files) == 0, in_files.keys() return out_files def get_config( *, - config, + config: SimpleNamespace, subject: str, session: Optional[str], ) -> SimpleNamespace: cfg = SimpleNamespace( - exec_params=config.exec_params, reader_extra_params=config.reader_extra_params, mf_cal_fname=get_mf_cal_fname( config=config, @@ -288,7 +311,6 @@ def get_config( bids_root=config.bids_root, deriv_root=config.deriv_root, crop_runs=config.crop_runs, - interactive=config.interactive, rename_events=config.rename_events, eeg_template_montage=config.eeg_template_montage, fix_stim_artifact=config.fix_stim_artifact, @@ -305,12 +327,13 @@ def get_config( ch_types=config.ch_types, data_type=config.data_type, on_rename_missing_events=config.on_rename_missing_events, + plot_psd_for_runs=config.plot_psd_for_runs, _raw_split_size=config._raw_split_size, ) return cfg -def main(*, config) -> None: +def main(*, config: SimpleNamespace) -> None: """Run maxwell_filter.""" if not config.use_maxwell_filter: msg = 'Skipping …' @@ -326,6 +349,7 @@ def main(*, config) -> None: config=config, subject=subject, session=session), + exec_params=config.exec_params, subject=subject, session=session, run=run diff --git a/mne_bids_pipeline/steps/preprocessing/_02_frequency_filter.py b/mne_bids_pipeline/steps/preprocessing/_03_frequency_filter.py similarity index 63% rename from mne_bids_pipeline/steps/preprocessing/_02_frequency_filter.py rename to mne_bids_pipeline/steps/preprocessing/_03_frequency_filter.py index eff7c974a..be5412308 100644 --- a/mne_bids_pipeline/steps/preprocessing/_02_frequency_filter.py +++ b/mne_bids_pipeline/steps/preprocessing/_03_frequency_filter.py @@ -15,95 +15,40 @@ """ # noqa: E501 import numpy as np -from typing import Optional, Union, Literal from types import SimpleNamespace +from typing import Optional, Union, Literal import mne -from mne_bids import BIDSPath, get_bids_path_from_fname from ..._config_utils import ( get_sessions, get_runs, get_subjects, get_task, get_datatype, get_mf_reference_run, ) from ..._import_data import ( - import_experimental_data, import_er_data, import_rest_data) -from ..._io import _read_json, _empty_room_match_path + import_experimental_data, import_er_data, _get_raw_paths, +) from ..._logging import gen_log_kwargs, logger from ..._parallel import parallel_func, get_parallel_backend -from ..._report import _open_report, plot_auto_scores_ +from ..._report import _open_report, _add_raw from ..._run import failsafe_run, save_logs, _update_for_splits -def get_input_fnames_frequency_filter(**kwargs): +def get_input_fnames_frequency_filter( + *, + cfg: SimpleNamespace, + subject: str, + session: Optional[str], + run: str, +) -> dict: """Get paths of files required by filter_data function.""" - cfg = kwargs.pop('cfg') - subject = kwargs.pop('subject') - session = kwargs.pop('session') - run = kwargs.pop('run') - assert len(kwargs) == 0, kwargs.keys() - del kwargs - - # Construct the basenames of the files we wish to load, and of the empty- - # room recording we wish to save. - # The basenames of the empty-room recording output file does not contain - # the "run" entity. - path_kwargs = dict( + kind = 'sss' if cfg.use_maxwell_filter else 'orig' + return _get_raw_paths( + cfg=cfg, subject=subject, - run=run, session=session, - task=cfg.task, - acquisition=cfg.acq, - recording=cfg.rec, - space=cfg.space, - datatype=cfg.datatype, - check=False + run=run, + kind=kind, ) - if cfg.use_maxwell_filter: - path_kwargs['root'] = cfg.deriv_root - path_kwargs['suffix'] = 'raw' - path_kwargs['extension'] = '.fif' - path_kwargs['processing'] = 'sss' - else: - path_kwargs['root'] = cfg.bids_root - path_kwargs['suffix'] = None - path_kwargs['extension'] = None - path_kwargs['processing'] = cfg.proc - bids_path_in = BIDSPath(**path_kwargs) - - in_files = dict() - in_files[f'raw_run-{run}'] = bids_path_in - _update_for_splits(in_files, f'raw_run-{run}', single=True) - - if run == cfg.runs[0]: - do = dict( - rest=cfg.process_rest and not cfg.task_is_rest, - noise=cfg.process_empty_room and cfg.datatype == 'meg', - ) - for task in ('rest', 'noise'): - if not do[task]: - continue - key = f'raw_{task}' - if cfg.use_maxwell_filter: - raw_fname = bids_path_in.copy().update( - run=None, task=task) - else: - if task == 'rest': - raw_fname = bids_path_in.copy().update( - run=None, task=task) - else: - raw_fname = _read_json( - _empty_room_match_path(bids_path_in, cfg))['fname'] - if raw_fname is not None: - raw_fname = get_bids_path_from_fname(raw_fname) - if raw_fname is None: - continue - in_files[key] = raw_fname - _update_for_splits( - in_files, key, single=True, allow_missing=True) - if not in_files[key].fpath.exists(): - in_files.pop(key) - - return in_files def filter( @@ -163,16 +108,18 @@ def resample( ) def filter_data( *, - cfg, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, subject: str, session: Optional[str], run: str, in_files: dict, -) -> None: +) -> dict: """Filter data from a single subject.""" - import matplotlib.pyplot as plt out_files = dict() - bids_path = in_files.pop(f"raw_run-{run}") + in_key = f"raw_run-{run}" + bids_path = in_files.pop(in_key) + bids_path_bads_in = in_files.pop(f'{in_key}-bads', None) # Create paths for reading and writing the filtered data. if cfg.use_maxwell_filter: @@ -180,10 +127,14 @@ def filter_data( logger.info(**gen_log_kwargs(message=msg)) raw = mne.io.read_raw_fif(bids_path) else: - raw = import_experimental_data(bids_path_in=bids_path, - cfg=cfg) + raw = import_experimental_data( + cfg=cfg, + bids_path_in=bids_path, + bids_path_bads_in=bids_path_bads_in, + data_is_rest=False, + ) - out_files['raw_filt'] = bids_path.copy().update( + out_files[in_key] = bids_path.copy().update( root=cfg.deriv_root, processing='filt', extension='.fif', suffix='raw', split=None) raw.load_data() @@ -197,10 +148,10 @@ def filter_data( resample(raw=raw, subject=subject, session=session, run=run, sfreq=cfg.resample_sfreq, data_type='experimental') - raw.save(out_files['raw_filt'], overwrite=True, split_naming='bids', + raw.save(out_files[in_key], overwrite=True, split_naming='bids', split_size=cfg._raw_split_size) - _update_for_splits(out_files, 'raw_filt') - if cfg.interactive: + _update_for_splits(out_files, in_key) + if exec_params.interactive: # Plot raw data and power spectral density. raw.plot(n_channels=50, butterfly=True) fmax = 1.5 * cfg.h_freq if cfg.h_freq is not None else np.inf @@ -215,6 +166,7 @@ def filter_data( continue data_type = nice_names[task] bids_path_noise = in_files.pop(in_key) + bids_path_noise_bads = in_files.pop(f'{in_key}-bads', None) if cfg.use_maxwell_filter: msg = (f'Reading {data_type} recording: ' f'{bids_path_noise.basename}') @@ -224,15 +176,19 @@ def filter_data( raw_noise = import_er_data( cfg=cfg, bids_path_er_in=bids_path_noise, - bids_path_ref_in=bids_path, # will take bads from this run (0) + bids_path_ref_in=bids_path, + bids_path_er_bads_in=bids_path_noise_bads, + # take bads from this run (0) + bids_path_ref_bads_in=bids_path_bads_in, ) else: - raw_noise = import_rest_data( + raw_noise = import_experimental_data( cfg=cfg, bids_path_in=bids_path_noise, + bids_path_bads_in=bids_path_noise_bads, + data_is_rest=True, ) - out_key = f'raw_{task}_filt' - out_files[out_key] = \ + out_files[in_key] = \ bids_path.copy().update( root=cfg.deriv_root, processing='filt', extension='.fif', suffix='raw', split=None, task=task, run=None) @@ -249,11 +205,11 @@ def filter_data( sfreq=cfg.resample_sfreq, data_type=data_type) raw_noise.save( - out_files[out_key], overwrite=True, split_naming='bids', + out_files[in_key], overwrite=True, split_naming='bids', split_size=cfg._raw_split_size, ) - _update_for_splits(out_files, out_key) - if cfg.interactive: + _update_for_splits(out_files, in_key) + if exec_params.interactive: # Plot raw data and power spectral density. raw_noise.plot(n_channels=50, butterfly=True) fmax = 1.5 * cfg.h_freq if cfg.h_freq is not None else np.inf @@ -261,49 +217,19 @@ def filter_data( assert len(in_files) == 0, in_files.keys() - # Report - with _open_report(cfg=cfg, subject=subject, session=session) as report: - # This is a weird place for this, but it's the first place we actually - # start a report, so let's leave it here. We could put it in - # _import_data (where the auto-bad-finding is called), but that seems - # worse. - if run == cfg.runs[0] and cfg.find_noisy_channels_meg: - msg = 'Adding visualization of noisy channel detection to report.' - logger.info(**gen_log_kwargs(message=msg)) - figs, captions = plot_auto_scores_( - cfg=cfg, - subject=subject, - session=session, - ) - tags = ('raw', 'data-quality', *[f'run-{i}' for i in cfg.runs]) - report.add_figure( - fig=figs, - caption=captions, - title='Data Quality', - tags=tags, - replace=True, - ) - for fig in figs: - plt.close(fig) - + with _open_report( + cfg=cfg, + exec_params=exec_params, + subject=subject, + session=session) as report: msg = 'Adding filtered raw data to report.' logger.info(**gen_log_kwargs(message=msg)) for fname in out_files.values(): - title = 'Raw' - if fname.run is not None: - title += f', run {fname.run}' - plot_raw_psd = ( - cfg.plot_psd_for_runs == 'all' or - fname.run in cfg.plot_psd_for_runs - ) - report.add_raw( - raw=fname, - title=title, - butterfly=5, - psd=plot_raw_psd, - tags=('raw', 'filtered', f'run-{fname.run}'), - # caption=fname.basename, # TODO upstream - replace=True, + _add_raw( + cfg=cfg, + report=report, + bids_path_in=fname, + title='Raw (filtered)', ) return out_files @@ -311,11 +237,10 @@ def filter_data( def get_config( *, - config, + config: SimpleNamespace, subject: str, ) -> SimpleNamespace: cfg = SimpleNamespace( - exec_params=config.exec_params, reader_extra_params=config.reader_extra_params, process_empty_room=config.process_empty_room, process_rest=config.process_rest, @@ -336,7 +261,6 @@ def get_config( h_trans_bandwidth=config.h_trans_bandwidth, resample_sfreq=config.resample_sfreq, crop_runs=config.crop_runs, - interactive=config.interactive, rename_events=config.rename_events, eeg_bipolar_channels=config.eeg_bipolar_channels, eeg_template_montage=config.eeg_template_montage, @@ -357,14 +281,11 @@ def get_config( on_rename_missing_events=config.on_rename_missing_events, plot_psd_for_runs=config.plot_psd_for_runs, _raw_split_size=config._raw_split_size, - config_path=config.config_path, ) - if config.sessions == ['N170'] and config.task == 'ERN': - raise RuntimeError return cfg -def main(*, config) -> None: +def main(*, config: SimpleNamespace) -> None: """Run filter.""" with get_parallel_backend(config.exec_params): parallel, run_func = parallel_func( @@ -372,7 +293,11 @@ def main(*, config) -> None: logs = parallel( run_func( - cfg=get_config(config=config, subject=subject), + cfg=get_config( + config=config, + subject=subject, + ), + exec_params=config.exec_params, subject=subject, session=session, run=run, diff --git a/mne_bids_pipeline/steps/preprocessing/_03_make_epochs.py b/mne_bids_pipeline/steps/preprocessing/_04_make_epochs.py similarity index 95% rename from mne_bids_pipeline/steps/preprocessing/_03_make_epochs.py rename to mne_bids_pipeline/steps/preprocessing/_04_make_epochs.py index 6684207e5..3ae35bbc3 100644 --- a/mne_bids_pipeline/steps/preprocessing/_03_make_epochs.py +++ b/mne_bids_pipeline/steps/preprocessing/_04_make_epochs.py @@ -8,6 +8,7 @@ """ from types import SimpleNamespace +from typing import Optional import mne from mne_bids import BIDSPath @@ -25,14 +26,13 @@ from ..._parallel import parallel_func, get_parallel_backend -def get_input_fnames_epochs(**kwargs): +def get_input_fnames_epochs( + *, + cfg: SimpleNamespace, + subject: str, + session: Optional[str], +) -> dict: """Get paths of files required by filter_data function.""" - cfg = kwargs.pop('cfg') - subject = kwargs.pop('subject') - session = kwargs.pop('session') - assert len(kwargs) == 0, kwargs.keys() - del kwargs - # Construct the basenames of the files we wish to load, and of the empty- # room recording we wish to save. # The basenames of the empty-room recording output file does not contain @@ -67,7 +67,14 @@ def get_input_fnames_epochs(**kwargs): @failsafe_run( get_input_fnames=get_input_fnames_epochs, ) -def run_epochs(*, cfg, subject, session, in_files): +def run_epochs( + *, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, + subject: str, + session: Optional[str], + in_files: dict, +) -> dict: """Extract epochs for one subject.""" raw_fnames = [in_files.pop(f'raw_run-{run}') for run in cfg.runs] bids_path_in = raw_fnames[0].copy().update( @@ -147,7 +154,7 @@ def run_epochs(*, cfg, subject, session, in_files): smallest_rank = new_rank smallest_rank_info = epochs.info.copy() - del epochs + del epochs, run # Clean up namespace epochs = epochs_all_runs @@ -211,10 +218,11 @@ def run_epochs(*, cfg, subject, session, in_files): _update_for_splits(out_files, 'epochs') # Report - with _open_report(cfg=cfg, - subject=subject, - session=session, - run=run) as report: + with _open_report( + cfg=cfg, + exec_params=exec_params, + subject=subject, + session=session) as report: if not cfg.task_is_rest: msg = 'Adding events plot to report.' logger.info(**gen_log_kwargs(message=msg)) @@ -246,7 +254,7 @@ def run_epochs(*, cfg, subject, session, in_files): ) # Interactive - if cfg.interactive: + if exec_params.interactive: epochs.plot() epochs.plot_image(combine='gfp', sigma=2., cmap='YlGnBu_r') assert len(in_files) == 0, in_files.keys() @@ -289,7 +297,6 @@ def get_config( subject: str, ) -> SimpleNamespace: cfg = SimpleNamespace( - exec_params=config.exec_params, runs=get_runs(config=config, subject=subject), use_maxwell_filter=config.use_maxwell_filter, proc=config.proc, @@ -300,7 +307,6 @@ def get_config( space=config.space, bids_root=config.bids_root, deriv_root=config.deriv_root, - interactive=config.interactive, task_is_rest=config.task_is_rest, conditions=config.conditions, epochs_tmin=config.epochs_tmin, @@ -317,7 +323,6 @@ def get_config( eeg_reference=get_eeg_reference(config), rest_epochs_duration=config.rest_epochs_duration, rest_epochs_overlap=config.rest_epochs_overlap, - config_path=config.config_path, _epochs_split_size=config._epochs_split_size, ) return cfg @@ -334,6 +339,7 @@ def main(*, config) -> None: config=config, subject=subject, ), + exec_params=config.exec_params, subject=subject, session=session ) diff --git a/mne_bids_pipeline/steps/preprocessing/_04a_run_ica.py b/mne_bids_pipeline/steps/preprocessing/_05a_run_ica.py similarity index 97% rename from mne_bids_pipeline/steps/preprocessing/_04a_run_ica.py rename to mne_bids_pipeline/steps/preprocessing/_05a_run_ica.py index e9e178b2c..f16d0e6f3 100644 --- a/mne_bids_pipeline/steps/preprocessing/_04a_run_ica.py +++ b/mne_bids_pipeline/steps/preprocessing/_05a_run_ica.py @@ -212,12 +212,12 @@ def detect_bad_components( return inds, scores -def get_input_fnames_run_ica(**kwargs): - cfg = kwargs.pop('cfg') - subject = kwargs.pop('subject') - session = kwargs.pop('session') - assert len(kwargs) == 0, kwargs.keys() - del kwargs +def get_input_fnames_run_ica( + *, + cfg: SimpleNamespace, + subject: str, + session: Optional[str], +) -> dict: bids_basename = BIDSPath(subject=subject, session=session, task=cfg.task, @@ -239,7 +239,14 @@ def get_input_fnames_run_ica(**kwargs): @failsafe_run( get_input_fnames=get_input_fnames_run_ica, ) -def run_ica(*, cfg, subject, session, in_files): +def run_ica( + *, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, + subject: str, + session: Optional[str], + in_files: dict, +) -> dict: """Run ICA.""" raw_fnames = [in_files.pop(f'raw_run-{run}') for run in cfg.runs] bids_basename = raw_fnames[0].copy().update( @@ -488,7 +495,10 @@ def run_ica(*, cfg, subject, session, in_files): logger.info(**gen_log_kwargs(message=msg)) report.save( - out_files['report'], overwrite=True, open_browser=cfg.interactive) + out_files['report'], + overwrite=True, + open_browser=exec_params.interactive, + ) assert len(in_files) == 0, in_files.keys() return out_files @@ -496,12 +506,11 @@ def run_ica(*, cfg, subject, session, in_files): def get_config( *, - config, + config: SimpleNamespace, subject: Optional[str] = None, session: Optional[str] = None ) -> SimpleNamespace: cfg = SimpleNamespace( - exec_params=config.exec_params, conditions=config.conditions, task=get_task(config), task_is_rest=config.task_is_rest, @@ -511,7 +520,6 @@ def get_config( rec=config.rec, space=config.space, deriv_root=config.deriv_root, - interactive=config.interactive, ica_l_freq=config.ica_l_freq, ica_algorithm=config.ica_algorithm, ica_n_components=config.ica_n_components, @@ -541,7 +549,7 @@ def get_config( return cfg -def main(*, config) -> None: +def main(*, config: SimpleNamespace) -> None: """Run ICA.""" if config.spatial_filter != 'ica': msg = 'Skipping …' @@ -556,6 +564,7 @@ def main(*, config) -> None: cfg=get_config( config=config, subject=subject), + exec_params=config.exec_params, subject=subject, session=session, ) diff --git a/mne_bids_pipeline/steps/preprocessing/_04b_run_ssp.py b/mne_bids_pipeline/steps/preprocessing/_05b_run_ssp.py similarity index 93% rename from mne_bids_pipeline/steps/preprocessing/_04b_run_ssp.py rename to mne_bids_pipeline/steps/preprocessing/_05b_run_ssp.py index 8a0de6922..e53db1221 100644 --- a/mne_bids_pipeline/steps/preprocessing/_04b_run_ssp.py +++ b/mne_bids_pipeline/steps/preprocessing/_05b_run_ssp.py @@ -3,6 +3,7 @@ These are often also referred to as PCA vectors. """ +from typing import Optional from types import SimpleNamespace import mne @@ -21,12 +22,12 @@ from ..._run import failsafe_run, _update_for_splits, save_logs -def get_input_fnames_run_ssp(**kwargs): - cfg = kwargs.pop('cfg') - subject = kwargs.pop('subject') - session = kwargs.pop('session') - assert len(kwargs) == 0, kwargs.keys() - del kwargs +def get_input_fnames_run_ssp( + *, + cfg: SimpleNamespace, + subject: str, + session: Optional[str], +) -> dict: bids_basename = BIDSPath(subject=subject, session=session, task=cfg.task, @@ -49,7 +50,14 @@ def get_input_fnames_run_ssp(**kwargs): @failsafe_run( get_input_fnames=get_input_fnames_run_ssp, ) -def run_ssp(*, cfg, subject, session, in_files): +def run_ssp( + *, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, + subject: str, + session: Optional[str], + in_files: dict, +) -> dict: import matplotlib.pyplot as plt # compute SSP on first run of raw raw_fnames = [in_files.pop(f'raw_run-{run}') for run in cfg.runs] @@ -130,7 +138,11 @@ def run_ssp(*, cfg, subject, session, in_files): assert len(in_files) == 0, in_files.keys() # Report - with _open_report(cfg=cfg, subject=subject, session=session) as report: + with _open_report( + cfg=cfg, + exec_params=exec_params, + subject=subject, + session=session) as report: for kind in proj_kinds: if f'epochs_{kind}' not in out_files: continue @@ -171,11 +183,10 @@ def run_ssp(*, cfg, subject, session, in_files): def get_config( *, - config, + config: SimpleNamespace, subject: str, ) -> SimpleNamespace: cfg = SimpleNamespace( - exec_params=config.exec_params, runs=get_runs(config=config, subject=subject), task=get_task(config), datatype=get_datatype(config), @@ -200,7 +211,7 @@ def get_config( return cfg -def main(*, config) -> None: +def main(*, config: SimpleNamespace) -> None: """Run SSP.""" if config.spatial_filter != 'ssp': msg = 'Skipping …' @@ -216,6 +227,7 @@ def main(*, config) -> None: config=config, subject=subject, ), + exec_params=config.exec_params, subject=subject, session=session, ) diff --git a/mne_bids_pipeline/steps/preprocessing/_05a_apply_ica.py b/mne_bids_pipeline/steps/preprocessing/_06a_apply_ica.py similarity index 88% rename from mne_bids_pipeline/steps/preprocessing/_05a_apply_ica.py rename to mne_bids_pipeline/steps/preprocessing/_06a_apply_ica.py index da5b16f98..d824d5e9d 100644 --- a/mne_bids_pipeline/steps/preprocessing/_05a_apply_ica.py +++ b/mne_bids_pipeline/steps/preprocessing/_06a_apply_ica.py @@ -12,6 +12,7 @@ """ from types import SimpleNamespace +from typing import Optional import pandas as pd import mne @@ -30,12 +31,12 @@ from ..._run import failsafe_run, _update_for_splits, save_logs -def get_input_fnames_apply_ica(**kwargs): - cfg = kwargs.pop('cfg') - subject = kwargs.pop('subject') - session = kwargs.pop('session') - assert len(kwargs) == 0, kwargs.keys() - del kwargs +def get_input_fnames_apply_ica( + *, + cfg: SimpleNamespace, + subject: str, + session: Optional[str], +) -> dict: bids_basename = BIDSPath(subject=subject, session=session, task=cfg.task, @@ -58,7 +59,14 @@ def get_input_fnames_apply_ica(**kwargs): @failsafe_run( get_input_fnames=get_input_fnames_apply_ica, ) -def apply_ica(*, cfg, subject, session, in_files): +def apply_ica( + *, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, + subject: str, + session: Optional[str], + in_files: dict, +) -> dict: bids_basename = in_files['ica'].copy().update(processing=None) out_files = dict() out_files['epochs'] = in_files['epochs'].copy().update(processing='ica') @@ -131,7 +139,10 @@ def apply_ica(*, cfg, subject, session, in_files): n_jobs=1, # avoid automatic parallelization ) report.save( - out_files['report'], overwrite=True, open_browser=cfg.interactive) + out_files['report'], + overwrite=True, + open_browser=exec_params.interactive, + ) assert len(in_files) == 0, in_files.keys() @@ -142,7 +153,11 @@ def apply_ica(*, cfg, subject, session, in_files): msg = 'Skipping ICA addition to report, no components marked as bad.' logger.info(**gen_log_kwargs(message=msg)) if ica.exclude: - with _open_report(cfg=cfg, subject=subject, session=session) as report: + with _open_report( + cfg=cfg, + exec_params=exec_params, + subject=subject, + session=session) as report: report.add_ica( ica=ica, title='ICA', @@ -160,17 +175,15 @@ def apply_ica(*, cfg, subject, session, in_files): def get_config( *, - config, + config: SimpleNamespace, ) -> SimpleNamespace: cfg = SimpleNamespace( - exec_params=config.exec_params, task=get_task(config), datatype=get_datatype(config), acq=config.acq, rec=config.rec, space=config.space, deriv_root=config.deriv_root, - interactive=config.interactive, baseline=config.baseline, ica_reject=config.ica_reject, ch_types=config.ch_types, @@ -179,7 +192,7 @@ def get_config( return cfg -def main(*, config) -> None: +def main(*, config: SimpleNamespace) -> None: """Apply ICA.""" if not config.spatial_filter == 'ica': msg = 'Skipping …' @@ -191,7 +204,10 @@ def main(*, config) -> None: apply_ica, exec_params=config.exec_params) logs = parallel( run_func( - cfg=get_config(config=config), + cfg=get_config( + config=config, + ), + exec_params=config.exec_params, subject=subject, session=session) for subject in get_subjects(config) diff --git a/mne_bids_pipeline/steps/preprocessing/_05b_apply_ssp.py b/mne_bids_pipeline/steps/preprocessing/_06b_apply_ssp.py similarity index 85% rename from mne_bids_pipeline/steps/preprocessing/_05b_apply_ssp.py rename to mne_bids_pipeline/steps/preprocessing/_06b_apply_ssp.py index 276c89855..827ba3754 100644 --- a/mne_bids_pipeline/steps/preprocessing/_05b_apply_ssp.py +++ b/mne_bids_pipeline/steps/preprocessing/_06b_apply_ssp.py @@ -6,6 +6,7 @@ """ from types import SimpleNamespace +from typing import Optional import mne from mne_bids import BIDSPath @@ -18,12 +19,12 @@ from ..._parallel import parallel_func, get_parallel_backend -def get_input_fnames_apply_ssp(**kwargs): - cfg = kwargs.pop('cfg') - subject = kwargs.pop('subject') - session = kwargs.pop('session') - assert len(kwargs) == 0, kwargs.keys() - del kwargs +def get_input_fnames_apply_ssp( + *, + cfg: SimpleNamespace, + subject: str, + session: Optional[str], +) -> dict: bids_basename = BIDSPath(subject=subject, session=session, task=cfg.task, @@ -43,7 +44,14 @@ def get_input_fnames_apply_ssp(**kwargs): @failsafe_run( get_input_fnames=get_input_fnames_apply_ssp, ) -def apply_ssp(*, cfg, subject, session, in_files): +def apply_ssp( + *, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, + subject: str, + session: Optional[str], + in_files: dict, +) -> dict: # load epochs to reject ICA components # compute SSP on first run of raw out_files = dict() @@ -68,10 +76,9 @@ def apply_ssp(*, cfg, subject, session, in_files): def get_config( *, - config, + config: SimpleNamespace, ) -> SimpleNamespace: cfg = SimpleNamespace( - exec_params=config.exec_params, task=get_task(config), datatype=get_datatype(config), acq=config.acq, @@ -83,7 +90,7 @@ def get_config( return cfg -def main(*, config) -> None: +def main(*, config: SimpleNamespace) -> None: """Apply ssp.""" if not config.spatial_filter == 'ssp': msg = 'Skipping …' @@ -95,7 +102,10 @@ def main(*, config) -> None: apply_ssp, exec_params=config.exec_params) logs = parallel( run_func( - cfg=get_config(config=config), + cfg=get_config( + config=config, + ), + exec_params=config.exec_params, subject=subject, session=session) for subject in get_subjects(config) diff --git a/mne_bids_pipeline/steps/preprocessing/_06_ptp_reject.py b/mne_bids_pipeline/steps/preprocessing/_07_ptp_reject.py similarity index 89% rename from mne_bids_pipeline/steps/preprocessing/_06_ptp_reject.py rename to mne_bids_pipeline/steps/preprocessing/_07_ptp_reject.py index 09c966c26..28c4fa33b 100644 --- a/mne_bids_pipeline/steps/preprocessing/_06_ptp_reject.py +++ b/mne_bids_pipeline/steps/preprocessing/_07_ptp_reject.py @@ -9,6 +9,7 @@ """ from types import SimpleNamespace +from typing import Optional import mne from mne_bids import BIDSPath @@ -23,12 +24,12 @@ from ..._run import failsafe_run, _update_for_splits, save_logs -def get_input_fnames_drop_ptp(**kwargs): - cfg = kwargs.pop('cfg') - subject = kwargs.pop('subject') - session = kwargs.pop('session') - assert len(kwargs) == 0, kwargs.keys() - del kwargs +def get_input_fnames_drop_ptp( + *, + cfg: SimpleNamespace, + subject: str, + session: Optional[str], +) -> dict: bids_path = BIDSPath(subject=subject, session=session, task=cfg.task, @@ -50,7 +51,14 @@ def get_input_fnames_drop_ptp(**kwargs): @failsafe_run( get_input_fnames=get_input_fnames_drop_ptp, ) -def drop_ptp(*, cfg, subject, session, in_files): +def drop_ptp( + *, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, + subject: str, + session: Optional[str], + in_files: dict, +) -> dict: out_files = dict() out_files['epochs'] = in_files['epochs'].copy().update(processing='clean') msg = f'Input: {in_files["epochs"].basename}' @@ -125,7 +133,11 @@ def drop_ptp(*, cfg, subject, session, in_files): psd = True else: psd = 30 - with _open_report(cfg=cfg, subject=subject, session=session) as report: + with _open_report( + cfg=cfg, + exec_params=exec_params, + subject=subject, + session=session) as report: report.add_epochs( epochs=epochs, title='Epochs: after cleaning', @@ -138,10 +150,9 @@ def drop_ptp(*, cfg, subject, session, in_files): def get_config( *, - config, + config: SimpleNamespace, ) -> SimpleNamespace: cfg = SimpleNamespace( - exec_params=config.exec_params, task=get_task(config), datatype=get_datatype(config), acq=config.acq, @@ -160,7 +171,7 @@ def get_config( return cfg -def main(*, config) -> None: +def main(*, config: SimpleNamespace) -> None: """Run epochs.""" parallel, run_func = parallel_func( drop_ptp, exec_params=config.exec_params) @@ -168,7 +179,10 @@ def main(*, config) -> None: with get_parallel_backend(config.exec_params): logs = parallel( run_func( - cfg=get_config(config=config), + cfg=get_config( + config=config, + ), + exec_params=config.exec_params, subject=subject, session=session) for subject in get_subjects(config) diff --git a/mne_bids_pipeline/steps/preprocessing/__init__.py b/mne_bids_pipeline/steps/preprocessing/__init__.py index c184c8db8..c65b80284 100644 --- a/mne_bids_pipeline/steps/preprocessing/__init__.py +++ b/mne_bids_pipeline/steps/preprocessing/__init__.py @@ -1,21 +1,23 @@ """Preprocessing.""" -from . import _01_maxfilter -from . import _02_frequency_filter -from . import _03_make_epochs -from . import _04a_run_ica -from . import _04b_run_ssp -from . import _05a_apply_ica -from . import _05b_apply_ssp -from . import _06_ptp_reject +from . import _01_data_quality +from . import _02_maxfilter +from . import _03_frequency_filter +from . import _04_make_epochs +from . import _05a_run_ica +from . import _05b_run_ssp +from . import _06a_apply_ica +from . import _06b_apply_ssp +from . import _07_ptp_reject _STEPS = ( - _01_maxfilter, - _02_frequency_filter, - _03_make_epochs, - _04a_run_ica, - _04b_run_ssp, - _05a_apply_ica, - _05b_apply_ssp, - _06_ptp_reject, + _01_data_quality, + _02_maxfilter, + _03_frequency_filter, + _04_make_epochs, + _05a_run_ica, + _05b_run_ssp, + _06a_apply_ica, + _06b_apply_ssp, + _07_ptp_reject, ) diff --git a/mne_bids_pipeline/steps/sensor/_01_make_evoked.py b/mne_bids_pipeline/steps/sensor/_01_make_evoked.py index 3fe64d987..b5a07f83c 100644 --- a/mne_bids_pipeline/steps/sensor/_01_make_evoked.py +++ b/mne_bids_pipeline/steps/sensor/_01_make_evoked.py @@ -1,6 +1,7 @@ """Extract evoked data for each condition.""" from types import SimpleNamespace +from typing import Optional import mne from mne_bids import BIDSPath @@ -15,12 +16,12 @@ from ..._run import failsafe_run, save_logs, _sanitize_callable -def get_input_fnames_evoked(**kwargs): - cfg = kwargs.pop('cfg') - subject = kwargs.pop('subject') - session = kwargs.pop('session') - assert len(kwargs) == 0, kwargs.keys() - del kwargs +def get_input_fnames_evoked( + *, + cfg: SimpleNamespace, + subject: str, + session: Optional[str], +) -> dict: fname_epochs = BIDSPath(subject=subject, session=session, task=cfg.task, @@ -42,7 +43,14 @@ def get_input_fnames_evoked(**kwargs): @failsafe_run( get_input_fnames=get_input_fnames_evoked, ) -def run_evoked(*, cfg, subject, session, in_files): +def run_evoked( + *, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, + subject: str, + session: Optional[str], + in_files: dict, +) -> dict: out_files = dict() out_files['evoked'] = in_files['epochs'].copy().update( suffix='ave', processing=None, check=False) @@ -93,7 +101,11 @@ def run_evoked(*, cfg, subject, session, in_files): logger.info( **gen_log_kwargs(message=msg) ) - with _open_report(cfg=cfg, subject=subject, session=session) as report: + with _open_report( + cfg=cfg, + exec_params=exec_params, + subject=subject, + session=session) as report: for condition, evoked in all_evoked.items(): _restrict_analyze_channels(evoked, cfg) @@ -114,7 +126,7 @@ def run_evoked(*, cfg, subject, session, in_files): ) # Interaction - if cfg.interactive: + if exec_params.interactive: for evoked in evokeds: evoked.plot() @@ -132,10 +144,9 @@ def run_evoked(*, cfg, subject, session, in_files): def get_config( *, - config, + config: SimpleNamespace, ) -> SimpleNamespace: cfg = SimpleNamespace( - exec_params=config.exec_params, task=get_task(config), datatype=get_datatype(config), acq=config.acq, @@ -144,7 +155,6 @@ def get_config( deriv_root=config.deriv_root, conditions=config.conditions, contrasts=get_all_contrasts(config), - interactive=config.interactive, proc=config.proc, noise_cov=_sanitize_callable(config.noise_cov), analyze_channels=config.analyze_channels, @@ -154,7 +164,7 @@ def get_config( return cfg -def main(*, config) -> None: +def main(*, config: SimpleNamespace) -> None: """Run evoked.""" if config.task_is_rest: msg = ' … skipping: for resting-state task.' @@ -166,7 +176,10 @@ def main(*, config) -> None: run_evoked, exec_params=config.exec_params) logs = parallel( run_func( - cfg=get_config(config=config), + cfg=get_config( + config=config, + ), + exec_params=config.exec_params, subject=subject, session=session, ) diff --git a/mne_bids_pipeline/steps/sensor/_02_decoding_full_epochs.py b/mne_bids_pipeline/steps/sensor/_02_decoding_full_epochs.py index 3c7e5cb00..2a584da11 100644 --- a/mne_bids_pipeline/steps/sensor/_02_decoding_full_epochs.py +++ b/mne_bids_pipeline/steps/sensor/_02_decoding_full_epochs.py @@ -10,6 +10,7 @@ import os.path as op from types import SimpleNamespace +from typing import Optional import numpy as np import pandas as pd @@ -36,15 +37,14 @@ _sanitize_cond_tag) -def get_input_fnames_epochs_decoding(**kwargs): - cfg = kwargs.pop('cfg') - subject = kwargs.pop('subject') - session = kwargs.pop('session') - # TODO: Somehow remove these? - del kwargs['condition1'] - del kwargs['condition2'] - assert len(kwargs) == 0, kwargs.keys() - del kwargs +def get_input_fnames_epochs_decoding( + *, + cfg: SimpleNamespace, + subject: str, + session: Optional[str], + condition1: str, + condition2: str, +) -> dict: fname_epochs = BIDSPath(subject=subject, session=session, task=cfg.task, @@ -65,8 +65,16 @@ def get_input_fnames_epochs_decoding(**kwargs): @failsafe_run( get_input_fnames=get_input_fnames_epochs_decoding, ) -def run_epochs_decoding(*, cfg, subject, condition1, condition2, session, - in_files): +def run_epochs_decoding( + *, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, + subject: str, + session: Optional[str], + condition1: str, + condition2: str, + in_files: dict, +) -> dict: import matplotlib.pyplot as plt msg = f'Contrasting conditions: {condition1} – {condition2}' logger.info(**gen_log_kwargs(message=msg)) @@ -148,7 +156,11 @@ def run_epochs_decoding(*, cfg, subject, condition1, condition2, session, tabular_data.to_csv(out_files[tsv_key], sep='\t', index=False) # Report - with _open_report(cfg=cfg, subject=subject, session=session) as report: + with _open_report( + cfg=cfg, + exec_params=exec_params, + subject=subject, + session=session) as report: msg = 'Adding full-epochs decoding results to the report.' logger.info(**gen_log_kwargs(message=msg)) @@ -202,10 +214,9 @@ def run_epochs_decoding(*, cfg, subject, condition1, condition2, session, def get_config( *, - config, + config: SimpleNamespace, ) -> SimpleNamespace: cfg = SimpleNamespace( - exec_params=config.exec_params, task=get_task(config), datatype=get_datatype(config), acq=config.acq, @@ -225,7 +236,7 @@ def get_config( return cfg -def main(*, config) -> None: +def main(*, config: SimpleNamespace) -> None: """Run time-by-time decoding.""" if not config.contrasts: msg = 'No contrasts specified; not performing decoding.' @@ -243,6 +254,7 @@ def main(*, config) -> None: logs = parallel( run_func( cfg=get_config(config=config), + exec_params=config.exec_params, subject=subject, condition1=cond_1, condition2=cond_2, diff --git a/mne_bids_pipeline/steps/sensor/_03_decoding_time_by_time.py b/mne_bids_pipeline/steps/sensor/_03_decoding_time_by_time.py index a6ed6b467..49b216bcc 100644 --- a/mne_bids_pipeline/steps/sensor/_03_decoding_time_by_time.py +++ b/mne_bids_pipeline/steps/sensor/_03_decoding_time_by_time.py @@ -13,6 +13,7 @@ import os.path as op from types import SimpleNamespace +from typing import Optional import numpy as np import pandas as pd @@ -36,23 +37,21 @@ from ..._decoding import LogReg from ..._logging import gen_log_kwargs, logger from ..._run import failsafe_run, save_logs -from ..._parallel import ( - get_parallel_backend, get_n_jobs, get_parallel_backend_name) +from ..._parallel import get_parallel_backend, get_parallel_backend_name from ..._report import ( _open_report, _plot_decoding_time_generalization, _sanitize_cond_tag, _plot_time_by_time_decoding_scores, ) -def get_input_fnames_time_decoding(**kwargs): - cfg = kwargs.pop('cfg') - subject = kwargs.pop('subject') - session = kwargs.pop('session') - # TODO: Somehow remove these? - del kwargs['condition1'] - del kwargs['condition2'] - assert len(kwargs) == 0, kwargs.keys() - del kwargs +def get_input_fnames_time_decoding( + *, + cfg: SimpleNamespace, + subject: str, + session: Optional[str], + condition1: str, + condition2: str, +) -> dict: # TODO: Shouldn't this at least use the PTP-rejected epochs if available? fname_epochs = BIDSPath(subject=subject, session=session, @@ -74,8 +73,16 @@ def get_input_fnames_time_decoding(**kwargs): @failsafe_run( get_input_fnames=get_input_fnames_time_decoding, ) -def run_time_decoding(*, cfg, subject, condition1, condition2, session, - in_files): +def run_time_decoding( + *, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, + subject: str, + session: Optional[str], + condition1: str, + condition2: str, + in_files: dict, +) -> dict: import matplotlib.pyplot as plt if cfg.decoding_time_generalization: kind = 'time generalization' @@ -118,10 +125,6 @@ def run_time_decoding(*, cfg, subject, condition1, condition2, session, X = epochs.get_data() y = np.r_[np.ones(n_cond1), np.zeros(n_cond2)] # ProgressBar does not work on dask, so only enable it if not using dask - exec_params = SimpleNamespace( - parallel_backend=cfg.parallel_backend, - N_JOBS=cfg.N_JOBS, - ) verbose = get_parallel_backend_name(exec_params=exec_params) != "dask" with get_parallel_backend(exec_params): clf = make_pipeline( @@ -142,7 +145,7 @@ def run_time_decoding(*, cfg, subject, condition1, condition2, session, estimator = GeneralizingEstimator( clf, scoring=cfg.decoding_metric, - n_jobs=cfg.n_jobs, + n_jobs=exec_params.N_JOBS, ) cv_scoring_n_jobs = 1 else: @@ -151,7 +154,7 @@ def run_time_decoding(*, cfg, subject, condition1, condition2, session, scoring=cfg.decoding_metric, n_jobs=1, ) - cv_scoring_n_jobs = cfg.n_jobs + cv_scoring_n_jobs = exec_params.N_JOBS scores = cross_val_multiscore( estimator, X=X, y=y, cv=cv, n_jobs=cv_scoring_n_jobs, @@ -195,7 +198,11 @@ def run_time_decoding(*, cfg, subject, condition1, condition2, session, out_files[f'tsv_{processing}'], sep='\t', index=False) # Report - with _open_report(cfg=cfg, subject=subject, session=session) as report: + with _open_report( + cfg=cfg, + exec_params=exec_params, + subject=subject, + session=session) as report: msg = 'Adding time-by-time decoding results to the report.' logger.info(**gen_log_kwargs(message=msg)) @@ -276,10 +283,9 @@ def run_time_decoding(*, cfg, subject, condition1, condition2, session, def get_config( *, - config, + config: SimpleNamespace, ) -> SimpleNamespace: cfg = SimpleNamespace( - exec_params=config.exec_params, task=get_task(config), datatype=get_datatype(config), acq=config.acq, @@ -297,17 +303,11 @@ def get_config( analyze_channels=config.analyze_channels, ch_types=config.ch_types, eeg_reference=get_eeg_reference(config), - # TODO: None of these should affect caching, but they will... - n_jobs=get_n_jobs(exec_params=config.exec_params), - parallel_backend=config.exec_params.parallel_backend, - interactive=config.interactive, - N_JOBS=config.exec_params.N_JOBS, - config_path=config.config_path, ) return cfg -def main(*, config) -> None: +def main(*, config: SimpleNamespace) -> None: """Run time-by-time decoding.""" if not config.contrasts: msg = 'No contrasts specified; not performing decoding.' diff --git a/mne_bids_pipeline/steps/sensor/_04_time_frequency.py b/mne_bids_pipeline/steps/sensor/_04_time_frequency.py index f95ad5162..4cb1301ab 100644 --- a/mne_bids_pipeline/steps/sensor/_04_time_frequency.py +++ b/mne_bids_pipeline/steps/sensor/_04_time_frequency.py @@ -5,6 +5,7 @@ """ from types import SimpleNamespace +from typing import Optional import numpy as np @@ -22,12 +23,12 @@ from ..._report import _open_report, _sanitize_cond_tag -def get_input_fnames_time_frequency(**kwargs): - cfg = kwargs.pop('cfg') - subject = kwargs.pop('subject') - session = kwargs.pop('session') - assert len(kwargs) == 0, kwargs.keys() - del kwargs +def get_input_fnames_time_frequency( + *, + cfg: SimpleNamespace, + subject: str, + session: Optional[str], +) -> dict: processing = None if cfg.spatial_filter is not None: processing = 'clean' @@ -52,7 +53,14 @@ def get_input_fnames_time_frequency(**kwargs): @failsafe_run( get_input_fnames=get_input_fnames_time_frequency, ) -def run_time_frequency(*, cfg, subject, session, in_files): +def run_time_frequency( + *, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, + subject: str, + session: Optional[str], + in_files: dict, +) -> dict: import matplotlib.pyplot as plt msg = f'Input: {in_files["epochs"].basename}' logger.info(**gen_log_kwargs(message=msg)) @@ -95,7 +103,11 @@ def run_time_frequency(*, cfg, subject, session, in_files): itc.save(out_files[itc_key], overwrite=True, verbose='error') # Report - with _open_report(cfg=cfg, subject=subject, session=session) as report: + with _open_report( + cfg=cfg, + exec_params=exec_params, + subject=subject, + session=session) as report: msg = 'Adding TFR analysis results to the report.' logger.info(**gen_log_kwargs(message=msg)) for condition in cfg.time_frequency_conditions: @@ -144,10 +156,9 @@ def run_time_frequency(*, cfg, subject, session, in_files): def get_config( *, - config, + config: SimpleNamespace, ) -> SimpleNamespace: cfg = SimpleNamespace( - exec_params=config.exec_params, task=get_task(config), datatype=get_datatype(config), acq=config.acq, @@ -170,7 +181,7 @@ def get_config( return cfg -def main(*, config) -> None: +def main(*, config: SimpleNamespace) -> None: """Run Time-frequency decomposition.""" if not config.time_frequency_conditions: msg = 'Skipping …' @@ -182,7 +193,10 @@ def main(*, config) -> None: with get_parallel_backend(config.exec_params): logs = parallel( run_func( - cfg=get_config(config=config), + cfg=get_config( + config=config, + ), + exec_params=config.exec_params, subject=subject, session=session, ) diff --git a/mne_bids_pipeline/steps/sensor/_05_decoding_csp.py b/mne_bids_pipeline/steps/sensor/_05_decoding_csp.py index c5252e7d0..77f6d3198 100644 --- a/mne_bids_pipeline/steps/sensor/_05_decoding_csp.py +++ b/mne_bids_pipeline/steps/sensor/_05_decoding_csp.py @@ -111,14 +111,13 @@ def prepare_epochs_and_y( return epochs_filt, y -def get_input_fnames_csp(**kwargs): - cfg = kwargs.pop('cfg') - subject = kwargs.pop('subject') - session = kwargs.pop('session') - # TODO: Somehow remove this? - del kwargs['contrast'] - assert len(kwargs) == 0, kwargs.keys() - del kwargs +def get_input_fnames_csp( + *, + cfg: SimpleNamespace, + subject: str, + session: Optional[str], + contrast: Tuple[str], +) -> dict: fname_epochs = BIDSPath(subject=subject, session=session, task=cfg.task, @@ -141,12 +140,13 @@ def get_input_fnames_csp(**kwargs): ) def one_subject_decoding( *, - cfg, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, subject: str, session: str, contrast: Tuple[str, str], in_files: Dict[str, BIDSPath] -) -> None: +) -> dict: """Run one subject. There are two steps in this function: @@ -369,7 +369,11 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, out_files = {'csp-excel': fname_results} # Report - with _open_report(cfg=cfg, subject=subject, session=session) as report: + with _open_report( + cfg=cfg, + exec_params=exec_params, + subject=subject, + session=session) as report: msg = 'Adding CSP decoding results to the report.' logger.info(**gen_log_kwargs(message=msg)) section = 'Decoding: CSP' @@ -511,12 +515,11 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, def get_config( *, - config, + config: SimpleNamespace, subject: str, session: Optional[str] ) -> SimpleNamespace: cfg = SimpleNamespace( - exec_params=config.exec_params, # Data parameters datatype=get_datatype(config), deriv_root=config.deriv_root, @@ -537,12 +540,11 @@ def get_config( decoding_contrasts=get_decoding_contrasts(config), n_boot=config.n_boot, random_state=config.random_state, - interactive=config.interactive ) return cfg -def main(*, config) -> None: +def main(*, config: SimpleNamespace) -> None: """Run all subjects decoding in parallel.""" if not config.contrasts or not config.decoding_csp: if not config.contrasts: @@ -564,6 +566,7 @@ def main(*, config) -> None: subject=subject, session=session ), + exec_params=config.exec_params, subject=subject, session=session, contrast=contrast, diff --git a/mne_bids_pipeline/steps/sensor/_06_make_cov.py b/mne_bids_pipeline/steps/sensor/_06_make_cov.py index bd2f7ab48..9f83c6359 100644 --- a/mne_bids_pipeline/steps/sensor/_06_make_cov.py +++ b/mne_bids_pipeline/steps/sensor/_06_make_cov.py @@ -21,13 +21,12 @@ from ..._run import failsafe_run, save_logs, _sanitize_callable -def get_input_fnames_cov(**kwargs): - cfg = kwargs.pop('cfg') - subject = kwargs.pop('subject') - session = kwargs.pop('session') - assert len(kwargs) == 0, kwargs.keys() - del kwargs - # short circuit to say: always re-run +def get_input_fnames_cov( + *, + cfg: SimpleNamespace, + subject: str, + session: Optional[str], +) -> dict: cov_type = _get_cov_type(cfg) in_files = dict() processing = 'clean' if cfg.spatial_filter is not None else None @@ -124,6 +123,7 @@ def compute_cov_from_raw(*, cfg, subject, session, in_files, out_files): def retrieve_custom_cov( cfg: SimpleNamespace, + exec_params: SimpleNamespace, subject: str, session: Optional[str], in_files: None, @@ -132,7 +132,7 @@ def retrieve_custom_cov( # This should be the only place we use config.noise_cov (rather than cfg.* # entries) config = _import_config( - config_path=cfg.config_path, + config_path=exec_params.config_path, check=False, log=False, ) @@ -182,7 +182,14 @@ def _get_cov_type(cfg): @failsafe_run( get_input_fnames=get_input_fnames_cov, ) -def run_covariance(*, cfg, subject, session, in_files): +def run_covariance( + *, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, + subject: str, + session: Optional[str] = None, + in_files: dict, +) -> dict: import matplotlib.pyplot as plt out_files = dict() out_files['cov'] = get_noise_cov_bids_path( @@ -193,7 +200,7 @@ def run_covariance(*, cfg, subject, session, in_files): cov_type = _get_cov_type(cfg) kwargs = dict( cfg=cfg, subject=subject, session=session, - in_files=in_files, out_files=out_files) + in_files=in_files, out_files=out_files, exec_params=exec_params) fname_info = in_files.pop('report_info') fname_evoked = in_files.pop('evoked', None) if cov_type == 'custom': @@ -206,7 +213,11 @@ def run_covariance(*, cfg, subject, session, in_files): cov.save(out_files['cov'], overwrite=True) # Report - with _open_report(cfg=cfg, subject=subject, session=session) as report: + with _open_report( + cfg=cfg, + exec_params=exec_params, + subject=subject, + session=session) as report: msg = 'Rendering noise covariance matrix and corresponding SVD.' logger.info(**gen_log_kwargs(message=msg)) report.add_covariance( @@ -246,10 +257,9 @@ def run_covariance(*, cfg, subject, session, in_files): def get_config( *, - config, + config: SimpleNamespace, ) -> SimpleNamespace: cfg = SimpleNamespace( - exec_params=config.exec_params, task=get_task(config), datatype=get_datatype(config), acq=config.acq, @@ -264,12 +274,11 @@ def get_config( conditions=config.conditions, all_contrasts=get_all_contrasts(config), analyze_channels=config.analyze_channels, - config_path=config.config_path, ) return cfg -def main(*, config) -> None: +def main(*, config: SimpleNamespace) -> None: """Run cov.""" if not config.run_source_estimation: msg = 'Skipping, run_source_estimation is set to False …' @@ -290,6 +299,7 @@ def main(*, config) -> None: logs = parallel( run_func( cfg=get_config(config=config), + exec_params=config.exec_params, subject=subject, session=session, ) diff --git a/mne_bids_pipeline/steps/sensor/_99_group_average.py b/mne_bids_pipeline/steps/sensor/_99_group_average.py index 99f0cf1ab..f62033b33 100644 --- a/mne_bids_pipeline/steps/sensor/_99_group_average.py +++ b/mne_bids_pipeline/steps/sensor/_99_group_average.py @@ -27,7 +27,12 @@ from ..._report import run_report_average_sensor -def average_evokeds(cfg, session): +def average_evokeds( + *, + cfg: SimpleNamespace, + subject: str, + session: Optional[str], +) -> List[mne.Evoked]: # Container for all conditions: all_evokeds = defaultdict(list) @@ -59,7 +64,6 @@ def average_evokeds(cfg, session): # Keep condition in comment all_evokeds[idx].comment = 'Grand average: ' + evokeds[0].comment - subject = 'average' fname_out = BIDSPath(subject=subject, session=session, task=cfg.task, @@ -579,7 +583,6 @@ def get_config( ) -> SimpleNamespace: dtg_decim = config.decoding_time_generalization_decim cfg = SimpleNamespace( - exec_params=config.exec_params, subjects=get_subjects(config), task=get_task(config), task_is_rest=config.task_is_rest, @@ -608,7 +611,6 @@ def get_config( interpolate_bads_grand_average=config.interpolate_bads_grand_average, ch_types=config.ch_types, eeg_reference=get_eeg_reference(config), - interactive=config.interactive, sessions=get_sessions(config), bids_root=config.bids_root, data_type=config.data_type, @@ -616,18 +618,18 @@ def get_config( all_contrasts=get_all_contrasts(config), report_evoked_n_time_points=config.report_evoked_n_time_points, cluster_permutation_p_threshold=config.cluster_permutation_p_threshold, - # TODO: This script deviates from our standard procedure, when we - # eventually cache and have in/out_files we'll want to remove this - # since it breaks the "execution params should only be in - # cfg.exec_params" idea - exec_params_inner=config.exec_params, ) return cfg # pass 'average' subject for logging @failsafe_run() -def run_group_average_sensor(*, cfg): +def run_group_average_sensor( + *, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, + subject: str, +) -> None: if cfg.task_is_rest: msg = ' … skipping: for "rest" task.' logger.info(**gen_log_kwargs(message=msg)) @@ -637,10 +639,14 @@ def run_group_average_sensor(*, cfg): if not sessions: sessions = [None] - with get_parallel_backend(cfg.exec_params_inner): + with get_parallel_backend(exec_params): for session in sessions: - evokeds = average_evokeds(cfg, session) - if cfg.interactive: + evokeds = average_evokeds( + cfg=cfg, + subject=subject, + session=session, + ) + if exec_params.interactive: for evoked in evokeds: evoked.plot() @@ -649,7 +655,7 @@ def run_group_average_sensor(*, cfg): average_time_by_time_decoding(cfg, session) if cfg.decode and cfg.decoding_csp: parallel, run_func = parallel_func( - average_csp_decoding, exec_params=cfg.exec_params_inner) + average_csp_decoding) parallel( run_func( cfg=cfg, @@ -662,11 +668,17 @@ def run_group_average_sensor(*, cfg): ) for session in sessions: - run_report_average_sensor(cfg=cfg, session=session) + run_report_average_sensor( + cfg=cfg, + subject=subject, + session=session, + ) -def main(*, config) -> None: +def main(*, config: SimpleNamespace) -> None: log = run_group_average_sensor( cfg=get_config(config=config), + exec_params=config.exec_params, + subject='average', ) save_logs(config=config, logs=[log]) diff --git a/mne_bids_pipeline/steps/source/_01_make_bem_surfaces.py b/mne_bids_pipeline/steps/source/_01_make_bem_surfaces.py index 4c986147e..165f01704 100644 --- a/mne_bids_pipeline/steps/source/_01_make_bem_surfaces.py +++ b/mne_bids_pipeline/steps/source/_01_make_bem_surfaces.py @@ -16,7 +16,7 @@ from ..._run import failsafe_run, save_logs -def _get_bem_params(cfg): +def _get_bem_params(cfg: SimpleNamespace): mri_dir = Path(cfg.fs_subjects_dir) / cfg.fs_subject / 'mri' flash_dir = mri_dir / 'flash' / 'parameter_maps' if cfg.bem_mri_images == 'FLASH' and not flash_dir.exists(): @@ -30,7 +30,11 @@ def _get_bem_params(cfg): return mri_images, mri_dir, flash_dir -def get_input_fnames_make_bem_surfaces(*, cfg, subject): +def get_input_fnames_make_bem_surfaces( + *, + cfg: SimpleNamespace, + subject: str, +) -> dict: in_files = dict() mri_images, mri_dir, flash_dir = _get_bem_params(cfg) in_files['t1'] = mri_dir / 'T1.mgz' @@ -42,7 +46,11 @@ def get_input_fnames_make_bem_surfaces(*, cfg, subject): return in_files -def get_output_fnames_make_bem_surfaces(*, cfg, subject): +def get_output_fnames_make_bem_surfaces( + *, + cfg: SimpleNamespace, + subject: str, +) -> dict: out_files = dict() conductivity, _ = _get_bem_conductivity(cfg) n_layers = len(conductivity) @@ -56,7 +64,13 @@ def get_output_fnames_make_bem_surfaces(*, cfg, subject): get_input_fnames=get_input_fnames_make_bem_surfaces, get_output_fnames=get_output_fnames_make_bem_surfaces, ) -def make_bem_surfaces(*, cfg, subject, in_files): +def make_bem_surfaces( + *, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, + subject: str, + in_files: dict, +) -> dict: mri_images, _, _ = _get_bem_params(cfg) in_files.clear() # assume we use everything we add if mri_images == 'FLASH': @@ -67,7 +81,7 @@ def make_bem_surfaces(*, cfg, subject, in_files): 'watershed algorithm') bem_func = mne.bem.make_watershed_bem logger.info(**gen_log_kwargs(message=msg, subject=subject)) - show = True if cfg.interactive else False + show = True if exec_params.interactive else False bem_func( subject=cfg.fs_subject, subjects_dir=cfg.fs_subjects_dir, @@ -90,7 +104,6 @@ def get_config( fs_subject=get_fs_subject(config=config, subject=subject), fs_subjects_dir=get_fs_subjects_dir(config=config), bem_mri_images=config.bem_mri_images, - interactive=config.interactive, freesurfer_verbose=config.freesurfer_verbose, use_template_mri=config.use_template_mri, ch_types=config.ch_types, @@ -122,6 +135,7 @@ def main(*, config) -> None: config=config, subject=subject, ), + exec_params=config.exec_params, subject=subject, force_run=config.recreate_bem) for subject in get_subjects(config) diff --git a/mne_bids_pipeline/steps/source/_04_make_forward.py b/mne_bids_pipeline/steps/source/_04_make_forward.py index 4521ed6aa..208074d53 100644 --- a/mne_bids_pipeline/steps/source/_04_make_forward.py +++ b/mne_bids_pipeline/steps/source/_04_make_forward.py @@ -4,6 +4,7 @@ """ from types import SimpleNamespace +from typing import Optional import mne from mne.coreg import Coregistration @@ -41,7 +42,12 @@ def _prepare_trans_template(cfg, info): return trans -def _prepare_trans(cfg, bids_path): +def _prepare_trans( + *, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, + bids_path: BIDSPath, +): # Generate a head ↔ MRI transformation matrix from the # electrophysiological and MRI sidecar files, and save it to an MNE # "trans" file in the derivatives folder. @@ -49,7 +55,7 @@ def _prepare_trans(cfg, bids_path): # TODO: This breaks our encapsulation config = _import_config( - config_path=cfg.config_path, + config_path=exec_params.config_path, check=False, log=False, ) @@ -111,7 +117,14 @@ def get_input_fnames_forward(*, cfg, subject, session): @failsafe_run( get_input_fnames=get_input_fnames_forward, ) -def run_forward(*, cfg, subject, session, in_files): +def run_forward( + *, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, + subject: str, + session: Optional[str], + in_files: dict, +) -> dict: bids_path = BIDSPath(subject=subject, session=session, task=cfg.task, @@ -144,9 +157,13 @@ def run_forward(*, cfg, subject, session, in_files): # trans if cfg.use_template_mri is not None: - trans = _prepare_trans_template(cfg, info) + trans = _prepare_trans_template(cfg, exec_params) else: - trans = _prepare_trans(cfg, bids_path) + trans = _prepare_trans( + cfg=cfg, + exec_params=exec_params, + bids_path=bids_path, + ) msg = 'Calculating forward solution' logger.info(**gen_log_kwargs(message=msg)) @@ -159,7 +176,11 @@ def run_forward(*, cfg, subject, session, in_files): mne.write_forward_solution(out_files['forward'], fwd, overwrite=True) # Report - with _open_report(cfg=cfg, subject=subject, session=session) as report: + with _open_report( + cfg=cfg, + exec_params=exec_params, + subject=subject, + session=session) as report: msg = 'Rendering MRI slices with BEM contours.' logger.info(**gen_log_kwargs(message=msg)) report.add_bem( @@ -198,11 +219,10 @@ def run_forward(*, cfg, subject, session, in_files): def get_config( *, - config, + config: SimpleNamespace, subject: str, ) -> SimpleNamespace: cfg = SimpleNamespace( - exec_params=config.exec_params, task=get_task(config), runs=get_runs(config=config, subject=subject), datatype=get_datatype(config), @@ -219,12 +239,11 @@ def get_config( fs_subjects_dir=get_fs_subjects_dir(config), deriv_root=config.deriv_root, bids_root=config.bids_root, - config_path=config.config_path, ) return cfg -def main(*, config) -> None: +def main(*, config: SimpleNamespace) -> None: """Run forward.""" if not config.run_source_estimation: msg = 'Skipping, run_source_estimation is set to False …' @@ -239,6 +258,7 @@ def main(*, config) -> None: cfg=get_config( config=config, subject=subject), + exec_params=config.exec_params, subject=subject, session=session, ) diff --git a/mne_bids_pipeline/steps/source/_05_make_inverse.py b/mne_bids_pipeline/steps/source/_05_make_inverse.py index 912d9cb34..b131e2c54 100644 --- a/mne_bids_pipeline/steps/source/_05_make_inverse.py +++ b/mne_bids_pipeline/steps/source/_05_make_inverse.py @@ -5,6 +5,7 @@ import pathlib from types import SimpleNamespace +from typing import Optional import mne from mne.minimum_norm import (make_inverse_operator, apply_inverse, @@ -21,7 +22,12 @@ from ..._run import failsafe_run, save_logs, _sanitize_callable -def get_input_fnames_inverse(*, cfg, subject, session): +def get_input_fnames_inverse( + *, + cfg: SimpleNamespace, + subject: str, + session: Optional[str], +): bids_path = BIDSPath(subject=subject, session=session, task=cfg.task, @@ -50,7 +56,14 @@ def get_input_fnames_inverse(*, cfg, subject, session): @failsafe_run( get_input_fnames=get_input_fnames_inverse, ) -def run_inverse(*, cfg, subject, session, in_files): +def run_inverse( + *, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, + subject: str, + session: Optional[str], + in_files: dict, +) -> dict: # TODO: Eventually we should maybe loop over ch_types, e.g., to create # MEG, EEG, and MEG+EEG inverses and STCs fname_fwd = in_files.pop('forward') @@ -106,7 +119,11 @@ def run_inverse(*, cfg, subject, session, in_files): stc.save(out_files[key], overwrite=True) out_files[key] = pathlib.Path(str(out_files[key]) + '-lh.stc') - with _open_report(cfg=cfg, subject=subject, session=session) as report: + with _open_report( + cfg=cfg, + exec_params=exec_params, + subject=subject, + session=session) as report: for condition in conditions: cond_str = sanitize_cond_name(condition) key = f'{cond_str}+{method}+hemi' diff --git a/mne_bids_pipeline/steps/source/_99_group_average.py b/mne_bids_pipeline/steps/source/_99_group_average.py index 3d85bfe6d..f1ac6c88d 100644 --- a/mne_bids_pipeline/steps/source/_99_group_average.py +++ b/mne_bids_pipeline/steps/source/_99_group_average.py @@ -4,6 +4,7 @@ """ from types import SimpleNamespace +from typing import Optional, List import numpy as np @@ -16,6 +17,7 @@ ) from ..._logging import logger, gen_log_kwargs from ..._parallel import get_parallel_backend, parallel_func +from ..._report import run_report_average_source from ..._run import failsafe_run, save_logs @@ -67,8 +69,13 @@ def morph_stc(cfg, subject, fs_subject, session=None): return morphed_stcs -def run_average(cfg, session, mean_morphed_stcs): - subject = 'average' +def run_average( + *, + cfg: SimpleNamespace, + subject: str, + session: Optional[str], + mean_morphed_stcs: List[mne.SourceEstimate], +): bids_path = BIDSPath(subject=subject, session=session, task=cfg.task, @@ -100,10 +107,9 @@ def run_average(cfg, session, mean_morphed_stcs): def get_config( *, - config, + config: SimpleNamespace, ) -> SimpleNamespace: cfg = SimpleNamespace( - exec_params=config.exec_params, task=get_task(config), task_is_rest=config.task_is_rest, datatype=get_datatype(config), @@ -125,23 +131,25 @@ def get_config( use_template_mri=config.use_template_mri, all_contrasts=get_all_contrasts(config), report_stc_n_time_points=config.report_stc_n_time_points, - # TODO: This deviates in the same way as the sensor group avg script - exec_params_inner=config.exec_params, - interactive=config.interactive, ) return cfg # pass 'average' subject for logging @failsafe_run() -def run_group_average_source(*, cfg, subject='average'): +def run_group_average_source( + *, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, + subject: str, +) -> None: """Run group average in source space""" mne.datasets.fetch_fsaverage(subjects_dir=get_fs_subjects_dir(cfg)) - with get_parallel_backend(cfg.exec_params_inner): + with get_parallel_backend(exec_params): parallel, run_func = parallel_func( - morph_stc, exec_params=cfg.exec_params_inner) + morph_stc, exec_params=exec_params) all_morphed_stcs = parallel( run_func( cfg=cfg, subject=subject, @@ -163,15 +171,26 @@ def run_group_average_source(*, cfg, subject='average'): run_average( cfg=cfg, session=session, + subject=subject, mean_morphed_stcs=mean_morphed_stcs ) + run_report_average_source( + cfg=cfg, + exec_params=exec_params, + subject=subject, + session=session, + ) -def main(*, config) -> None: +def main(*, config: SimpleNamespace) -> None: if not config.run_source_estimation: msg = 'Skipping, run_source_estimation is set to False …' logger.info(**gen_log_kwargs(message=msg, emoji='skip')) return - log = run_group_average_source(cfg=get_config(config=config)) + log = run_group_average_source( + cfg=get_config(config=config), + exec_params=config.exec_params, + subject='average', + ) save_logs(config=config, logs=[log]) From 79c745d3e8c04f0158ba7fc08a10a97d6f4f7f00 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Wed, 16 Nov 2022 17:25:35 -0500 Subject: [PATCH 2/7] FIX: Closer --- mne_bids_pipeline/_config_utils.py | 7 ++++--- mne_bids_pipeline/_import_data.py | 4 ++++ .../steps/preprocessing/_01_data_quality.py | 1 - .../steps/sensor/_03_decoding_time_by_time.py | 5 ++++- .../steps/sensor/_06_make_cov.py | 21 +++++++++++++++---- .../steps/sensor/_99_group_average.py | 10 ++++++--- .../steps/source/_99_group_average.py | 4 +++- mne_bids_pipeline/tests/conftest.py | 1 + 8 files changed, 40 insertions(+), 13 deletions(-) diff --git a/mne_bids_pipeline/_config_utils.py b/mne_bids_pipeline/_config_utils.py index 2ae5b1dc5..a5968e2c7 100644 --- a/mne_bids_pipeline/_config_utils.py +++ b/mne_bids_pipeline/_config_utils.py @@ -260,8 +260,9 @@ def get_channels_to_analyze( # Return names of the channels of the channel types we wish to analyze. # We also include channels marked as "bad" here. # `exclude=[]`: keep "bad" channels, too. + kwargs = dict(eog=True, ecg=True, exclude=()) if get_datatype(config) == 'meg' and _meg_in_ch_types(config.ch_types): - pick_idx = mne.pick_types(info, eog=True, ecg=True, exclude=[]) + pick_idx = mne.pick_types(info, **kwargs) if 'mag' in config.ch_types: pick_idx = np.concatenate( @@ -270,8 +271,8 @@ def get_channels_to_analyze( pick_idx = np.concatenate( [pick_idx, mne.pick_types(info, meg='grad', exclude=[])]) if 'meg' in config.ch_types: - pick_idx = mne.pick_types(info, meg=True, eog=True, ecg=True, - exclude=[]) + pick_idx = mne.pick_types(info, meg=True, exclude=[]) + pick_idx.sort() elif config.ch_types == ['eeg']: pick_idx = mne.pick_types(info, meg=False, eeg=True, eog=True, ecg=True, exclude=[]) diff --git a/mne_bids_pipeline/_import_data.py b/mne_bids_pipeline/_import_data.py index 0168f602c..2a08fdcd1 100644 --- a/mne_bids_pipeline/_import_data.py +++ b/mne_bids_pipeline/_import_data.py @@ -6,6 +6,7 @@ import numpy as np import pandas as pd +from ._config_utils import get_channels_to_analyze from ._io import _read_json, _empty_room_match_path from ._logging import gen_log_kwargs, logger from ._run import _update_for_splits @@ -213,6 +214,9 @@ def _load_data(cfg: SimpleNamespace, bids_path: BIDSPath) -> mne.io.BaseRaw: raw = read_raw_bids(bids_path=bids_path, extra_params=cfg.reader_extra_params) + picks = get_channels_to_analyze(raw.info, cfg) + raw.pick(picks) + _crop_data(cfg, raw=raw, subject=subject) raw.load_data() diff --git a/mne_bids_pipeline/steps/preprocessing/_01_data_quality.py b/mne_bids_pipeline/steps/preprocessing/_01_data_quality.py index 811466095..83ae4badc 100644 --- a/mne_bids_pipeline/steps/preprocessing/_01_data_quality.py +++ b/mne_bids_pipeline/steps/preprocessing/_01_data_quality.py @@ -141,7 +141,6 @@ def _find_bads( bids_path_ref_in=None, bids_path_ref_bads_in=None, cfg=cfg, - datatype='meg', ) else: data_is_rest = (key == 'raw_rest') diff --git a/mne_bids_pipeline/steps/sensor/_03_decoding_time_by_time.py b/mne_bids_pipeline/steps/sensor/_03_decoding_time_by_time.py index 49b216bcc..047fc047b 100644 --- a/mne_bids_pipeline/steps/sensor/_03_decoding_time_by_time.py +++ b/mne_bids_pipeline/steps/sensor/_03_decoding_time_by_time.py @@ -323,7 +323,10 @@ def main(*, config: SimpleNamespace) -> None: # so we don't dispatch manually to multiple jobs. logs = [ run_time_decoding( - cfg=get_config(config=config), + cfg=get_config( + config=config, + ), + exec_params=config.exec_params, subject=subject, condition1=cond_1, condition2=cond_2, diff --git a/mne_bids_pipeline/steps/sensor/_06_make_cov.py b/mne_bids_pipeline/steps/sensor/_06_make_cov.py index 9f83c6359..9106f931e 100644 --- a/mne_bids_pipeline/steps/sensor/_06_make_cov.py +++ b/mne_bids_pipeline/steps/sensor/_06_make_cov.py @@ -84,7 +84,13 @@ def get_input_fnames_cov( def compute_cov_from_epochs( - *, cfg, subject, session, tmin, tmax, in_files, out_files): + cfg: SimpleNamespace, + exec_params: SimpleNamespace, + subject: str, + session: Optional[str], + in_files: dict, + out_files: dict, +) -> mne.Covariance: epo_fname = in_files.pop('epochs') msg = "Computing regularized covariance based on epochs' baseline periods." @@ -106,7 +112,14 @@ def compute_cov_from_epochs( return cov -def compute_cov_from_raw(*, cfg, subject, session, in_files, out_files): +def compute_cov_from_raw( + cfg: SimpleNamespace, + exec_params: SimpleNamespace, + subject: str, + session: Optional[str], + in_files: dict, + out_files: dict, +) -> mne.Covariance: fname_raw = in_files.pop('raw') data_type = 'resting-state' if fname_raw.task == 'rest' else 'empty-room' msg = f'Computing regularized covariance based on {data_type} recording.' @@ -126,9 +139,9 @@ def retrieve_custom_cov( exec_params: SimpleNamespace, subject: str, session: Optional[str], - in_files: None, + in_files: dict, out_files: dict, -): +) -> mne.Covariance: # This should be the only place we use config.noise_cov (rather than cfg.* # entries) config = _import_config( diff --git a/mne_bids_pipeline/steps/sensor/_99_group_average.py b/mne_bids_pipeline/steps/sensor/_99_group_average.py index f62033b33..b145a1641 100644 --- a/mne_bids_pipeline/steps/sensor/_99_group_average.py +++ b/mne_bids_pipeline/steps/sensor/_99_group_average.py @@ -380,6 +380,7 @@ def average_full_epochs_decoding( def average_csp_decoding( cfg: SimpleNamespace, session: str, + subject: str, condition_1: str, condition_2: str, ): @@ -622,7 +623,6 @@ def get_config( return cfg -# pass 'average' subject for logging @failsafe_run() def run_group_average_sensor( *, @@ -655,11 +655,12 @@ def run_group_average_sensor( average_time_by_time_decoding(cfg, session) if cfg.decode and cfg.decoding_csp: parallel, run_func = parallel_func( - average_csp_decoding) + average_csp_decoding, exec_params=exec_params) parallel( run_func( cfg=cfg, session=session, + subject=subject, condition_1=contrast[0], condition_2=contrast[1] ) @@ -670,6 +671,7 @@ def run_group_average_sensor( for session in sessions: run_report_average_sensor( cfg=cfg, + exec_params=exec_params, subject=subject, session=session, ) @@ -677,7 +679,9 @@ def run_group_average_sensor( def main(*, config: SimpleNamespace) -> None: log = run_group_average_sensor( - cfg=get_config(config=config), + cfg=get_config( + config=config, + ), exec_params=config.exec_params, subject='average', ) diff --git a/mne_bids_pipeline/steps/source/_99_group_average.py b/mne_bids_pipeline/steps/source/_99_group_average.py index f1ac6c88d..91cee4d69 100644 --- a/mne_bids_pipeline/steps/source/_99_group_average.py +++ b/mne_bids_pipeline/steps/source/_99_group_average.py @@ -189,7 +189,9 @@ def main(*, config: SimpleNamespace) -> None: return log = run_group_average_source( - cfg=get_config(config=config), + cfg=get_config( + config=config, + ), exec_params=config.exec_params, subject='average', ) diff --git a/mne_bids_pipeline/tests/conftest.py b/mne_bids_pipeline/tests/conftest.py index 3d105a1a9..01d3bf40c 100644 --- a/mne_bids_pipeline/tests/conftest.py +++ b/mne_bids_pipeline/tests/conftest.py @@ -24,6 +24,7 @@ def pytest_configure(config): always::ResourceWarning ignore:subprocess .* is still running:ResourceWarning ignore:`np.MachAr` is deprecated.*:DeprecationWarning + ignore:The get_cmap function will be deprecated.*: """ for warning_line in warning_lines.split('\n'): warning_line = warning_line.strip() From 3f7aa51d41d1a74275a293db899afbd0f1c17cfa Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Wed, 16 Nov 2022 17:26:37 -0500 Subject: [PATCH 3/7] FIX: Flake --- mne_bids_pipeline/steps/sensor/_06_make_cov.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mne_bids_pipeline/steps/sensor/_06_make_cov.py b/mne_bids_pipeline/steps/sensor/_06_make_cov.py index 9106f931e..23022ef30 100644 --- a/mne_bids_pipeline/steps/sensor/_06_make_cov.py +++ b/mne_bids_pipeline/steps/sensor/_06_make_cov.py @@ -84,6 +84,9 @@ def get_input_fnames_cov( def compute_cov_from_epochs( + *, + tmin: Optional[float], + tmax: Optional[float], cfg: SimpleNamespace, exec_params: SimpleNamespace, subject: str, @@ -113,6 +116,7 @@ def compute_cov_from_epochs( def compute_cov_from_raw( + *, cfg: SimpleNamespace, exec_params: SimpleNamespace, subject: str, @@ -135,6 +139,7 @@ def compute_cov_from_raw( def retrieve_custom_cov( + *, cfg: SimpleNamespace, exec_params: SimpleNamespace, subject: str, From e865552970034160a69759bd9393bc352602fe17 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Wed, 16 Nov 2022 19:14:30 -0500 Subject: [PATCH 4/7] WIP: Closer [ci skip] --- mne_bids_pipeline/_import_data.py | 19 +++++- mne_bids_pipeline/_report.py | 19 +++--- .../steps/preprocessing/_01_data_quality.py | 67 +++++++++++-------- .../steps/preprocessing/_02_maxfilter.py | 4 +- .../preprocessing/_03_frequency_filter.py | 1 + .../steps/sensor/_99_group_average.py | 1 + 6 files changed, 72 insertions(+), 39 deletions(-) diff --git a/mne_bids_pipeline/_import_data.py b/mne_bids_pipeline/_import_data.py index 2a08fdcd1..92248f245 100644 --- a/mne_bids_pipeline/_import_data.py +++ b/mne_bids_pipeline/_import_data.py @@ -398,6 +398,7 @@ def import_er_data( bids_path_ref_in: Optional[BIDSPath], bids_path_er_bads_in: Optional[BIDSPath], bids_path_ref_bads_in: Optional[BIDSPath], + prepare_maxwell_filter: bool, ) -> mne.io.BaseRaw: """Import empty-room data. @@ -446,7 +447,7 @@ def import_er_data( raw_ref.info._check_consistency() raw_ref.pick_types(meg=True, exclude=[]) - if cfg.use_maxwell_filter: + if prepare_maxwell_filter: # We need to include any automatically found bad channels, if relevant. # TODO: This 'union' operation should affect the raw runs, too, # otherwise rank mismatches will still occur (eventually for some @@ -502,6 +503,7 @@ def _get_raw_paths( run: Optional[str], kind: Literal['raw', 'sss'], add_bads: bool = True, + include_mf_ref: bool = True, ) -> dict: # Construct the basenames of the files we wish to load, and of the empty- # room recording we wish to save. @@ -541,6 +543,7 @@ def _get_raw_paths( in_files=in_files, key=key, ) + orig_key = key if run == cfg.runs[0]: do = dict( @@ -576,6 +579,20 @@ def _get_raw_paths( in_files=in_files, key=key, ) + if include_mf_ref and task == 'noise': + key = 'raw_ref_noise' + in_files[key] = in_files[orig_key].copy().update( + run=cfg.mf_reference_run, check=True) + _update_for_splits( + in_files, key, single=True, allow_missing=True) + if not in_files[key].fpath.exists(): + in_files.pop(key) + elif add_bads: + _add_bads_file( + cfg=cfg, + in_files=in_files, + key=key, + ) return in_files diff --git a/mne_bids_pipeline/_report.py b/mne_bids_pipeline/_report.py index 5363e64b0..7861b5b90 100644 --- a/mne_bids_pipeline/_report.py +++ b/mne_bids_pipeline/_report.py @@ -1127,12 +1127,13 @@ def _add_raw( bids_path_in.run in cfg.plot_psd_for_runs or bids_path_in.task in cfg.plot_psd_for_runs ) - report.add_raw( - raw=bids_path_in, - title=title, - butterfly=5, - psd=plot_raw_psd, - tags=('raw', 'filtered', f'run-{bids_path_in.run}'), - # caption=bids_path_in.basename, # TODO upstream - replace=True, - ) + with mne.use_log_level('error'): + report.add_raw( + raw=bids_path_in, + title=title, + butterfly=5, + psd=plot_raw_psd, + tags=('raw', 'filtered', f'run-{bids_path_in.run}'), + # caption=bids_path_in.basename, # TODO upstream + replace=True, + ) diff --git a/mne_bids_pipeline/steps/preprocessing/_01_data_quality.py b/mne_bids_pipeline/steps/preprocessing/_01_data_quality.py index 83ae4badc..31a7324ce 100644 --- a/mne_bids_pipeline/steps/preprocessing/_01_data_quality.py +++ b/mne_bids_pipeline/steps/preprocessing/_01_data_quality.py @@ -10,7 +10,7 @@ from ..._config_utils import ( get_mf_cal_fname, get_mf_ctc_fname, get_subjects, get_sessions, - get_runs, get_task, get_datatype, + get_runs, get_task, get_datatype, get_mf_reference_run, ) from ..._import_data import ( _get_raw_paths, import_experimental_data, import_er_data, @@ -57,18 +57,33 @@ def assess_data_quality( """Assess data quality and find and mark bad channels.""" import matplotlib.pyplot as plt out_files = dict() + orig_run = run for key in list(in_files.keys()): bids_path_in = in_files.pop(key) - auto_scores = _find_bads( - cfg=cfg, - exec_params=exec_params, - bids_path_in=bids_path_in, - key=key, - subject=subject, - session=session, - run=run, - out_files=out_files, - ) + if key == 'raw_noise': + run = 'noise' + elif key == 'raw_rest': + run = 'rest' + else: + run = orig_run + if cfg.find_noisy_channels_meg or cfg.find_flat_channels_meg: + if key == 'raw_noise': + bids_path_ref_in = in_files.pop('raw_ref_noise') + else: + bids_path_ref_in = None + auto_scores = _find_bads_maxwell( + cfg=cfg, + exec_params=exec_params, + bids_path_in=bids_path_in, + bids_path_ref_in=bids_path_ref_in, + key=key, + subject=subject, + session=session, + run=run, + out_files=out_files, + ) + else: + auto_scores = None # Report with _open_report( @@ -85,7 +100,7 @@ def assess_data_quality( cfg=cfg, report=report, bids_path_in=bids_path_in, - title=f'Raw ({kind})', + title=f'Raw ({kind} run {run})', ) if cfg.find_noisy_channels_meg: assert auto_scores is not None @@ -109,20 +124,18 @@ def assess_data_quality( return out_files -def _find_bads( +def _find_bads_maxwell( *, cfg: SimpleNamespace, exec_params: SimpleNamespace, bids_path_in: BIDSPath, + bids_path_ref_in: Optional[BIDSPath], subject: str, session: Optional[str], run: str, key: str, out_files: dict, ): - if not (cfg.find_noisy_channels_meg or cfg.find_flat_channels_meg): - return None - if (cfg.find_flat_channels_meg and not cfg.find_noisy_channels_meg): msg = 'Finding flat channels.' @@ -130,17 +143,18 @@ def _find_bads( not cfg.find_flat_channels_meg): msg = 'Finding noisy channels using Maxwell filtering.' else: - msg = ('Finding flat channels, and noisy channels using ' + msg = ('Finding flat channels and noisy channels using ' 'Maxwell filtering.') logger.info(**gen_log_kwargs(message=msg)) if key == 'raw_noise': raw = import_er_data( + cfg=cfg, bids_path_er_in=bids_path_in, bids_path_er_bads_in=None, - bids_path_ref_in=None, + bids_path_ref_in=bids_path_ref_in, bids_path_ref_bads_in=None, - cfg=cfg, + prepare_maxwell_filter=True, ) else: data_is_rest = (key == 'raw_rest') @@ -241,19 +255,20 @@ def get_config( subject: str, session: Optional[str], ) -> SimpleNamespace: + extra_kwargs = dict() if config.find_noisy_channels_meg or config.find_flat_channels_meg: - mf_cal_fname = get_mf_cal_fname( + extra_kwargs['mf_cal_fname'] = get_mf_cal_fname( config=config, subject=subject, session=session, ) - mf_ctc_fname = get_mf_ctc_fname( + extra_kwargs['mf_ctc_fname'] = get_mf_ctc_fname( config=config, subject=subject, session=session, ) - else: - mf_cal_fname = mf_ctc_fname = None + extra_kwargs['mf_reference_run'] = get_mf_reference_run(config=config) + extra_kwargs['mf_head_origin' ] = config.mf_head_origin cfg = SimpleNamespace( process_empty_room=config.process_empty_room, process_rest=config.process_rest, @@ -267,9 +282,6 @@ def get_config( space=config.space, bids_root=config.bids_root, deriv_root=config.deriv_root, - mf_cal_fname=mf_cal_fname, - mf_ctc_fname=mf_ctc_fname, - mf_head_origin=config.mf_head_origin, reader_extra_params=config.reader_extra_params, crop_runs=config.crop_runs, rename_events=config.rename_events, @@ -290,6 +302,7 @@ def get_config( eog_channels=config.eog_channels, on_rename_missing_events=config.on_rename_missing_events, plot_psd_for_runs=config.plot_psd_for_runs, + **extra_kwargs, ) return cfg @@ -308,7 +321,7 @@ def main(*, config: SimpleNamespace) -> None: exec_params=config.exec_params, subject=subject, session=session, - run=run + run=run, ) for subject in get_subjects(config) for session in get_sessions(config) diff --git a/mne_bids_pipeline/steps/preprocessing/_02_maxfilter.py b/mne_bids_pipeline/steps/preprocessing/_02_maxfilter.py index 5d14721be..cc89813c6 100644 --- a/mne_bids_pipeline/steps/preprocessing/_02_maxfilter.py +++ b/mne_bids_pipeline/steps/preprocessing/_02_maxfilter.py @@ -51,8 +51,7 @@ def get_input_fnames_maxwell_filter( ) ref_bids_path = list(in_files.values())[0].copy().update( run=cfg.mf_reference_run, - extension='.fif', - check=True + check=True, ) key = "raw_ref_run" in_files[key] = ref_bids_path @@ -200,6 +199,7 @@ def run_maxwell_filter( bids_path_ref_in=bids_path_ref_in, bids_path_er_bads_in=bids_path_noise_bads, bids_path_ref_bads_in=bids_path_ref_bads_in, + prepare_maxwell_filter=True, ) # Maxwell-filter noise data. diff --git a/mne_bids_pipeline/steps/preprocessing/_03_frequency_filter.py b/mne_bids_pipeline/steps/preprocessing/_03_frequency_filter.py index be5412308..2417fdcd5 100644 --- a/mne_bids_pipeline/steps/preprocessing/_03_frequency_filter.py +++ b/mne_bids_pipeline/steps/preprocessing/_03_frequency_filter.py @@ -180,6 +180,7 @@ def filter_data( bids_path_er_bads_in=bids_path_noise_bads, # take bads from this run (0) bids_path_ref_bads_in=bids_path_bads_in, + prepare_maxwell_filter=False, ) else: raw_noise = import_experimental_data( diff --git a/mne_bids_pipeline/steps/sensor/_99_group_average.py b/mne_bids_pipeline/steps/sensor/_99_group_average.py index b145a1641..584a488c2 100644 --- a/mne_bids_pipeline/steps/sensor/_99_group_average.py +++ b/mne_bids_pipeline/steps/sensor/_99_group_average.py @@ -373,6 +373,7 @@ def average_full_epochs_decoding( del bootstrapped_means, se, ci_lower, ci_upper fname_out = fname_mat.copy().update(subject='average') + fname_out.parent.mkdir(exist_ok=True, parents=True) savemat(fname_out, contrast_score_stats) del contrast_score_stats, fname_out From 903902dcce5956ce24a11c72b9b4598af3453b8c Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Fri, 18 Nov 2022 14:36:49 -0500 Subject: [PATCH 5/7] FIX: More working --- mne_bids_pipeline/_import_data.py | 15 ++++---- mne_bids_pipeline/_report.py | 17 ++++++++- .../steps/preprocessing/_01_data_quality.py | 37 ++++++++++++++----- .../steps/preprocessing/_02_maxfilter.py | 27 ++++++++------ .../preprocessing/_03_frequency_filter.py | 6 ++- .../steps/sensor/_99_group_average.py | 7 ++-- .../steps/source/_01_make_bem_surfaces.py | 5 +-- .../steps/source/_02_make_bem_solution.py | 24 +++++++++--- .../steps/source/_03_setup_source_space.py | 9 +++-- .../steps/source/_04_make_forward.py | 6 ++- .../steps/source/_05_make_inverse.py | 17 ++++++--- 11 files changed, 115 insertions(+), 55 deletions(-) diff --git a/mne_bids_pipeline/_import_data.py b/mne_bids_pipeline/_import_data.py index 92248f245..70b1ce451 100644 --- a/mne_bids_pipeline/_import_data.py +++ b/mne_bids_pipeline/_import_data.py @@ -6,7 +6,6 @@ import numpy as np import pandas as pd -from ._config_utils import get_channels_to_analyze from ._io import _read_json, _empty_room_match_path from ._logging import gen_log_kwargs, logger from ._run import _update_for_splits @@ -214,9 +213,6 @@ def _load_data(cfg: SimpleNamespace, bids_path: BIDSPath) -> mne.io.BaseRaw: raw = read_raw_bids(bids_path=bids_path, extra_params=cfg.reader_extra_params) - picks = get_channels_to_analyze(raw.info, cfg) - raw.pick(picks) - _crop_data(cfg, raw=raw, subject=subject) raw.load_data() @@ -441,7 +437,7 @@ def import_er_data( if bids_path_ref_bads_in is not None: bads = _read_bads_tsv( cfg=cfg, - bids_path_in=bids_path_ref_bads_in, + bids_path_bads=bids_path_ref_bads_in, ) raw_ref.info['bads'] = bads raw_ref.info._check_consistency() @@ -580,9 +576,9 @@ def _get_raw_paths( key=key, ) if include_mf_ref and task == 'noise': - key = 'raw_ref_noise' + key = 'raw_ref_run' in_files[key] = in_files[orig_key].copy().update( - run=cfg.mf_reference_run, check=True) + run=cfg.mf_reference_run) _update_for_splits( in_files, key, single=True, allow_missing=True) if not in_files[key].fpath.exists(): @@ -604,7 +600,10 @@ def _add_bads_file( key: str, ) -> None: bids_path_in = in_files[key] - bads_tsv_fname = _bads_path(cfg=cfg, bids_path_in=bids_path_in) + bads_tsv_fname = _bads_path( + cfg=cfg, + bids_path_in=bids_path_in, + ) if bads_tsv_fname.fpath.is_file(): in_files[f'{key}-bads'] = bads_tsv_fname diff --git a/mne_bids_pipeline/_report.py b/mne_bids_pipeline/_report.py index 7861b5b90..cbdf73e1e 100644 --- a/mne_bids_pipeline/_report.py +++ b/mne_bids_pipeline/_report.py @@ -69,7 +69,13 @@ def _open_report( try: msg = 'Adding config and sys info to report' logger.info(**gen_log_kwargs(message=msg)) - _finalize(report=report, exec_params=exec_params) + _finalize( + report=report, + exec_params=exec_params, + subject=subject, + session=session, + run=run, + ) except Exception: pass fname_report_html = fname_report.with_suffix('.html') @@ -467,7 +473,14 @@ def add_event_counts(*, report.add_custom_css(css=css) -def _finalize(*, report: mne.Report, exec_params: SimpleNamespace): +def _finalize( + *, + report: mne.Report, + exec_params: SimpleNamespace, + subject: str, + session: Optional[str], + run: Optional[str], +) -> None: """Add system information and the pipeline configuration to the report.""" # ensure they are always appended titles = ['Configuration file', 'System information'] diff --git a/mne_bids_pipeline/steps/preprocessing/_01_data_quality.py b/mne_bids_pipeline/steps/preprocessing/_01_data_quality.py index 31a7324ce..8279d4720 100644 --- a/mne_bids_pipeline/steps/preprocessing/_01_data_quality.py +++ b/mne_bids_pipeline/steps/preprocessing/_01_data_quality.py @@ -6,6 +6,7 @@ import pandas as pd import mne +from mne.utils import _pl from mne_bids import BIDSPath from ..._config_utils import ( @@ -31,6 +32,7 @@ def get_input_fnames_data_quality( run: str, ) -> dict: """Get paths of files required by maxwell_filter function.""" + include_mf_ref = _do_mf_autobad(cfg=cfg) in_files = _get_raw_paths( cfg=cfg, subject=subject, @@ -38,6 +40,7 @@ def get_input_fnames_data_quality( run=run, kind='orig', add_bads=False, + include_mf_ref=include_mf_ref, ) return in_files @@ -58,17 +61,19 @@ def assess_data_quality( import matplotlib.pyplot as plt out_files = dict() orig_run = run - for key in list(in_files.keys()): + # raw_ref_run will be .pop()ed inside this loop, do not include it + raw_keys = list(key for key in in_files.keys() if key != 'raw_ref_run') + for key in raw_keys: bids_path_in = in_files.pop(key) if key == 'raw_noise': run = 'noise' elif key == 'raw_rest': run = 'rest' - else: + else: # raw_run-{run} run = orig_run - if cfg.find_noisy_channels_meg or cfg.find_flat_channels_meg: + if _do_mf_autobad(cfg=cfg): if key == 'raw_noise': - bids_path_ref_in = in_files.pop('raw_ref_noise') + bids_path_ref_in = in_files.pop('raw_ref_run') else: bids_path_ref_in = None auto_scores = _find_bads_maxwell( @@ -94,7 +99,7 @@ def assess_data_quality( run=run) as report: # Original data kind = 'original' if not cfg.proc else cfg.proc - msg = f'Adding {kind} raw data to report.' + msg = f'Adding {kind} raw data to report' logger.info(**gen_log_kwargs(message=msg)) _add_raw( cfg=cfg, @@ -104,7 +109,7 @@ def assess_data_quality( ) if cfg.find_noisy_channels_meg: assert auto_scores is not None - msg = 'Adding noisy channel detection to report.' + msg = 'Adding noisy channel detection to report' logger.info(**gen_log_kwargs(message=msg)) figs = plot_auto_scores(auto_scores, ch_types=cfg.ch_types) captions = [f'Run {run}'] * len(figs) @@ -194,8 +199,11 @@ def _find_bads_maxwell( if cfg.find_noisy_channels_meg: if auto_noisy_chs: - msg = (f'Found {len(auto_noisy_chs)} noisy channels: ' - f'{", ".join(auto_noisy_chs)}') + msg = ( + f'Found {len(auto_noisy_chs)} noisy ' + f'channel{_pl(auto_noisy_chs)}: ' + f'{", ".join(auto_noisy_chs)}' + ) else: msg = 'Found no noisy channels.' @@ -213,10 +221,15 @@ def _find_bads_maxwell( cfg=cfg, bids_path_in=bids_path_in, ) + if not out_files['auto_scores'].fpath.parent.exists(): + out_files['auto_scores'].fpath.parent.mkdir(parents=True) _write_json(out_files['auto_scores'], auto_scores) # Write the bad channels to disk. - out_files['bads_tsv'] = _bads_path(cfg=cfg, bids_path_in=bids_path_in) + out_files['bads_tsv'] = _bads_path( + cfg=cfg, + bids_path_in=bids_path_in, + ) bads_for_tsv = [] reasons = [] @@ -268,7 +281,7 @@ def get_config( session=session, ) extra_kwargs['mf_reference_run'] = get_mf_reference_run(config=config) - extra_kwargs['mf_head_origin' ] = config.mf_head_origin + extra_kwargs['mf_head_origin'] = config.mf_head_origin cfg = SimpleNamespace( process_empty_room=config.process_empty_room, process_rest=config.process_rest, @@ -329,3 +342,7 @@ def main(*, config: SimpleNamespace) -> None: ) save_logs(config=config, logs=logs) + + +def _do_mf_autobad(*, cfg: SimpleNamespace) -> bool: + return cfg.find_noisy_channels_meg or cfg.find_flat_channels_meg diff --git a/mne_bids_pipeline/steps/preprocessing/_02_maxfilter.py b/mne_bids_pipeline/steps/preprocessing/_02_maxfilter.py index cc89813c6..405e30ab9 100644 --- a/mne_bids_pipeline/steps/preprocessing/_02_maxfilter.py +++ b/mne_bids_pipeline/steps/preprocessing/_02_maxfilter.py @@ -164,7 +164,8 @@ def run_maxwell_filter( cfg=cfg, exec_params=exec_params, subject=subject, - session=session) as report: + session=session, + run=run) as report: msg = 'Adding Maxwell filtered raw data to report.' _add_raw( cfg=cfg, @@ -197,7 +198,11 @@ def run_maxwell_filter( cfg=cfg, bids_path_er_in=bids_path_noise, bids_path_ref_in=bids_path_ref_in, - bids_path_er_bads_in=bids_path_noise_bads, + # TODO: This can break processing, need to use union for all, + # otherwise can get for ds003392: + # "Reference run data rank does not match empty-room data rank" + # bids_path_er_bads_in=bids_path_noise_bads, + bids_path_er_bads_in=None, bids_path_ref_bads_in=bids_path_ref_bads_in, prepare_maxwell_filter=True, ) @@ -257,18 +262,18 @@ def run_maxwell_filter( _update_for_splits(out_files, in_key) del raw_noise_sss - with _open_report( - cfg=cfg, - exec_params=exec_params, - subject=subject, - session=session) as report: - msg = 'Adding Maxwell filtered raw data to report.' - logger.info(**gen_log_kwargs(message=msg)) - for fname in out_files.values(): + with _open_report( + cfg=cfg, + exec_params=exec_params, + subject=subject, + session=session, + run=task) as report: + msg = 'Adding Maxwell filtered raw data to report' + logger.info(**gen_log_kwargs(message=msg, run=task)) _add_raw( cfg=cfg, report=report, - bids_path_in=fname, + bids_path_in=out_files[in_key], title='Raw (maxwell filtered)', ) diff --git a/mne_bids_pipeline/steps/preprocessing/_03_frequency_filter.py b/mne_bids_pipeline/steps/preprocessing/_03_frequency_filter.py index 2417fdcd5..31615882f 100644 --- a/mne_bids_pipeline/steps/preprocessing/_03_frequency_filter.py +++ b/mne_bids_pipeline/steps/preprocessing/_03_frequency_filter.py @@ -48,6 +48,7 @@ def get_input_fnames_frequency_filter( session=session, run=run, kind=kind, + include_mf_ref=False, ) @@ -222,8 +223,9 @@ def filter_data( cfg=cfg, exec_params=exec_params, subject=subject, - session=session) as report: - msg = 'Adding filtered raw data to report.' + session=session, + run=run) as report: + msg = 'Adding filtered raw data to report' logger.info(**gen_log_kwargs(message=msg)) for fname in out_files.values(): _add_raw( diff --git a/mne_bids_pipeline/steps/sensor/_99_group_average.py b/mne_bids_pipeline/steps/sensor/_99_group_average.py index 584a488c2..a527dc13f 100644 --- a/mne_bids_pipeline/steps/sensor/_99_group_average.py +++ b/mne_bids_pipeline/steps/sensor/_99_group_average.py @@ -36,8 +36,8 @@ def average_evokeds( # Container for all conditions: all_evokeds = defaultdict(list) - for subject in cfg.subjects: - fname_in = BIDSPath(subject=subject, + for this_subject in cfg.subjects: + fname_in = BIDSPath(subject=this_subject, session=session, task=cfg.task, acquisition=cfg.acq, @@ -373,7 +373,8 @@ def average_full_epochs_decoding( del bootstrapped_means, se, ci_lower, ci_upper fname_out = fname_mat.copy().update(subject='average') - fname_out.parent.mkdir(exist_ok=True, parents=True) + if not fname_out.fpath.parent.exists(): + os.makedirs(fname_out.fpath.parent) savemat(fname_out, contrast_score_stats) del contrast_score_stats, fname_out diff --git a/mne_bids_pipeline/steps/source/_01_make_bem_surfaces.py b/mne_bids_pipeline/steps/source/_01_make_bem_surfaces.py index 165f01704..66713d5e2 100644 --- a/mne_bids_pipeline/steps/source/_01_make_bem_surfaces.py +++ b/mne_bids_pipeline/steps/source/_01_make_bem_surfaces.py @@ -96,11 +96,10 @@ def make_bem_surfaces( def get_config( *, - config, + config: SimpleNamespace, subject: str, ) -> SimpleNamespace: cfg = SimpleNamespace( - exec_params=config.exec_params, fs_subject=get_fs_subject(config=config, subject=subject), fs_subjects_dir=get_fs_subjects_dir(config=config), bem_mri_images=config.bem_mri_images, @@ -111,7 +110,7 @@ def get_config( return cfg -def main(*, config) -> None: +def main(*, config: SimpleNamespace) -> None: """Run BEM surface extraction.""" if not config.run_source_estimation: msg = 'Skipping, run_source_estimation is set to False …' diff --git a/mne_bids_pipeline/steps/source/_02_make_bem_solution.py b/mne_bids_pipeline/steps/source/_02_make_bem_solution.py index e30b3bda7..57c029f7d 100644 --- a/mne_bids_pipeline/steps/source/_02_make_bem_solution.py +++ b/mne_bids_pipeline/steps/source/_02_make_bem_solution.py @@ -15,7 +15,11 @@ from ..._run import failsafe_run, save_logs -def get_input_fnames_make_bem_solution(*, cfg, subject): +def get_input_fnames_make_bem_solution( + *, + cfg: SimpleNamespace, + subject: str, +) -> dict: in_files = dict() conductivity, _ = _get_bem_conductivity(cfg) n_layers = len(conductivity) @@ -25,7 +29,11 @@ def get_input_fnames_make_bem_solution(*, cfg, subject): return in_files -def get_output_fnames_make_bem_solution(*, cfg, subject): +def get_output_fnames_make_bem_solution( + *, + cfg: SimpleNamespace, + subject: str, +) -> dict: out_files = dict() bem_dir = Path(cfg.fs_subjects_dir) / cfg.fs_subject / 'bem' _, tag = _get_bem_conductivity(cfg) @@ -38,7 +46,13 @@ def get_output_fnames_make_bem_solution(*, cfg, subject): get_input_fnames=get_input_fnames_make_bem_solution, get_output_fnames=get_output_fnames_make_bem_solution, ) -def make_bem_solution(*, cfg, subject, in_files): +def make_bem_solution( + *, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, + subject: str, + in_files: dict, +) -> dict: msg = 'Calculating BEM solution' logger.info(**gen_log_kwargs(message=msg, subject=subject)) conductivity, _ = _get_bem_conductivity(cfg) @@ -54,11 +68,10 @@ def make_bem_solution(*, cfg, subject, in_files): def get_config( *, - config, + config: SimpleNamespace, subject: str, ) -> SimpleNamespace: cfg = SimpleNamespace( - exec_params=config.exec_params, fs_subject=get_fs_subject(config=config, subject=subject), fs_subjects_dir=get_fs_subjects_dir(config), ch_types=config.ch_types, @@ -89,6 +102,7 @@ def main(*, config) -> None: logs = parallel( run_func( cfg=get_config(config=config, subject=subject), + exec_params=config.exec_params, subject=subject, force_run=config.recreate_bem) for subject in get_subjects(config) diff --git a/mne_bids_pipeline/steps/source/_03_setup_source_space.py b/mne_bids_pipeline/steps/source/_03_setup_source_space.py index 191b2aed9..31f05a859 100644 --- a/mne_bids_pipeline/steps/source/_03_setup_source_space.py +++ b/mne_bids_pipeline/steps/source/_03_setup_source_space.py @@ -48,11 +48,10 @@ def run_setup_source_space(*, cfg, subject, in_files): def get_config( *, - config, + config: SimpleNamespace, subject: str, ) -> SimpleNamespace: cfg = SimpleNamespace( - exec_params=config.exec_params, spacing=config.spacing, use_template_mri=config.use_template_mri, fs_subject=get_fs_subject(config=config, subject=subject), @@ -61,7 +60,7 @@ def get_config( return cfg -def main(*, config) -> None: +def main(*, config: SimpleNamespace) -> None: """Run forward.""" if not config.run_source_estimation: msg = 'Skipping, run_source_estimation is set to False …' @@ -82,7 +81,9 @@ def main(*, config) -> None: run_func( cfg=get_config( config=config, - subject=subject), + subject=subject, + ), + exec_params=config.exec_params, subject=subject, ) for subject in subjects diff --git a/mne_bids_pipeline/steps/source/_04_make_forward.py b/mne_bids_pipeline/steps/source/_04_make_forward.py index 208074d53..d3e144bcc 100644 --- a/mne_bids_pipeline/steps/source/_04_make_forward.py +++ b/mne_bids_pipeline/steps/source/_04_make_forward.py @@ -181,6 +181,8 @@ def run_forward( exec_params=exec_params, subject=subject, session=session) as report: + msg = 'Adding forward information to report' + logger.info(**gen_log_kwargs(message=msg)) msg = 'Rendering MRI slices with BEM contours.' logger.info(**gen_log_kwargs(message=msg)) report.add_bem( @@ -192,7 +194,7 @@ def run_forward( replace=True, n_jobs=1, # prevent automatic parallelization ) - msg = 'Rendering sensor alignment (coregistration).' + msg = 'Rendering sensor alignment (coregistration)' logger.info(**gen_log_kwargs(message=msg)) report.add_trans( trans=trans, @@ -203,7 +205,7 @@ def run_forward( alpha=1, replace=True, ) - msg = 'Rendering forward solution.' + msg = 'Rendering forward solution' logger.info(**gen_log_kwargs(message=msg)) report.add_forward( forward=fwd, diff --git a/mne_bids_pipeline/steps/source/_05_make_inverse.py b/mne_bids_pipeline/steps/source/_05_make_inverse.py index b131e2c54..a0c24a493 100644 --- a/mne_bids_pipeline/steps/source/_05_make_inverse.py +++ b/mne_bids_pipeline/steps/source/_05_make_inverse.py @@ -66,6 +66,8 @@ def run_inverse( ) -> dict: # TODO: Eventually we should maybe loop over ch_types, e.g., to create # MEG, EEG, and MEG+EEG inverses and STCs + msg = 'Computing inverse solutions' + logger.info(**gen_log_kwargs(message=msg)) fname_fwd = in_files.pop('forward') out_files = dict() out_files['inverse'] = fname_fwd.copy().update(suffix='inv') @@ -124,6 +126,8 @@ def run_inverse( exec_params=exec_params, subject=subject, session=session) as report: + msg = 'Adding inverse information to report' + logger.info(**gen_log_kwargs(message=msg)) for condition in conditions: cond_str = sanitize_cond_name(condition) key = f'{cond_str}+{method}+hemi' @@ -152,11 +156,10 @@ def run_inverse( def get_config( *, - config, - subject, + config: SimpleNamespace, + subject: str, ) -> SimpleNamespace: cfg = SimpleNamespace( - exec_params=config.exec_params, task=get_task(config), datatype=get_datatype(config), acq=config.acq, @@ -179,7 +182,7 @@ def get_config( return cfg -def main(*, config) -> None: +def main(*, config: SimpleNamespace) -> None: """Run inv.""" if not config.run_source_estimation: msg = 'Skipping, run_source_estimation is set to False …' @@ -191,7 +194,11 @@ def main(*, config) -> None: run_inverse, exec_params=config.exec_params) logs = parallel( run_func( - cfg=get_config(config=config, subject=subject), + cfg=get_config( + config=config, + subject=subject, + ), + exec_params=config.exec_params, subject=subject, session=session, ) From b7b2962af30965aaac8baeb6d7d5715b8654dd20 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Fri, 18 Nov 2022 16:34:16 -0500 Subject: [PATCH 6/7] FIX: Better --- mne_bids_pipeline/_import_data.py | 2 -- .../steps/preprocessing/_03_frequency_filter.py | 8 ++++---- mne_bids_pipeline/steps/source/_03_setup_source_space.py | 8 +++++++- mne_bids_pipeline/tests/configs/config_ds000117.py | 2 +- mne_bids_pipeline/tests/configs/config_ds003392.py | 1 + pyproject.toml | 2 +- 6 files changed, 14 insertions(+), 9 deletions(-) diff --git a/mne_bids_pipeline/_import_data.py b/mne_bids_pipeline/_import_data.py index 70b1ce451..45aa89561 100644 --- a/mne_bids_pipeline/_import_data.py +++ b/mne_bids_pipeline/_import_data.py @@ -469,8 +469,6 @@ def _find_breaks_func( run: Optional[str], ) -> None: if not cfg.find_breaks: - msg = 'Finding breaks has been disabled by the user.' - logger.info(**gen_log_kwargs(message=msg)) return msg = (f'Finding breaks with a minimum duration of ' diff --git a/mne_bids_pipeline/steps/preprocessing/_03_frequency_filter.py b/mne_bids_pipeline/steps/preprocessing/_03_frequency_filter.py index 31615882f..0f4d5fca8 100644 --- a/mne_bids_pipeline/steps/preprocessing/_03_frequency_filter.py +++ b/mne_bids_pipeline/steps/preprocessing/_03_frequency_filter.py @@ -144,7 +144,7 @@ def filter_data( h_freq=cfg.h_freq, l_freq=cfg.l_freq, h_trans_bandwidth=cfg.h_trans_bandwidth, l_trans_bandwidth=cfg.l_trans_bandwidth, - data_type='experimental' + data_type='experimental', ) resample(raw=raw, subject=subject, session=session, run=run, sfreq=cfg.resample_sfreq, data_type='experimental') @@ -197,13 +197,13 @@ def filter_data( raw_noise.load_data() filter( - raw=raw_noise, subject=subject, session=session, run=None, + raw=raw_noise, subject=subject, session=session, run=task, h_freq=cfg.h_freq, l_freq=cfg.l_freq, h_trans_bandwidth=cfg.h_trans_bandwidth, l_trans_bandwidth=cfg.l_trans_bandwidth, - data_type=data_type + data_type=data_type, ) - resample(raw=raw_noise, subject=subject, session=session, run=None, + resample(raw=raw_noise, subject=subject, session=session, run=task, sfreq=cfg.resample_sfreq, data_type=data_type) raw_noise.save( diff --git a/mne_bids_pipeline/steps/source/_03_setup_source_space.py b/mne_bids_pipeline/steps/source/_03_setup_source_space.py index 31f05a859..f7419e74e 100644 --- a/mne_bids_pipeline/steps/source/_03_setup_source_space.py +++ b/mne_bids_pipeline/steps/source/_03_setup_source_space.py @@ -34,7 +34,13 @@ def get_output_fnames_setup_source_space(*, cfg, subject): get_input_fnames=get_input_fnames_setup_source_space, get_output_fnames=get_output_fnames_setup_source_space, ) -def run_setup_source_space(*, cfg, subject, in_files): +def run_setup_source_space( + *, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, + subject: str, + in_files: dict, +) -> dict: msg = f'Creating source space with spacing {repr(cfg.spacing)}' logger.info(**gen_log_kwargs(message=msg, subject=subject)) src = mne.setup_source_space( diff --git a/mne_bids_pipeline/tests/configs/config_ds000117.py b/mne_bids_pipeline/tests/configs/config_ds000117.py index 918441d8f..b206fed0a 100644 --- a/mne_bids_pipeline/tests/configs/config_ds000117.py +++ b/mne_bids_pipeline/tests/configs/config_ds000117.py @@ -15,7 +15,7 @@ subjects = ['01'] resample_sfreq = 125. -crop_runs = (0, 350) # Reduce memory usage on CI system +crop_runs = (0, 300) # Reduce memory usage on CI system find_flat_channels_meg = False find_noisy_channels_meg = False diff --git a/mne_bids_pipeline/tests/configs/config_ds003392.py b/mne_bids_pipeline/tests/configs/config_ds003392.py index 2503b4233..bfa41d35d 100644 --- a/mne_bids_pipeline/tests/configs/config_ds003392.py +++ b/mne_bids_pipeline/tests/configs/config_ds003392.py @@ -16,6 +16,7 @@ l_freq = 1. h_freq = 40. resample_sfreq = 250 +crop_runs = (0, 180) # Artifact correction. spatial_filter = 'ica' diff --git a/pyproject.toml b/pyproject.toml index b373381b2..782867ec7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ dependencies = [ "joblib >= 0.14", "threadpoolctl", "dask[distributed]", - "bokeh", # for distributed dashboard + "bokeh < 3", # for distributed dashboard "jupyter-server-proxy", # to have dask and jupyter working together "scikit-learn", "pandas", From 7d15066ae86cf9888d1819ffd1aeb3069bdf8acd Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Sat, 19 Nov 2022 11:20:26 -0500 Subject: [PATCH 7/7] FIX: Mem --- mne_bids_pipeline/steps/source/_04_make_forward.py | 13 ++++++++++--- mne_bids_pipeline/tests/configs/config_ds000246.py | 3 ++- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/mne_bids_pipeline/steps/source/_04_make_forward.py b/mne_bids_pipeline/steps/source/_04_make_forward.py index d3e144bcc..2fe0f8b90 100644 --- a/mne_bids_pipeline/steps/source/_04_make_forward.py +++ b/mne_bids_pipeline/steps/source/_04_make_forward.py @@ -21,7 +21,11 @@ from ..._run import failsafe_run, save_logs -def _prepare_trans_template(cfg, info): +def _prepare_trans_template( + *, + cfg: SimpleNamespace, + info: mne.Info, +) -> mne.transforms.Transform: assert isinstance(cfg.use_template_mri, str) assert cfg.use_template_mri == cfg.fs_subject @@ -47,7 +51,7 @@ def _prepare_trans( cfg: SimpleNamespace, exec_params: SimpleNamespace, bids_path: BIDSPath, -): +) -> mne.transforms.Transform: # Generate a head ↔ MRI transformation matrix from the # electrophysiological and MRI sidecar files, and save it to an MNE # "trans" file in the derivatives folder. @@ -157,7 +161,10 @@ def run_forward( # trans if cfg.use_template_mri is not None: - trans = _prepare_trans_template(cfg, exec_params) + trans = _prepare_trans_template( + cfg=cfg, + info=info, + ) else: trans = _prepare_trans( cfg=cfg, diff --git a/mne_bids_pipeline/tests/configs/config_ds000246.py b/mne_bids_pipeline/tests/configs/config_ds000246.py index eff59d4bb..1dfea793a 100644 --- a/mne_bids_pipeline/tests/configs/config_ds000246.py +++ b/mne_bids_pipeline/tests/configs/config_ds000246.py @@ -10,7 +10,7 @@ deriv_root = '~/mne_data/derivatives/mne-bids-pipeline/ds000246' runs = ['01'] -crop_runs = (0, 180) # Reduce memory usage on CI system +crop_runs = (0, 120) # Reduce memory usage on CI system l_freq = 0.3 h_freq = 100 decim = 4 @@ -23,6 +23,7 @@ decoding_time_generalization = True decoding_time_generalization_decim = 4 on_error = 'abort' +plot_psd_for_runs = [] # too much memory on CIs parallel_backend = 'dask' dask_worker_memory_limit = '2G'