From eb9794a9f33e261f20fb4cc225de77d06f7d6b09 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Wed, 9 Nov 2022 15:04:28 -0500 Subject: [PATCH 01/10] ENH: Report along the way --- docs/mkdocs.yml | 2 - docs/source/changes.md | 3 + docs/source/getting_started/basic_usage.md | 5 - docs/source/settings/report/report.md | 3 - docs/source/settings/sensor/statistics.md | 1 + docs/source/settings/source/inverse.md | 1 + mne_bids_pipeline/_config_utils.py | 6 +- mne_bids_pipeline/_report.py | 1098 +++++++++ mne_bids_pipeline/config.py | 58 +- .../preprocessing/_02_frequency_filter.py | 56 +- .../scripts/preprocessing/_03_make_epochs.py | 70 + .../scripts/preprocessing/_04b_run_ssp.py | 42 +- .../scripts/preprocessing/_05a_apply_ica.py | 25 +- .../scripts/preprocessing/_06_ptp_reject.py | 21 +- .../scripts/report/_01_make_reports.py | 2129 ----------------- mne_bids_pipeline/scripts/report/__init__.py | 7 - .../scripts/sensor/_01_make_evoked.py | 50 +- .../sensor/_02_decoding_full_epochs.py | 59 +- .../sensor/_03_decoding_time_by_time.py | 84 +- .../scripts/sensor/_04_time_frequency.py | 46 + .../scripts/sensor/_05_decoding_csp.py | 153 +- .../scripts/sensor/_06_make_cov.py | 52 +- .../scripts/sensor/_99_group_average.py | 15 +- .../scripts/source/_04_make_forward.py | 39 + .../scripts/source/_05_make_inverse.py | 36 +- .../scripts/source/_99_group_average.py | 5 +- mne_bids_pipeline/tests/run_tests.py | 14 +- pyproject.toml | 1 + 28 files changed, 1850 insertions(+), 2231 deletions(-) delete mode 100644 docs/source/settings/report/report.md create mode 100644 mne_bids_pipeline/_report.py delete mode 100644 mne_bids_pipeline/scripts/report/_01_make_reports.py delete mode 100644 mne_bids_pipeline/scripts/report/__init__.py diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 723b6e7bc..5bfff43b5 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -59,8 +59,6 @@ nav: - BEM surface: settings/source/bem.md - Source space & forward solution: settings/source/forward.md - Inverse solution: settings/source/inverse.md - - Report: - - HTML Report: settings/report/report.md - Examples: - Examples Gallery: examples/examples.md - examples/ds003392.md diff --git a/docs/source/changes.md b/docs/source/changes.md index 94595f78f..85958c149 100644 --- a/docs/source/changes.md +++ b/docs/source/changes.md @@ -200,6 +200,9 @@ authors: {{ authors.hoechenberger }}, and {{ authors.larsoner }}) - Add progress bar for time-by-time decoding ({{ gh(647) by {{ authors.larsoner }}) +- Make report generation happen within relevant steps instead of at the end + of all steps + ({{ gh(652) by {{ authors.larsoner }}) ### Behavior changes diff --git a/docs/source/getting_started/basic_usage.md b/docs/source/getting_started/basic_usage.md index 1de7a23ee..654137c25 100644 --- a/docs/source/getting_started/basic_usage.md +++ b/docs/source/getting_started/basic_usage.md @@ -91,11 +91,6 @@ Run the pipeline mne_bids_pipeline --config=/path/to/your/custom_config.py --steps=source ``` - Only generate the report: - ```shell - mne_bids_pipeline --config=/path/to/your/custom_config.py --steps=report - ``` - (Re-)run ICA: ```shell mne_bids_pipeline --config=/path/to/your/custom_config.py --steps=preprocessing/ica diff --git a/docs/source/settings/report/report.md b/docs/source/settings/report/report.md deleted file mode 100644 index 66f5861d7..000000000 --- a/docs/source/settings/report/report.md +++ /dev/null @@ -1,3 +0,0 @@ - -::: mne_bids_pipeline.config.report_evoked_n_time_points -::: mne_bids_pipeline.config.report_stc_n_time_points diff --git a/docs/source/settings/sensor/statistics.md b/docs/source/settings/sensor/statistics.md index fc3d03490..0a618c5dd 100644 --- a/docs/source/settings/sensor/statistics.md +++ b/docs/source/settings/sensor/statistics.md @@ -1,4 +1,5 @@ ::: mne_bids_pipeline.config.contrasts +::: mne_bids_pipeline.config.report_evoked_n_time_points ::: mne_bids_pipeline.config.decode ::: mne_bids_pipeline.config.decoding_metric ::: mne_bids_pipeline.config.decoding_n_splits diff --git a/docs/source/settings/source/inverse.md b/docs/source/settings/source/inverse.md index f10b4ed55..af9d43d66 100644 --- a/docs/source/settings/source/inverse.md +++ b/docs/source/settings/source/inverse.md @@ -4,3 +4,4 @@ ::: mne_bids_pipeline.config.noise_cov ::: mne_bids_pipeline.config.source_info_path_update ::: mne_bids_pipeline.config.inverse_targets +::: mne_bids_pipeline.config.report_stc_n_time_points diff --git a/mne_bids_pipeline/_config_utils.py b/mne_bids_pipeline/_config_utils.py index f9232de28..f53d01cf6 100644 --- a/mne_bids_pipeline/_config_utils.py +++ b/mne_bids_pipeline/_config_utils.py @@ -578,14 +578,12 @@ def _get_script_modules() -> Dict[str, Tuple[ModuleType]]: from .scripts import preprocessing from .scripts import sensor from .scripts import source - from .scripts import report from .scripts import freesurfer INIT_SCRIPTS = init.SCRIPTS PREPROCESSING_SCRIPTS = preprocessing.SCRIPTS SENSOR_SCRIPTS = sensor.SCRIPTS SOURCE_SCRIPTS = source.SCRIPTS - REPORT_SCRIPTS = report.SCRIPTS FREESURFER_SCRIPTS = freesurfer.SCRIPTS SCRIPT_MODULES = { @@ -594,7 +592,6 @@ def _get_script_modules() -> Dict[str, Tuple[ModuleType]]: 'preprocessing': PREPROCESSING_SCRIPTS, 'sensor': SENSOR_SCRIPTS, 'source': SOURCE_SCRIPTS, - 'report': REPORT_SCRIPTS, } # Do not include the FreeSurfer scripts in "all" – we don't intend to run @@ -603,8 +600,7 @@ def _get_script_modules() -> Dict[str, Tuple[ModuleType]]: SCRIPT_MODULES['init'] + SCRIPT_MODULES['preprocessing'] + SCRIPT_MODULES['sensor'] + - SCRIPT_MODULES['source'] + - SCRIPT_MODULES['report'] + SCRIPT_MODULES['source'] ) return SCRIPT_MODULES diff --git a/mne_bids_pipeline/_report.py b/mne_bids_pipeline/_report.py new file mode 100644 index 000000000..c1f1f2c8e --- /dev/null +++ b/mne_bids_pipeline/_report.py @@ -0,0 +1,1098 @@ +import contextlib +import os +import os.path as op +from pathlib import Path +from typing import Optional, List, Literal +from types import SimpleNamespace + +from filelock import FileLock +import matplotlib.transforms +import numpy as np +import pandas as pd +from scipy.io import loadmat + +import mne +from mne.utils import _pl +from mne_bids import BIDSPath +from mne_bids.stats import count_events + +from ._config_utils import ( + sanitize_cond_name, get_subjects, _restrict_analyze_channels) +from ._decoding import _handle_csp_args +from ._logging import logger, gen_log_kwargs +from ._viz import plot_auto_scores + + +@contextlib.contextmanager +def _open_report( + *, + cfg: SimpleNamespace, + subject: str, + session: Optional[str] +): + fname_report = BIDSPath( + subject=subject, + session=session, + task=cfg.task, + acquisition=cfg.acq, + run=None, + recording=cfg.rec, + space=cfg.space, + extension='.h5', + datatype=cfg.datatype, + root=cfg.deriv_root, + suffix='report', + check=False + ).fpath + # prevent parallel file access + with FileLock(f'{fname_report}.lock'), _agg_backend(): + if not fname_report.is_file(): + msg = 'Initializing report HDF5 file' + logger.info( + **gen_log_kwargs(message=msg, subject=subject, session=session) + ) + report = _gen_empty_report( + cfg=cfg, + subject=subject, + session=session, + ) + report.save(fname_report) + try: + report = mne.open_report(fname_report) + except OSError as exc: + raise OSError( + f'Could not open report HDF5 file:\n{fname_report}\n' + f'Got error:\n{exc}\nPerhaps you need to delete it?') from None + add_system_info(report) + try: + yield report + finally: + report.save(fname_report, overwrite=True) + report.save( + fname_report.with_suffix('.html'), overwrite=True, + open_browser=False) + + +def plot_auto_scores_(cfg, subject, session): + """Plot automated bad channel detection scores. + """ + import json_tricks + fname_scores = BIDSPath(subject=subject, + session=session, + task=cfg.task, + acquisition=cfg.acq, + run=None, + processing=cfg.proc, + recording=cfg.rec, + space=cfg.space, + suffix='scores', + extension='.json', + datatype=cfg.datatype, + root=cfg.deriv_root, + check=False) + + all_figs = [] + all_captions = [] + for run in cfg.runs: + fname_scores.update(run=run) + auto_scores = json_tricks.loads( + fname_scores.fpath.read_text(encoding='utf-8-sig') + ) + + figs = plot_auto_scores(auto_scores, ch_types=cfg.ch_types) + all_figs.extend(figs) + + # Could be more than 1 fig, e.g. "grad" and "mag" + captions = [f'Run {run}'] * len(figs) + all_captions.extend(captions) + + assert all_figs + return all_figs, all_captions + + +# def plot_full_epochs_decoding_scores( +# contrast: str, +# cross_val_scores: np.ndarray, +# metric: str +# ): +# """Plot cross-validation results from full-epochs decoding. +# """ +# import matplotlib.pyplot as plt # nested import to help joblib +# import seaborn as sns + +# cross_val_scores = cross_val_scores.squeeze() # Make it a 1D array +# data = pd.DataFrame({ +# 'contrast': [contrast] * len(cross_val_scores), +# 'scores': cross_val_scores, +# 'metric': [metric] * len(cross_val_scores)} +# ) +# fig, ax = plt.subplots(constrained_layout=True) + +# sns.swarmplot(x='contrast', y='scores', data=data, color='0.25', +# label='cross-val. scores') +# ax.set_xticklabels([]) + +# ax.plot(cross_val_scores.mean(), '+', color='red', ms=15, +# label='mean score', zorder=99) +# ax.axhline(0.5, ls='--', lw=0.5, color='black', label='chance') + +# ax.set_xlabel(f'{contrast[0]} vs. {contrast[1]}') +# if metric == 'roc_auc': +# metric = 'ROC AUC' +# ax.set_ylabel(f'Score ({metric})') +# ax.legend(loc='center right') + +# return fig + + +def _plot_full_epochs_decoding_scores( + contrast_names: List[str], + scores: List[np.ndarray], + metric: str, + kind: Literal['single-subject', 'grand-average'] = 'single-subject', +): + """Plot cross-validation results from full-epochs decoding. + """ + import matplotlib.pyplot as plt # nested import to help joblib + import seaborn as sns + + if metric == 'roc_auc': + metric = 'ROC AUC' + score_label = f'Score ({metric})' + + data = pd.DataFrame({ + 'Contrast': np.array([ + [c] * len(scores[0]) + for c in contrast_names + ]).flatten(), + score_label: np.hstack(scores), + }) + + if kind == 'grand-average': + # First create a grid of boxplots … + g = sns.catplot( + data=data, y=score_label, kind='box', + col='Contrast', col_wrap=3, aspect=0.33 + ) + # … and now add swarmplots on top to visualize every single data point. + g.map_dataframe(sns.swarmplot, y=score_label, color='black') + caption = ( + f'Based on N={len(scores[0])} ' + f'subjects. Each dot represents the mean cross-validation score ' + f'for a single subject. The dashed line is expected chance ' + f'performance.' + ) + else: + # First create a grid of swarmplots to visualize every single + # cross-validation score. + g = sns.catplot( + data=data, y=score_label, kind='swarm', + col='Contrast', col_wrap=3, aspect=0.33, color='black' + ) + + # And now add the mean CV score on top. + def _plot_mean_cv_score(x, **kwargs): + plt.plot(x.mean(), **kwargs) + + g.map( + _plot_mean_cv_score, score_label, marker='+', color='red', + ms=15, label='mean score', zorder=99 + ) + caption = ( + 'Each black dot represents the single cross-validation score. ' + f'The red cross is the mean of all {len(scores[0])} ' + 'cross-validation scores. ' + 'The dashed line is expected chance performance.' + ) + plt.xlim([-0.1, 0.1]) + + g.map(plt.axhline, y=0.5, ls='--', lw=0.5, color='black', zorder=99) + g.set_titles('{col_name}') # use this argument literally! + g.set_xlabels('') + + fig = g.fig + return fig, caption + + +def _plot_time_by_time_decoding_scores( + *, + times: np.ndarray, + cross_val_scores: np.ndarray, + metric: str, + time_generalization: bool, + decim: int, +): + """Plot cross-validation results from time-by-time decoding. + """ + import matplotlib.pyplot as plt # nested import to help joblib + + mean_scores = cross_val_scores.mean(axis=0) + max_scores = cross_val_scores.max(axis=0) + min_scores = cross_val_scores.min(axis=0) + + if time_generalization: + # Only use the diagonal values (classifiers trained and tested on the + # same time points). + mean_scores = np.diag(mean_scores) + max_scores = np.diag(max_scores) + min_scores = np.diag(min_scores) + + fig, ax = plt.subplots(constrained_layout=True) + ax.axhline(0.5, ls='--', lw=0.5, color='black', label='chance') + if times.min() < 0 < times.max(): + ax.axvline(0, ls='-', lw=0.5, color='black') + ax.fill_between(x=times, y1=min_scores, y2=max_scores, color='lightgray', + alpha=0.5, label='range [min, max]') + ax.plot(times, mean_scores, ls='-', lw=2, color='black', + label='mean') + + _label_time_by_time(ax, xlabel='Time (s)', decim=decim) + if metric == 'roc_auc': + metric = 'ROC AUC' + ax.set_ylabel(f'Score ({metric})') + ax.set_ylim((-0.025, 1.025)) + ax.legend(loc='lower right') + + return fig + + +def _label_time_by_time(ax, *, decim, xlabel=None, ylabel=None): + extra = '' + if decim > 1: + extra = f' (decim={decim})' + if xlabel is not None: + ax.set_xlabel(f'{xlabel}{extra}') + if ylabel is not None: + ax.set_ylabel(f'{ylabel}{extra}') + + +def _plot_time_by_time_decoding_scores_gavg(*, cfg, decoding_data): + """Plot the grand-averaged decoding scores. + """ + import matplotlib.pyplot as plt # nested import to help joblib + + # We squeeze() to make Matplotlib happy. + times = decoding_data['times'].squeeze() + mean_scores = decoding_data['mean'].squeeze() + se_lower = mean_scores - decoding_data['mean_se'].squeeze() + se_upper = mean_scores + decoding_data['mean_se'].squeeze() + ci_lower = decoding_data['mean_ci_lower'].squeeze() + ci_upper = decoding_data['mean_ci_upper'].squeeze() + decim = decoding_data['decim'].item() + + if cfg.decoding_time_generalization: + # Only use the diagonal values (classifiers trained and tested on the + # same time points). + mean_scores = np.diag(mean_scores) + se_lower = np.diag(se_lower) + se_upper = np.diag(se_upper) + ci_lower = np.diag(ci_lower) + ci_upper = np.diag(ci_upper) + + metric = cfg.decoding_metric + clusters = np.atleast_1d(decoding_data['clusters'].squeeze()) + + fig, ax = plt.subplots(constrained_layout=True) + ax.set_ylim((-0.025, 1.025)) + + # Start with plotting the significant time periods according to the + # cluster-based permutation test + n_significant_clusters_plotted = 0 + for cluster in clusters: + cluster_times = np.atleast_1d(cluster['times'][0][0].squeeze()) + cluster_p = cluster['p_value'][0][0].item() + if cluster_p >= cfg.cluster_permutation_p_threshold: + continue + + # Only add the label once + if n_significant_clusters_plotted == 0: + label = (f'$p$ < {cfg.cluster_permutation_p_threshold} ' + f'(cluster pemutation)') + else: + label = None + + ax.fill_betweenx( + y=ax.get_ylim(), + x1=cluster_times[0], + x2=cluster_times[-1], + facecolor='orange', + alpha=0.15, + label=label + ) + n_significant_clusters_plotted += 1 + + ax.axhline(0.5, ls='--', lw=0.5, color='black', label='chance') + if times.min() < 0 < times.max(): + ax.axvline(0, ls='-', lw=0.5, color='black') + ax.fill_between(x=times, y1=ci_lower, y2=ci_upper, color='lightgray', + alpha=0.5, label='95% confidence interval') + + ax.plot(times, mean_scores, ls='-', lw=2, color='black', + label='mean') + ax.plot(times, se_lower, ls='-.', lw=0.5, color='gray', + label='mean ± standard error') + ax.plot(times, se_upper, ls='-.', lw=0.5, color='gray') + ax.text(0.05, 0.05, s=f'$N$={decoding_data["N"].squeeze()}', + fontsize='x-large', horizontalalignment='left', + verticalalignment='bottom', transform=ax.transAxes) + + _label_time_by_time(ax, xlabel='Time (s)', decim=decim) + if metric == 'roc_auc': + metric = 'ROC AUC' + ax.set_ylabel(f'Score ({metric})') + ax.legend(loc='lower right') + + return fig + + +def plot_time_by_time_decoding_t_values(decoding_data): + """Plot the t-values used to form clusters for the permutation test. + """ + import matplotlib.pyplot as plt # nested import to help joblib + + # We squeeze() to make Matplotlib happy. + all_times = decoding_data['cluster_all_times'].squeeze() + all_t_values = decoding_data['cluster_all_t_values'].squeeze() + t_threshold = decoding_data['cluster_t_threshold'] + decim = decoding_data['decim'] + + fig, ax = plt.subplots(constrained_layout=True) + ax.plot(all_times, all_t_values, ls='-', color='black', + label='observed $t$-values') + ax.axhline(t_threshold, ls='--', color='red', label='threshold') + + ax.text(0.05, 0.05, s=f'$N$={decoding_data["N"].squeeze()}', + fontsize='x-large', horizontalalignment='left', + verticalalignment='bottom', transform=ax.transAxes) + + _label_time_by_time(ax, xlabel='Time (s)', decim=decim) + ax.set_ylabel('$t$-value') + ax.legend(loc='lower right') + + if all_t_values.min() < 0 and all_t_values.max() > 0: + # center the y axis around 0 + y_max = np.abs(ax.get_ylim()).max() + ax.set_ylim(ymin=-y_max, ymax=y_max) + elif all_t_values.min() > 0 and all_t_values.max() > 0: + # start y axis at zero + ax.set_ylim(ymin=0, ymax=all_t_values.max()) + elif all_t_values.min() < 0 and all_t_values.max() < 0: + # start y axis at zero + ax.set_ylim(ymin=all_t_values.min(), ymax=0) + + return fig + + +def _plot_decoding_time_generalization( + decoding_data, + metric: str, + kind: Literal['single-subject', 'grand-average'] +): + """Plot time generalization matrix. + """ + import matplotlib.pyplot as plt # nested import to help joblib + + # We squeeze() to make Matplotlib happy. + times = decoding_data['times'].squeeze() + decim = decoding_data['decim'].item() + if kind == 'single-subject': + # take the mean across CV scores + mean_scores = decoding_data['scores'].mean(axis=0) + else: + mean_scores = decoding_data['mean'] + + fig, ax = plt.subplots(constrained_layout=True) + im = ax.imshow( + mean_scores, + extent=times[[0, -1, 0, -1]], + interpolation='nearest', + origin='lower', + cmap='RdBu_r', + vmin=0, + vmax=1 + ) + + # Indicate time point zero + if times.min() < 0 < times.max(): + ax.axvline(0, ls='--', lw=0.5, color='black') + ax.axhline(0, ls='--', lw=0.5, color='black') + + # Indicate diagonal + ax.plot(times[[0, -1]], times[[0, -1]], ls='--', lw=0.5, color='black') + + # Axis labels + _label_time_by_time( + ax, + xlabel='Testing time (s)', + ylabel='Training time (s)', + decim=decim, + ) + + # Color bar + cbar = plt.colorbar(im, ax=ax) + if metric == 'roc_auc': + metric = 'ROC AUC' + cbar.set_label(f'Score ({metric})') + + return fig + + +def _gen_empty_report( + *, + cfg: SimpleNamespace, + subject: str, + session: Optional[str] +) -> mne.Report: + title = f'sub-{subject}' + if session is not None: + title += f', ses-{session}' + if cfg.task is not None: + title += f', task-{cfg.task}' + + report = mne.Report(title=title, raw_psd=True) + return report + + +def _contrasts_to_names(contrasts: List[List[str]]) -> List[str]: + return [f'{c[0]} vs.\n{c[1]}' for c in contrasts] + + +def add_event_counts(*, + cfg, + subject: Optional[str], + session: Optional[str], + report: mne.Report) -> None: + try: + df_events = count_events(BIDSPath(root=cfg.bids_root, + session=session)) + except ValueError: + msg = 'Could not read events.' + logger.warning( + **gen_log_kwargs(message=msg, subject=subject, session=session) + ) + df_events = None + + if df_events is not None: + css_classes = ('table', 'table-striped', 'table-borderless', + 'table-hover') + report.add_html( + f'
\n' + f'{df_events.to_html(classes=css_classes, border=0)}\n' + f'
', + title='Event counts', + tags=('events',) + ) + css = ('.event-counts {\n' + ' display: -webkit-box;\n' + ' display: -ms-flexbox;\n' + ' display: -webkit-flex;\n' + ' display: flex;\n' + ' justify-content: center;\n' + ' text-align: center;\n' + '}\n\n' + 'th, td {\n' + ' text-align: center;\n' + '}\n') + report.add_custom_css(css=css) + + +def add_system_info(report: mne.Report): + """Add system information and the pipeline configuration to the report.""" + + config_path = Path(os.environ['MNE_BIDS_STUDY_CONFIG']) + report.add_code( + code=config_path, + title='Configuration file', + tags=('configuration',) + ) + report.add_sys_info(title='System information') + + +def _all_conditions(*, cfg): + if isinstance(cfg.conditions, dict): + conditions = list(cfg.conditions.keys()) + else: + conditions = cfg.conditions.copy() + conditions.extend([contrast["name"] for contrast in cfg.all_contrasts]) + return conditions + + +def run_report_average_sensor(*, cfg, session: str) -> None: + subject = 'average' + msg = 'Generating grand average report …' + logger.info( + **gen_log_kwargs(message=msg, subject=subject, session=session) + ) + + evoked_fname = BIDSPath( + subject=subject, + session=session, + task=cfg.task, + acquisition=cfg.acq, + run=None, + recording=cfg.rec, + space=cfg.space, + suffix='ave', + extension='.fif', + datatype=cfg.datatype, + root=cfg.deriv_root, + check=False + ) + + title = f'sub-{subject}' + if session is not None: + title += f', ses-{session}' + if cfg.task is not None: + title += f', task-{cfg.task}' + + report = mne.Report( + title=title, + raw_psd=True + ) + evokeds = mne.read_evokeds(evoked_fname) + for evoked in evokeds: + _restrict_analyze_channels(evoked, cfg) + + conditions = _all_conditions(cfg=cfg) + + ####################################################################### + # + # Add event stats. + # + add_event_counts(cfg=cfg, report=report, subject=subject, session=session) + + ####################################################################### + # + # Visualize evoked responses. + # + for condition, evoked in zip(conditions, evokeds): + if condition in cfg.conditions: + title = f'Average: {condition}' + tags = ( + 'evoked', + _sanitize_cond_tag(condition) + ) + else: # It's a contrast of two conditions. + # XXX Will change once we process contrasts here too + continue + + report.add_evokeds( + evokeds=evoked, + titles=title, + projs=False, + tags=tags, + n_time_points=cfg.report_evoked_n_time_points, + # captions=evoked.comment # TODO upstream + ) + + ####################################################################### + # + # Visualize decoding results. + # + if cfg.decode and cfg.decoding_contrasts: + add_decoding_grand_average( + session=session, cfg=cfg, report=report + ) + + if cfg.decode and cfg.decoding_csp: + add_csp_grand_average( + session=session, cfg=cfg, report=report + ) + + +def run_report_average_source(*, cfg, session: str) -> None: + ####################################################################### + # + # Visualize forward solution, inverse operator, and inverse solutions. + # + subject = 'average' + evoked_fname = BIDSPath( + subject=subject, + session=session, + task=cfg.task, + acquisition=cfg.acq, + run=None, + recording=cfg.rec, + space=cfg.space, + suffix='ave', + extension='.fif', + datatype=cfg.datatype, + root=cfg.deriv_root, + check=False + ) + evokeds = mne.read_evokeds(evoked_fname) + method = cfg.inverse_method + inverse_str = method + hemi_str = 'hemi' # MNE will auto-append '-lh' and '-rh'. + morph_str = 'morph2fsaverage' + conditions = _all_conditions(cfg=cfg) + with _open_report(cfg=cfg, subject=subject, session=session) as report: + for condition, evoked in zip(conditions, evokeds): + tags = ( + 'source-estimate', + _sanitize_cond_tag(condition), + ) + if condition in cfg.conditions: + title = f'Average: {condition}' + else: # It's a contrast of two conditions. + title = f'Average contrast: {condition}' + tags = tags + ('contrast',) + cond_str = sanitize_cond_name(condition) + fname_stc_avg = evoked_fname.copy().update( + suffix=f'{cond_str}+{inverse_str}+{morph_str}+{hemi_str}', + extension=None) + if not Path(f'{fname_stc_avg.fpath}-lh.stc').exists(): + continue + report.add_stc( + stc=fname_stc_avg, + title=title, + subject='fsaverage', + subjects_dir=cfg.fs_subjects_dir, + n_time_points=cfg.report_stc_n_time_points, + tags=tags + ) + + +def add_decoding_grand_average( + *, + session: Optional[str], + cfg: SimpleNamespace, + report: mne.Report, +): + """Add decoding results to the grand average report.""" + import matplotlib.pyplot as plt # nested import to help joblib + + bids_path = BIDSPath( + subject='average', + session=session, + task=cfg.task, + acquisition=cfg.acq, + run=None, + recording=cfg.rec, + space=cfg.space, + suffix='ave', + extension='.fif', + datatype=cfg.datatype, + root=cfg.deriv_root, + check=False + ) + + # Full-epochs decoding + all_decoding_scores = [] + for contrast in cfg.decoding_contrasts: + cond_1, cond_2 = contrast + a_vs_b = f'{cond_1}+{cond_2}'.replace(op.sep, '') + processing = f'{a_vs_b}+FullEpochs+{cfg.decoding_metric}' + processing = processing.replace('_', '-').replace('-', '') + fname_decoding = bids_path.copy().update( + processing=processing, + suffix='decoding', + extension='.mat' + ) + decoding_data = loadmat(fname_decoding) + all_decoding_scores.append( + np.atleast_1d(decoding_data['scores'].squeeze()) + ) + del fname_decoding, processing, a_vs_b, decoding_data + + fig, caption = _plot_full_epochs_decoding_scores( + contrast_names=_contrasts_to_names(cfg.decoding_contrasts), + scores=all_decoding_scores, + metric=cfg.decoding_metric, + kind='grand-average' + ) + title = f'Full-epochs decoding: {cond_1} vs. {cond_2}' + report.add_figure( + fig=fig, + title=title, + section='Decoding: full-epochs', + caption=caption, + tags=( + 'epochs', + 'contrast', + 'decoding', + *[f'{_sanitize_cond_tag(cond_1)}–{_sanitize_cond_tag(cond_2)}' + for cond_1, cond_2 in cfg.decoding_contrasts] + ) + ) + # close figure to save memory + plt.close(fig) + del fig, caption, title + + # Time-by-time decoding + for contrast in cfg.decoding_contrasts: + cond_1, cond_2 = contrast + a_vs_b = f'{cond_1}+{cond_2}'.replace(op.sep, '') + section = 'Decoding: time-by-time' + tags = ( + 'epochs', + 'contrast', + 'decoding', + f'{_sanitize_cond_tag(cond_1)}–{_sanitize_cond_tag(cond_2)}' + ) + processing = f'{a_vs_b}+TimeByTime+{cfg.decoding_metric}' + processing = processing.replace('_', '-').replace('-', '') + fname_decoding = bids_path.copy().update( + processing=processing, + suffix='decoding', + extension='.mat' + ) + decoding_data = loadmat(fname_decoding) + del fname_decoding, processing, a_vs_b + + # Plot scores + fig = _plot_time_by_time_decoding_scores_gavg( + cfg=cfg, + decoding_data=decoding_data, + ) + caption = ( + f'Based on N={decoding_data["N"].squeeze()} ' + f'subjects. Standard error and confidence interval ' + f'of the mean were bootstrapped with {cfg.n_boot} ' + f'resamples. CI must not be used for statistical inference here, ' + f'as it is not corrected for multiple testing.' + ) + if len(get_subjects(cfg)) > 1: + caption += ( + f' Time periods with decoding performance significantly above ' + f'chance, if any, were derived with a one-tailed ' + f'cluster-based permutation test ' + f'({decoding_data["cluster_n_permutations"].squeeze()} ' + f'permutations) and are highlighted in yellow.' + ) + title = f'Decoding over time: {cond_1} vs. {cond_2}' + report.add_figure( + fig=fig, + title=title, + caption=caption, + section=section, + tags=tags, + ) + plt.close(fig) + + # Plot t-values used to form clusters + if len(get_subjects(cfg)) > 1: + fig = plot_time_by_time_decoding_t_values( + decoding_data=decoding_data + ) + t_threshold = np.round( + decoding_data['cluster_t_threshold'], + 3 + ).item() + caption = ( + f'Observed t-values. Time points with ' + f't-values > {t_threshold} were used to form clusters.' + ) + report.add_figure( + fig=fig, + title=f't-values across time: {cond_1} vs. {cond_2}', + caption=caption, + section=section, + tags=tags, + ) + plt.close(fig) + + if cfg.decoding_time_generalization: + fig = _plot_decoding_time_generalization( + decoding_data=decoding_data, + metric=cfg.decoding_metric, + kind='grand-average' + ) + caption = ( + f'Time generalization (generalization across time, GAT): ' + f'each classifier is trained on each time point, and tested ' + f'on all other time points. The results were averaged across ' + f'N={decoding_data["N"].item()} subjects.' + ) + title = f'Time generalization: {cond_1} vs. {cond_2}' + report.add_figure( + fig=fig, + title=title, + caption=caption, + section=section, + tags=tags, + ) + plt.close(fig) + + +def _sanitize_cond_tag(cond): + return cond.lower().replace(' ', '-') + + +def _imshow_tf(vals, ax, *, tmin, tmax, fmin, fmax, vmin, vmax, cmap='RdBu_r', + mask=None, cmap_masked=None): + """Plot CSP TF decoding scores.""" + # XXX Add support for more metrics + assert len(vals) == len(tmin) == len(tmax) == len(fmin) == len(fmax) + mask = np.zeros(vals.shape, dtype=bool) if mask is None else mask + assert len(vals) == len(mask) + assert vals.ndim == mask.ndim == 1 + img = None + for v, t1, t2, f1, f2, m in zip(vals, tmin, tmax, fmin, fmax, mask): + use_cmap = cmap_masked if m else cmap + img = ax.imshow( + np.array([[v]], float), + cmap=use_cmap, extent=[t1, t2, f1, f2], aspect='auto', + interpolation='none', origin='lower', vmin=vmin, vmax=vmax, + ) + return img + + +def add_csp_grand_average( + *, + session: str, + cfg: SimpleNamespace, + report: mne.Report, +): + """Add CSP decoding results to the grand average report.""" + import matplotlib.pyplot as plt # nested import to help joblib + + bids_path = BIDSPath( + subject='average', + session=session, + task=cfg.task, + acquisition=cfg.acq, + run=None, + recording=cfg.rec, + space=cfg.space, + suffix='decoding', + datatype=cfg.datatype, + root=cfg.deriv_root, + check=False + ) + + # First, plot decoding scores across frequency bins (entire epochs). + section = 'Decoding: CSP' + freq_name_to_bins_map = _handle_csp_args( + cfg.decoding_csp_times, + cfg.decoding_csp_freqs, + cfg.decoding_metric, + ) + for contrast in cfg.decoding_contrasts: + cond_1, cond_2 = contrast + a_vs_b = f'{cond_1}+{cond_2}'.replace(op.sep, '') + processing = f'{a_vs_b}+CSP+{cfg.decoding_metric}' + processing = processing.replace('_', '-').replace('-', '') + fname_csp_freq_results = bids_path.copy().update( + processing=processing, + extension='.xlsx', + ) + csp_freq_results = pd.read_excel( + fname_csp_freq_results, + sheet_name='CSP Frequency' + ) + + freq_bin_starts = list() + freq_bin_widths = list() + decoding_scores = list() + error_bars = list() + for freq_range_name, freq_bins in freq_name_to_bins_map.items(): + results = csp_freq_results.loc[ + csp_freq_results['freq_range_name'] == freq_range_name, : + ] + results.reset_index(drop=True, inplace=True) + assert len(results) == len(freq_bins) + for bi, freq_bin in enumerate(freq_bins): + freq_bin_starts.append(freq_bin[0]) + freq_bin_widths.append(np.diff(freq_bin)[0]) + decoding_scores.append(results['mean'][bi]) + cis_lower = results['mean_ci_lower'][bi] + cis_upper = results['mean_ci_upper'][bi] + error_bars_lower = decoding_scores[-1] - cis_lower + error_bars_upper = cis_upper - decoding_scores[-1] + error_bars.append( + np.stack([error_bars_lower, error_bars_upper])) + assert len(error_bars[-1]) == 2 # lower, upper + del cis_lower, cis_upper, error_bars_lower, error_bars_upper + error_bars = np.array(error_bars, float).T + + if cfg.decoding_metric == 'roc_auc': + metric = 'ROC AUC' + + fig, ax = plt.subplots(constrained_layout=True) + ax.bar( + x=freq_bin_starts, + width=freq_bin_widths, + height=decoding_scores, + align='edge', + yerr=error_bars, + edgecolor='black', + ) + ax.set_ylim([0, 1.02]) + offset = matplotlib.transforms.offset_copy( + ax.transData, fig, 0, 5, units='points') + for freq_range_name, freq_bins in freq_name_to_bins_map.items(): + start = freq_bins[0][0] + stop = freq_bins[-1][1] + width = stop - start + ax.text( + x=start + width / 2, + y=0., + transform=offset, + s=freq_range_name, + ha='center', + va='bottom', + ) + ax.axhline(0.5, color='black', linestyle='--', label='chance') + ax.legend() + ax.set_xlabel('Frequency (Hz)') + ax.set_ylabel(f'Mean decoding score ({metric})') + tags = ( + 'epochs', + 'contrast', + 'decoding', + 'csp', + f"{_sanitize_cond_tag(cond_1)}–{_sanitize_cond_tag(cond_2)}", + ) + title = f'CSP decoding: {cond_1} vs. {cond_2}' + report.add_figure( + fig=fig, + title=title, + section=section, + caption='Mean decoding scores. Error bars represent ' + 'bootstrapped 95% confidence intervals.', + tags=tags, + ) + + # Now, plot decoding scores across time-frequency bins. + for contrast in cfg.decoding_contrasts: + cond_1, cond_2 = contrast + a_vs_b = f'{cond_1}+{cond_2}'.replace(op.sep, '') + processing = f'{a_vs_b}+CSP+{cfg.decoding_metric}' + processing = processing.replace('_', '-').replace('-', '') + fname_csp_cluster_results = bids_path.copy().update( + processing=processing, + extension='.mat', + ) + csp_cluster_results = loadmat(fname_csp_cluster_results) + + fig, ax = plt.subplots( + nrows=1, ncols=2, sharex=True, sharey=True, + constrained_layout=True) + n_clu = 0 + cbar = None + lims = [np.inf, -np.inf, np.inf, -np.inf] + for freq_range_name, bins in freq_name_to_bins_map.items(): + results = csp_cluster_results[freq_range_name][0][0] + mean_crossval_scores = results['mean_crossval_scores'].ravel() + # t_vals = results['t_vals'] + clusters = results['clusters'] + cluster_p_vals = results['cluster_p_vals'].squeeze() + tmin = results['time_bin_edges'].ravel() + tmin, tmax = tmin[:-1], tmin[1:] + fmin = results['freq_bin_edges'].ravel() + fmin, fmax = fmin[:-1], fmin[1:] + lims[0] = min(lims[0], tmin.min()) + lims[1] = max(lims[1], tmax.max()) + lims[2] = min(lims[2], fmin.min()) + lims[3] = max(lims[3], fmax.max()) + # replicate, matching time-frequency order during clustering + fmin, fmax = np.tile(fmin, len(tmin)), np.tile(fmax, len(tmax)) + tmin, tmax = np.repeat(tmin, len(bins)), np.repeat(tmax, len(bins)) + assert fmin.shape == fmax.shape == tmin.shape == tmax.shape + assert fmin.shape == mean_crossval_scores.shape + cluster_t_threshold = results['cluster_t_threshold'].ravel().item() + + significant_cluster_idx = np.where( + cluster_p_vals < cfg.cluster_permutation_p_threshold + )[0] + significant_clusters = clusters[significant_cluster_idx] + n_clu += len(significant_cluster_idx) + + # XXX Add support for more metrics + assert cfg.decoding_metric == 'roc_auc' + metric = 'ROC AUC' + vmax = max( + np.abs(mean_crossval_scores.min() - 0.5), + np.abs(mean_crossval_scores.max() - 0.5) + ) + 0.5 + vmin = 0.5 - (vmax - 0.5) + # For diverging gray colormap, we need to combine two existing + # colormaps, as there is no diverging colormap with gray/black at + # both endpoints. + from matplotlib.cm import gray, gray_r + from matplotlib.colors import ListedColormap + + black_to_white = gray( + np.linspace(start=0, stop=1, endpoint=False, num=128) + ) + white_to_black = gray_r( + np.linspace(start=0, stop=1, endpoint=False, num=128) + ) + black_to_white_to_black = np.vstack( + (black_to_white, white_to_black) + ) + diverging_gray_cmap = ListedColormap( + black_to_white_to_black, name='DivergingGray' + ) + cmap_gray = diverging_gray_cmap + img = _imshow_tf( + mean_crossval_scores, ax[0], + tmin=tmin, tmax=tmax, fmin=fmin, fmax=fmax, + vmin=vmin, vmax=vmax) + if cbar is None: + ax[0].set_xlabel('Time (s)') + ax[0].set_ylabel('Frequency (Hz)') + ax[1].set_xlabel('Time (s)') + cbar = fig.colorbar( + ax=ax[1], shrink=0.75, orientation='vertical', + mappable=img) + cbar.set_label(f'Mean decoding score ({metric})') + offset = matplotlib.transforms.offset_copy( + ax[0].transData, fig, 6, 0, units='points') + ax[0].text(tmin.min(), + 0.5 * fmin.min() + 0.5 * fmax.max(), + freq_range_name, transform=offset, + ha='left', va='center', rotation=90) + + if len(significant_clusters): + # Create a masked array that only shows the T-values for + # time-frequency bins that belong to significant clusters. + if len(significant_clusters) == 1: + mask = ~significant_clusters[0].astype(bool) + else: + mask = ~np.logical_or( + *significant_clusters + ) + mask = mask.ravel() + else: + mask = np.ones(mean_crossval_scores.shape, dtype=bool) + _imshow_tf( + mean_crossval_scores, ax[1], + tmin=tmin, tmax=tmax, fmin=fmin, fmax=fmax, + vmin=vmin, vmax=vmax, mask=mask, cmap_masked=cmap_gray) + + ax[0].set_xlim(lims[:2]) + ax[0].set_ylim(lims[2:]) + ax[0].set_title('Scores') + ax[1].set_title('Masked') + tags = ( + 'epochs', + 'contrast', + 'decoding', + 'csp', + f"{_sanitize_cond_tag(cond_1)}–{_sanitize_cond_tag(cond_2)}", + ) + title = f'CSP TF decoding: {cond_1} vs. {cond_2}' + report.add_figure( + fig=fig, + title=title, + section=section, + caption=f'Found {n_clu} ' + f'cluster{_pl(n_clu)} with ' + f'p < {cfg.cluster_permutation_p_threshold} ' + f'(clustering bins with absolute t-values > ' + f'{round(cluster_t_threshold, 3)}).', + tags=tags, + ) + + +@contextlib.contextmanager +def _agg_backend(): + import matplotlib + backend = matplotlib.get_backend() + matplotlib.use('Agg', force=True) + try: + yield + finally: + matplotlib.use(backend, force=True) diff --git a/mne_bids_pipeline/config.py b/mne_bids_pipeline/config.py index 49bb5dc46..8647f6c16 100644 --- a/mne_bids_pipeline/config.py +++ b/mne_bids_pipeline/config.py @@ -698,6 +698,20 @@ Keep it None if no lowpass filtering should be applied. """ +l_trans_bandwidth: Union[float, Literal['auto']] = 'auto' +""" +Specifies the transition bandwidth of the +highpass filter. By default it's `'auto'` and uses default MNE +parameters. +""" + +h_trans_bandwidth: Union[float, Literal['auto']] = 'auto' +""" +Specifies the transition bandwidth of the +lowpass filter. By default it's `'auto'` and uses default MNE +parameters. +""" + resample_sfreq: Optional[float] = None """ Specifies at which sampling frequency the data should be resampled. @@ -984,6 +998,18 @@ ``` """ +report_evoked_n_time_points: Optional[int] = None +""" +Specifies the number of time points to display for each evoked +in the report. If None it defaults to the current default in MNE-Python. + +???+ example "Example" + Only display 5 time points per evoked + ```python + report_evoked_n_time_points = 5 + ``` +""" + ############################################################################### # ARTIFACT REMOVAL # ---------------- @@ -1879,22 +1905,6 @@ def noise_cov(bids_path): ``` """ -############################################################################### -# ADVANCED -# -------- - -report_evoked_n_time_points: Optional[int] = None -""" -Specifies the number of time points to display for each evoked -in the report. If None it defaults to the current default in MNE-Python. - -???+ example "Example" - Only display 5 time points per evoked - ```python - report_evoked_n_time_points = 5 - ``` -""" - report_stc_n_time_points: Optional[int] = None """ Specifies the number of time points to display for each source estimates @@ -1907,19 +1917,9 @@ def noise_cov(bids_path): ``` """ -l_trans_bandwidth: Union[float, Literal['auto']] = 'auto' -""" -Specifies the transition bandwidth of the -highpass filter. By default it's `'auto'` and uses default MNE -parameters. -""" - -h_trans_bandwidth: Union[float, Literal['auto']] = 'auto' -""" -Specifies the transition bandwidth of the -lowpass filter. By default it's `'auto'` and uses default MNE -parameters. -""" +############################################################################### +# Execution +# --------- N_JOBS: int = 1 """ diff --git a/mne_bids_pipeline/scripts/preprocessing/_02_frequency_filter.py b/mne_bids_pipeline/scripts/preprocessing/_02_frequency_filter.py index c4f38a85f..92c23a8ec 100644 --- a/mne_bids_pipeline/scripts/preprocessing/_02_frequency_filter.py +++ b/mne_bids_pipeline/scripts/preprocessing/_02_frequency_filter.py @@ -32,6 +32,7 @@ from ..._io import _read_json, _empty_room_match_path from ..._logging import gen_log_kwargs, logger from ..._parallel import parallel_func, get_parallel_backend +from ..._report import _open_report, plot_auto_scores_ from ..._run import failsafe_run, save_logs, _update_for_splits from ..._typing import Literal @@ -173,7 +174,7 @@ def filter_data( in_files: dict, ) -> None: """Filter data from a single subject.""" - + import matplotlib.pyplot as plt out_files = dict() bids_path = in_files.pop(f"raw_run-{run}") @@ -223,7 +224,7 @@ def filter_data( msg = (f'Reading {data_type} recording: ' f'{bids_path_noise.basename}') logger.info(**gen_log_kwargs(message=msg, subject=subject, - session=session)) + session=session, run=task)) raw_noise = mne.io.read_raw_fif(bids_path_noise) elif data_type == 'empty-room': raw_noise = import_er_data( @@ -265,6 +266,56 @@ def filter_data( raw_noise.plot_psd(fmax=fmax) assert len(in_files) == 0, in_files.keys() + + # Report + with _open_report(cfg=cfg, subject=subject, session=session) as report: + # This is a weird place for this, but it's the first place we actually + # start a report, so let's leave it here. We could put it in + # _import_data (where the auto-bad-finding is called), but that seems + # worse. + if run == cfg.runs[0] and cfg.find_noisy_channels_meg: + msg = 'Adding visualization of noisy channel detection to report.' + logger.info( + **gen_log_kwargs(message=msg, subject=subject, session=session) + ) + figs, captions = plot_auto_scores_( + cfg=cfg, + subject=subject, + session=session, + ) + tags = ('raw', 'data-quality', *[f'run-{i}' for i in cfg.runs]) + report.add_figure( + fig=figs, + caption=captions, + title='Data Quality', + tags=tags + ) + for fig in figs: + plt.close(fig) + + msg = 'Adding filtered raw data to report.' + logger.info( + **gen_log_kwargs( + message=msg, subject=subject, session=session, run=run + ) + ) + for fname in out_files.values(): + title = 'Raw' + if fname.run is not None: + title += f', run {fname.run}' + plot_raw_psd = ( + cfg.plot_psd_for_runs == 'all' or + fname.run in cfg.plot_psd_for_runs + ) + report.add_raw( + raw=fname, + title=title, + butterfly=5, + psd=plot_raw_psd, + tags=('raw', 'filtered', f'run-{fname.run}'), + # caption=fname.basename # TODO upstream + ) + return out_files @@ -313,6 +364,7 @@ def get_config( ch_types=config.ch_types, eog_channels=config.eog_channels, on_rename_missing_events=config.on_rename_missing_events, + plot_psd_for_runs=config.plot_psd_for_runs, _raw_split_size=config._raw_split_size, ) return cfg diff --git a/mne_bids_pipeline/scripts/preprocessing/_03_make_epochs.py b/mne_bids_pipeline/scripts/preprocessing/_03_make_epochs.py index e88dae18a..3afb5d91d 100644 --- a/mne_bids_pipeline/scripts/preprocessing/_03_make_epochs.py +++ b/mne_bids_pipeline/scripts/preprocessing/_03_make_epochs.py @@ -19,6 +19,7 @@ ) from ..._import_data import make_epochs, annotations_to_events from ..._logging import gen_log_kwargs, logger +from ..._report import _open_report from ..._run import ( failsafe_run, save_logs, _update_for_splits, _sanitize_callable) from ..._parallel import parallel_func, get_parallel_backend @@ -215,6 +216,45 @@ def run_epochs(*, cfg, subject, session, in_files): split_size=cfg._epochs_split_size) _update_for_splits(out_files, 'epochs') + # Report + with _open_report(cfg=cfg, subject=subject, session=session) as report: + if not cfg.task_is_rest: + msg = 'Adding events plot to report.' + logger.info( + **gen_log_kwargs( + message=msg, subject=subject, session=session + ) + ) + events, event_id, sfreq, first_samp = _get_events( + cfg=cfg, subject=subject, session=session + ) + report.add_events( + events=events, + event_id=event_id, + sfreq=sfreq, + first_samp=first_samp, + title='Events', + # caption='Events in filtered continuous data' # TODO upstream + ) + msg = 'Adding uncleaned epochs to report.' + logger.info( + **gen_log_kwargs( + message=msg, subject=subject, session=session + ) + ) + # Add PSD plots for 30s of data or all epochs if we have less available + if len(epochs) * (epochs.tmax - epochs.tmin) < 30: + psd = True + else: + psd = 30 + report.add_epochs( + epochs=epochs, + title='Epochs: before cleaning', + psd=psd, + drop_log_ignore=() + ) + + # Interactive if cfg.interactive: epochs.plot() epochs.plot_image(combine='gfp', sigma=2., cmap='YlGnBu_r') @@ -222,6 +262,36 @@ def run_epochs(*, cfg, subject, session, in_files): return out_files +# TODO: ideally we wouldn't need this anymore and could refactor the code above +def _get_events(cfg, subject, session): + raws_filt = [] + raw_fname = BIDSPath(subject=subject, + session=session, + task=cfg.task, + acquisition=cfg.acq, + recording=cfg.rec, + space=cfg.space, + processing='filt', + suffix='raw', + extension='.fif', + datatype=cfg.datatype, + root=cfg.deriv_root, + check=False) + + for run in cfg.runs: + this_raw_fname = raw_fname.copy().update(run=run) + this_raw_fname = _update_for_splits(this_raw_fname, None, single=True) + raw_filt = mne.io.read_raw_fif(this_raw_fname) + raws_filt.append(raw_filt) + del this_raw_fname + + # Concatenate the filtered raws and extract the events. + raw_filt_concat = mne.concatenate_raws(raws_filt, on_mismatch='warn') + events, event_id = mne.events_from_annotations(raw=raw_filt_concat) + return (events, event_id, raw_filt_concat.info['sfreq'], + raw_filt_concat.first_samp) + + def get_config( *, config, diff --git a/mne_bids_pipeline/scripts/preprocessing/_04b_run_ssp.py b/mne_bids_pipeline/scripts/preprocessing/_04b_run_ssp.py index 813e87c1d..f3999f9fd 100644 --- a/mne_bids_pipeline/scripts/preprocessing/_04b_run_ssp.py +++ b/mne_bids_pipeline/scripts/preprocessing/_04b_run_ssp.py @@ -17,9 +17,10 @@ get_deriv_root, ) from ..._logging import gen_log_kwargs, logger -from ..._run import failsafe_run, _update_for_splits, _script_path, save_logs from ..._parallel import parallel_func, get_parallel_backend from ..._reject import _get_reject +from ..._report import _open_report +from ..._run import failsafe_run, _update_for_splits, _script_path, save_logs def get_input_fnames_run_ssp(**kwargs): @@ -50,6 +51,7 @@ def get_input_fnames_run_ssp(**kwargs): @failsafe_run(script_path=__file__, get_input_fnames=get_input_fnames_run_ssp) def run_ssp(*, cfg, subject, session, in_files): + import matplotlib.pyplot as plt # compute SSP on first run of raw raw_fnames = [in_files.pop(f'raw_run-{run}') for run in cfg.runs] @@ -131,6 +133,44 @@ def run_ssp(*, cfg, subject, session, in_files): mne.write_proj(out_files['proj'], sum(projs.values(), []), overwrite=True) assert len(in_files) == 0, in_files.keys() + + # Report + with _open_report(cfg=cfg, subject=subject, session=session) as report: + for kind in proj_kinds: + if f'epochs_{kind}' not in out_files: + continue + + msg = f'Adding {kind.upper()} SSP to report.' + logger.info( + **gen_log_kwargs( + message=msg, subject=subject, session=session + ) + ) + proj_epochs = mne.read_epochs(out_files[f'epochs_{kind}']) + projs = mne.read_proj(out_files['proj']) + projs = [p for p in projs if kind.upper() in p['desc']] + assert len(projs), len(projs) # should exist if the epochs do + picks_trace = None + if kind == 'ecg': + if 'ecg' in proj_epochs: + picks_trace = 'ecg' + else: + assert kind == 'eog' + if cfg.eog_channels: + picks_trace = cfg.eog_channels + elif 'eog' in proj_epochs: + picks_trace = 'eog' + fig = mne.viz.plot_projs_joint( + projs, proj_epochs.average(picks='all'), + picks_trace=picks_trace) + caption = ( + f'Computed using {len(proj_epochs)} epochs ' + f'(from {len(proj_epochs.drop_log)} original events)' + ) + report.add_figure( + fig, title=f'SSP: {kind.upper()}', caption=caption, + tags=('ssp', kind)) + plt.close(fig) return out_files diff --git a/mne_bids_pipeline/scripts/preprocessing/_05a_apply_ica.py b/mne_bids_pipeline/scripts/preprocessing/_05a_apply_ica.py index 6a10a3c07..9fd7d5013 100644 --- a/mne_bids_pipeline/scripts/preprocessing/_05a_apply_ica.py +++ b/mne_bids_pipeline/scripts/preprocessing/_05a_apply_ica.py @@ -25,9 +25,10 @@ get_subjects, get_sessions, get_task, get_datatype, get_deriv_root, ) from ..._logging import gen_log_kwargs, logger -from ..._run import failsafe_run, _update_for_splits, _script_path, save_logs from ..._parallel import parallel_func, get_parallel_backend from ..._reject import _get_reject +from ..._report import _open_report +from ..._run import failsafe_run, _update_for_splits, _script_path, save_logs def get_input_fnames_apply_ica(**kwargs): @@ -135,6 +136,28 @@ def apply_ica(*, cfg, subject, session, in_files): out_files['report'], overwrite=True, open_browser=cfg.interactive) assert len(in_files) == 0, in_files.keys() + + # Report + if ica.exclude: + msg = 'Adding ICA to report.' + else: + msg = 'Skipping ICA addition to report, no components marked as bad.' + logger.info( + **gen_log_kwargs(message=msg, subject=subject, session=session) + ) + if ica.exclude: + with _open_report(cfg=cfg, subject=subject, session=session) as report: + report.add_ica( + ica=ica, + title='ICA', + inst=epochs, + picks=ica.exclude + # TODO upstream + # captions=f'Evoked response (across all epochs) ' + # f'before and after ICA ' + # f'({len(ica.exclude)} ICs removed)' + ) + return out_files diff --git a/mne_bids_pipeline/scripts/preprocessing/_06_ptp_reject.py b/mne_bids_pipeline/scripts/preprocessing/_06_ptp_reject.py index b5afb268b..b9292eeff 100644 --- a/mne_bids_pipeline/scripts/preprocessing/_06_ptp_reject.py +++ b/mne_bids_pipeline/scripts/preprocessing/_06_ptp_reject.py @@ -20,8 +20,9 @@ ) from ..._logging import gen_log_kwargs, logger from ..._parallel import parallel_func, get_parallel_backend -from ..._run import failsafe_run, _update_for_splits, save_logs from ..._reject import _get_reject +from ..._report import _open_report +from ..._run import failsafe_run, _update_for_splits, save_logs def get_input_fnames_drop_ptp(**kwargs): @@ -121,6 +122,24 @@ def drop_ptp(*, cfg, subject, session, in_files): split_size=cfg._epochs_split_size) _update_for_splits(out_files, 'epochs') assert len(in_files) == 0, in_files.keys() + + # Report + msg = 'Adding cleaned epochs to report.' + logger.info( + **gen_log_kwargs(message=msg, subject=subject, session=session) + ) + # Add PSD plots for 30s of data or all epochs if we have less available + if len(epochs) * (epochs.tmax - epochs.tmin) < 30: + psd = True + else: + psd = 30 + with _open_report(cfg=cfg, subject=subject, session=session) as report: + report.add_epochs( + epochs=epochs, + title='Epochs: after cleaning', + psd=psd, + drop_log_ignore=() + ) return out_files diff --git a/mne_bids_pipeline/scripts/report/_01_make_reports.py b/mne_bids_pipeline/scripts/report/_01_make_reports.py deleted file mode 100644 index 4c8fbf4d5..000000000 --- a/mne_bids_pipeline/scripts/report/_01_make_reports.py +++ /dev/null @@ -1,2129 +0,0 @@ -"""Make reports. - -Builds an HTML report for each subject containing all the relevant analysis -plots. -""" - -import contextlib -import os -import os.path as op -from pathlib import Path -from typing import Optional, List, Literal -from types import SimpleNamespace - -from scipy.io import loadmat -import numpy as np -import pandas as pd -import matplotlib.transforms - -import mne -from mne.utils import _pl -from mne_bids import BIDSPath -from mne_bids.stats import count_events - -from ..._config_utils import ( - get_noise_cov_bids_path, get_subjects, sanitize_cond_name, - get_task, get_datatype, get_deriv_root, get_sessions, - _restrict_analyze_channels, get_fs_subjects_dir, get_fs_subject, - get_runs, get_bids_root, get_decoding_contrasts, get_all_contrasts) -from ..._decoding import _handle_csp_args -from ..._logging import logger, gen_log_kwargs -from ..._parallel import get_parallel_backend, parallel_func -from ..._run import ( - failsafe_run, save_logs, _update_for_splits, _sanitize_callable) -from ..._reject import _get_reject -from ..._viz import plot_auto_scores - - -def get_events(cfg, subject, session): - raws_filt = [] - raw_fname = BIDSPath(subject=subject, - session=session, - task=cfg.task, - acquisition=cfg.acq, - recording=cfg.rec, - space=cfg.space, - processing='filt', - suffix='raw', - extension='.fif', - datatype=cfg.datatype, - root=cfg.deriv_root, - check=False) - - for run in cfg.runs: - this_raw_fname = raw_fname.copy().update(run=run) - this_raw_fname = _update_for_splits(this_raw_fname, None, single=True) - raw_filt = mne.io.read_raw_fif(this_raw_fname) - raws_filt.append(raw_filt) - del this_raw_fname - - # Concatenate the filtered raws and extract the events. - raw_filt_concat = mne.concatenate_raws(raws_filt, on_mismatch='warn') - events, event_id = mne.events_from_annotations(raw=raw_filt_concat) - return (events, event_id, raw_filt_concat.info['sfreq'], - raw_filt_concat.first_samp) - - -def get_er_path(cfg, subject, session): - raw_fname = BIDSPath(subject=subject, - session=session, - acquisition=cfg.acq, - run=None, - recording=cfg.rec, - space=cfg.space, - task='noise', - processing='filt', - suffix='raw', - extension='.fif', - datatype=cfg.datatype, - root=cfg.deriv_root, - check=False) - raw_fname = _update_for_splits( - raw_fname, None, single=True, allow_missing=True) - return raw_fname - - -def plot_auto_scores_(cfg, subject, session): - """Plot automated bad channel detection scores. - """ - import json_tricks - - fname_scores = BIDSPath(subject=subject, - session=session, - task=cfg.task, - acquisition=cfg.acq, - run=None, - processing=cfg.proc, - recording=cfg.rec, - space=cfg.space, - suffix='scores', - extension='.json', - datatype=cfg.datatype, - root=cfg.deriv_root, - check=False) - - all_figs = [] - all_captions = [] - for run in cfg.runs: - fname_scores.update(run=run) - auto_scores = json_tricks.loads( - fname_scores.fpath.read_text(encoding='utf-8-sig') - ) - - figs = plot_auto_scores(auto_scores, ch_types=cfg.ch_types) - all_figs.extend(figs) - - # Could be more than 1 fig, e.g. "grad" and "mag" - captions = [f'Run {run}'] * len(figs) - all_captions.extend(captions) - - assert all_figs - return all_figs, all_captions - - -# def plot_full_epochs_decoding_scores( -# contrast: str, -# cross_val_scores: np.ndarray, -# metric: str -# ): -# """Plot cross-validation results from full-epochs decoding. -# """ -# import matplotlib.pyplot as plt # nested import to help joblib -# import seaborn as sns - -# cross_val_scores = cross_val_scores.squeeze() # Make it a 1D array -# data = pd.DataFrame({ -# 'contrast': [contrast] * len(cross_val_scores), -# 'scores': cross_val_scores, -# 'metric': [metric] * len(cross_val_scores)} -# ) -# fig, ax = plt.subplots(constrained_layout=True) - -# sns.swarmplot(x='contrast', y='scores', data=data, color='0.25', -# label='cross-val. scores') -# ax.set_xticklabels([]) - -# ax.plot(cross_val_scores.mean(), '+', color='red', ms=15, -# label='mean score', zorder=99) -# ax.axhline(0.5, ls='--', lw=0.5, color='black', label='chance') - -# ax.set_xlabel(f'{contrast[0]} vs. {contrast[1]}') -# if metric == 'roc_auc': -# metric = 'ROC AUC' -# ax.set_ylabel(f'Score ({metric})') -# ax.legend(loc='center right') - -# return fig - - -def _plot_full_epochs_decoding_scores( - contrast_names: List[str], - scores: List[np.ndarray], - metric: str, - kind: Literal['single-subject', 'grand-average'] = 'single-subject', -): - """Plot cross-validation results from full-epochs decoding. - """ - import matplotlib.pyplot as plt # nested import to help joblib - import seaborn as sns - - if metric == 'roc_auc': - metric = 'ROC AUC' - score_label = f'Score ({metric})' - - data = pd.DataFrame({ - 'Contrast': np.array([ - [c] * len(scores[0]) - for c in contrast_names - ]).flatten(), - score_label: np.hstack(scores), - }) - - if kind == 'grand-average': - # First create a grid of boxplots … - g = sns.catplot( - data=data, y=score_label, kind='box', - col='Contrast', col_wrap=3, aspect=0.33 - ) - # … and now add swarmplots on top to visualize every single data point. - g.map_dataframe(sns.swarmplot, y=score_label, color='black') - caption = ( - f'Based on N={len(scores[0])} ' - f'subjects. Each dot represents the mean cross-validation score ' - f'for a single subject. The dashed line is expected chance ' - f'performance.' - ) - else: - # First create a grid of swarmplots to visualize every single - # cross-validation score. - g = sns.catplot( - data=data, y=score_label, kind='swarm', - col='Contrast', col_wrap=3, aspect=0.33, color='black' - ) - - # And now add the mean CV score on top. - def _plot_mean_cv_score(x, **kwargs): - plt.plot(x.mean(), **kwargs) - - g.map( - _plot_mean_cv_score, score_label, marker='+', color='red', - ms=15, label='mean score', zorder=99 - ) - caption = ( - 'Each black dot represents the single cross-validation score. ' - f'The red cross is the mean of all {len(scores[0])} ' - 'cross-validation scores. ' - 'The dashed line is expected chance performance.' - ) - plt.xlim([-0.1, 0.1]) - - g.map(plt.axhline, y=0.5, ls='--', lw=0.5, color='black', zorder=99) - g.set_titles('{col_name}') # use this argument literally! - g.set_xlabels('') - - fig = g.fig - return fig, caption - - -def _plot_time_by_time_decoding_scores( - *, - times: np.ndarray, - cross_val_scores: np.ndarray, - metric: str, - time_generalization: bool, - decim: int, -): - """Plot cross-validation results from time-by-time decoding. - """ - import matplotlib.pyplot as plt # nested import to help joblib - - mean_scores = cross_val_scores.mean(axis=0) - max_scores = cross_val_scores.max(axis=0) - min_scores = cross_val_scores.min(axis=0) - - if time_generalization: - # Only use the diagonal values (classifiers trained and tested on the - # same time points). - mean_scores = np.diag(mean_scores) - max_scores = np.diag(max_scores) - min_scores = np.diag(min_scores) - - fig, ax = plt.subplots(constrained_layout=True) - ax.axhline(0.5, ls='--', lw=0.5, color='black', label='chance') - if times.min() < 0 < times.max(): - ax.axvline(0, ls='-', lw=0.5, color='black') - ax.fill_between(x=times, y1=min_scores, y2=max_scores, color='lightgray', - alpha=0.5, label='range [min, max]') - ax.plot(times, mean_scores, ls='-', lw=2, color='black', - label='mean') - - _label_time_by_time(ax, xlabel='Time (s)', decim=decim) - if metric == 'roc_auc': - metric = 'ROC AUC' - ax.set_ylabel(f'Score ({metric})') - ax.set_ylim((-0.025, 1.025)) - ax.legend(loc='lower right') - - return fig - - -def _label_time_by_time(ax, *, decim, xlabel=None, ylabel=None): - extra = '' - if decim > 1: - extra = f' (decim={decim})' - if xlabel is not None: - ax.set_xlabel(f'{xlabel}{extra}') - if ylabel is not None: - ax.set_ylabel(f'{ylabel}{extra}') - - -def _plot_time_by_time_decoding_scores_gavg(*, cfg, decoding_data): - """Plot the grand-averaged decoding scores. - """ - import matplotlib.pyplot as plt # nested import to help joblib - - # We squeeze() to make Matplotlib happy. - times = decoding_data['times'].squeeze() - mean_scores = decoding_data['mean'].squeeze() - se_lower = mean_scores - decoding_data['mean_se'].squeeze() - se_upper = mean_scores + decoding_data['mean_se'].squeeze() - ci_lower = decoding_data['mean_ci_lower'].squeeze() - ci_upper = decoding_data['mean_ci_upper'].squeeze() - decim = decoding_data['decim'].item() - - if cfg.decoding_time_generalization: - # Only use the diagonal values (classifiers trained and tested on the - # same time points). - mean_scores = np.diag(mean_scores) - se_lower = np.diag(se_lower) - se_upper = np.diag(se_upper) - ci_lower = np.diag(ci_lower) - ci_upper = np.diag(ci_upper) - - metric = cfg.decoding_metric - clusters = np.atleast_1d(decoding_data['clusters'].squeeze()) - - fig, ax = plt.subplots(constrained_layout=True) - ax.set_ylim((-0.025, 1.025)) - - # Start with plotting the significant time periods according to the - # cluster-based permutation test - n_significant_clusters_plotted = 0 - for cluster in clusters: - cluster_times = np.atleast_1d(cluster['times'][0][0].squeeze()) - cluster_p = cluster['p_value'][0][0].item() - if cluster_p >= cfg.cluster_permutation_p_threshold: - continue - - # Only add the label once - if n_significant_clusters_plotted == 0: - label = (f'$p$ < {cfg.cluster_permutation_p_threshold} ' - f'(cluster pemutation)') - else: - label = None - - ax.fill_betweenx( - y=ax.get_ylim(), - x1=cluster_times[0], - x2=cluster_times[-1], - facecolor='orange', - alpha=0.15, - label=label - ) - n_significant_clusters_plotted += 1 - - ax.axhline(0.5, ls='--', lw=0.5, color='black', label='chance') - if times.min() < 0 < times.max(): - ax.axvline(0, ls='-', lw=0.5, color='black') - ax.fill_between(x=times, y1=ci_lower, y2=ci_upper, color='lightgray', - alpha=0.5, label='95% confidence interval') - - ax.plot(times, mean_scores, ls='-', lw=2, color='black', - label='mean') - ax.plot(times, se_lower, ls='-.', lw=0.5, color='gray', - label='mean ± standard error') - ax.plot(times, se_upper, ls='-.', lw=0.5, color='gray') - ax.text(0.05, 0.05, s=f'$N$={decoding_data["N"].squeeze()}', - fontsize='x-large', horizontalalignment='left', - verticalalignment='bottom', transform=ax.transAxes) - - _label_time_by_time(ax, xlabel='Time (s)', decim=decim) - if metric == 'roc_auc': - metric = 'ROC AUC' - ax.set_ylabel(f'Score ({metric})') - ax.legend(loc='lower right') - - return fig - - -def plot_time_by_time_decoding_t_values(decoding_data): - """Plot the t-values used to form clusters for the permutation test. - """ - import matplotlib.pyplot as plt # nested import to help joblib - - # We squeeze() to make Matplotlib happy. - all_times = decoding_data['cluster_all_times'].squeeze() - all_t_values = decoding_data['cluster_all_t_values'].squeeze() - t_threshold = decoding_data['cluster_t_threshold'] - decim = decoding_data['decim'] - - fig, ax = plt.subplots(constrained_layout=True) - ax.plot(all_times, all_t_values, ls='-', color='black', - label='observed $t$-values') - ax.axhline(t_threshold, ls='--', color='red', label='threshold') - - ax.text(0.05, 0.05, s=f'$N$={decoding_data["N"].squeeze()}', - fontsize='x-large', horizontalalignment='left', - verticalalignment='bottom', transform=ax.transAxes) - - _label_time_by_time(ax, xlabel='Time (s)', decim=decim) - ax.set_ylabel('$t$-value') - ax.legend(loc='lower right') - - if all_t_values.min() < 0 and all_t_values.max() > 0: - # center the y axis around 0 - y_max = np.abs(ax.get_ylim()).max() - ax.set_ylim(ymin=-y_max, ymax=y_max) - elif all_t_values.min() > 0 and all_t_values.max() > 0: - # start y axis at zero - ax.set_ylim(ymin=0, ymax=all_t_values.max()) - elif all_t_values.min() < 0 and all_t_values.max() < 0: - # start y axis at zero - ax.set_ylim(ymin=all_t_values.min(), ymax=0) - - return fig - - -def _plot_decoding_time_generalization( - decoding_data, - metric: str, - kind: Literal['single-subject', 'grand-average'] -): - """Plot time generalization matrix. - """ - import matplotlib.pyplot as plt # nested import to help joblib - - # We squeeze() to make Matplotlib happy. - times = decoding_data['times'].squeeze() - decim = decoding_data['decim'].item() - if kind == 'single-subject': - # take the mean across CV scores - mean_scores = decoding_data['scores'].mean(axis=0) - else: - mean_scores = decoding_data['mean'] - - fig, ax = plt.subplots(constrained_layout=True) - im = ax.imshow( - mean_scores, - extent=times[[0, -1, 0, -1]], - interpolation='nearest', - origin='lower', - cmap='RdBu_r', - vmin=0, - vmax=1 - ) - - # Indicate time point zero - if times.min() < 0 < times.max(): - ax.axvline(0, ls='--', lw=0.5, color='black') - ax.axhline(0, ls='--', lw=0.5, color='black') - - # Indicate diagonal - ax.plot(times[[0, -1]], times[[0, -1]], ls='--', lw=0.5, color='black') - - # Axis labels - _label_time_by_time( - ax, - xlabel='Testing time (s)', - ylabel='Training time (s)', - decim=decim, - ) - - # Color bar - cbar = plt.colorbar(im, ax=ax) - if metric == 'roc_auc': - metric = 'ROC AUC' - cbar.set_label(f'Score ({metric})') - - return fig - - -def _gen_empty_report( - *, - cfg: SimpleNamespace, - subject: str, - session: Optional[str] -) -> mne.Report: - title = f'sub-{subject}' - if session is not None: - title += f', ses-{session}' - if cfg.task is not None: - title += f', task-{cfg.task}' - - report = mne.Report(title=title, raw_psd=True) - return report - - -def _contrasts_to_names(contrasts: List[List[str]]) -> List[str]: - return [f'{c[0]} vs.\n{c[1]}' for c in contrasts] - - -def run_report_preprocessing( - *, - cfg: SimpleNamespace, - subject: str, - session: Optional[str], - report: Optional[mne.Report] -) -> mne.Report: - import matplotlib.pyplot as plt # nested import to help joblib - - msg = 'Generating preprocessing report …' - logger.info( - **gen_log_kwargs(message=msg, subject=subject, session=session) - ) - - if report is None: - report = _gen_empty_report( - cfg=cfg, - subject=subject, - session=session - ) - - bids_path = BIDSPath( - subject=subject, - session=session, - task=cfg.task, - acquisition=cfg.acq, - run=None, - recording=cfg.rec, - space=cfg.space, - extension='.fif', - datatype=cfg.datatype, - root=cfg.deriv_root, - check=False - ) - - fnames_raw_filt = [] - for run in cfg.runs: - fname = bids_path.copy().update( - run=run, processing='filt', - suffix='raw', check=False - ) - fname = _update_for_splits(fname, None, single=True) - fnames_raw_filt.append(fname) - - fname_epo_not_clean = bids_path.copy().update(suffix='epo') - fname_epo_clean = bids_path.copy().update(processing='clean', suffix='epo') - fname_ica = bids_path.copy().update(suffix='ica') - fname_ssp = bids_path.copy().update(suffix='proj') - fname_eog_epochs = bids_path.copy().update(suffix='eog-epo') - fname_ecg_epochs = bids_path.copy().update(suffix='ecg-epo') - - for fname in fnames_raw_filt: - msg = 'Adding filtered raw data to report.' - logger.info( - **gen_log_kwargs( - message=msg, subject=subject, session=session, run=fname.run - ) - ) - - title = 'Raw' - if fname.run is not None: - title += f', run {fname.run}' - - if ( - cfg.plot_psd_for_runs == 'all' or - fname.run in cfg.plot_psd_for_runs - ): - plot_raw_psd = True - else: - plot_raw_psd = False - - report.add_raw( - raw=fname, - title=title, - butterfly=5, - psd=plot_raw_psd, - tags=('raw', 'filtered', f'run-{fname.run}'), - # caption=fname.basename # TODO upstream - ) - del plot_raw_psd - - er_path = get_er_path(cfg=cfg, subject=subject, session=session) - if er_path.fpath.exists(): - msg = 'Adding filtered empty-room raw data to report.' - logger.info( - **gen_log_kwargs( - message=msg, subject=subject, session=session - ) - ) - - report.add_raw( - raw=er_path, - title='Empty-Room', - butterfly=5, - tags=('raw', 'empty-room') - # caption=er_path.basename # TODO upstream - ) - - # Visualize automated noisy channel detection. - if cfg.find_noisy_channels_meg: - msg = 'Adding visualization of noisy channel detection to report.' - logger.info( - **gen_log_kwargs( - message=msg, subject=subject, session=session - ) - ) - - figs, captions = plot_auto_scores_( - cfg=cfg, - subject=subject, - session=session, - ) - - tags = ('raw', 'data-quality', *[f'run-{i}' for i in cfg.runs]) - report.add_figure( - fig=figs, - caption=captions, - title='Data Quality', - tags=tags - ) - for fig in figs: - plt.close(fig) - - # Visualize events. - if not cfg.task_is_rest: - msg = 'Adding events plot to report.' - logger.info( - **gen_log_kwargs( - message=msg, subject=subject, session=session - ) - ) - - events, event_id, sfreq, first_samp = get_events( - cfg=cfg, subject=subject, session=session - ) - report.add_events( - events=events, - event_id=event_id, - sfreq=sfreq, - first_samp=first_samp, - title='Events', - # caption='Events in filtered continuous data', # TODO upstream - ) - - ########################################################################### - # - # Visualize uncleaned epochs. - # - msg = 'Adding uncleaned epochs to report.' - logger.info( - **gen_log_kwargs( - message=msg, subject=subject, session=session - ) - ) - epochs = mne.read_epochs(fname_epo_not_clean) - # Add PSD plots for 30s of data or all epochs if we have less available - if len(epochs) * (epochs.tmax - epochs.tmin) < 30: - psd = True - else: - psd = 30 - report.add_epochs( - epochs=epochs, - title='Epochs: before cleaning', - psd=psd, - drop_log_ignore=() - ) - - ########################################################################### - # - # Visualize effect of ICA artifact rejection. - # - if cfg.spatial_filter == 'ica': - msg = 'Adding ICA to report.' - logger.info( - **gen_log_kwargs( - message=msg, subject=subject, session=session - ) - ) - epochs = mne.read_epochs(fname_epo_not_clean) - ica = mne.preprocessing.read_ica(fname_ica) - ica_reject = _get_reject( - subject=subject, - session=session, - reject=cfg.ica_reject, - ch_types=cfg.ch_types, - param='ica_reject', - ) - # TODO: Ref is set during ICA epochs fitting, we should ensure we do - # it here, too - epochs.drop_bad(ica_reject) - - if ica.exclude: - report.add_ica( - ica=ica, - title='ICA', - inst=epochs, - picks=ica.exclude - # TODO upstream - # captions=f'Evoked response (across all epochs) ' - # f'before and after ICA ' - # f'({len(ica.exclude)} ICs removed)' - ) - - ########################################################################### - # - # Visualize effect of SSP artifact rejection. - # - - if cfg.spatial_filter == 'ssp': - fnames = dict(ecg=fname_ecg_epochs, eog=fname_eog_epochs) - for kind, fname in fnames.items(): - if not fname.fpath.is_file(): - continue - msg = f'Adding {kind.upper()} SSP to report.' - logger.info( - **gen_log_kwargs( - message=msg, subject=subject, session=session - ) - ) - # Eventually we should add this to report somehow - epochs = mne.read_epochs(fname) - projs = mne.read_proj(fname_ssp) - projs = [p for p in projs if kind.upper() in p['desc']] - assert len(projs), len(projs) # should exist if the epochs do - picks_trace = None - if kind == 'ecg': - if 'ecg' in epochs: - picks_trace = 'ecg' - else: - assert kind == 'eog' - if cfg.eog_channels: - picks_trace = cfg.eog_channels - elif 'eog' in epochs: - picks_trace = 'eog' - fig = mne.viz.plot_projs_joint( - projs, epochs.average(picks='all'), picks_trace=picks_trace) - caption = ( - f'Computed using {len(epochs)} epochs ' - f'(from {len(epochs.drop_log)} original events)' - ) - report.add_figure( - fig, title=f'SSP: {kind.upper()}', caption=caption, - tags=('ssp', kind)) - plt.close(fig) - - ########################################################################### - # - # Visualize cleaned epochs. - # - msg = 'Adding cleaned epochs to report.' - logger.info( - **gen_log_kwargs( - message=msg, subject=subject, session=session - ) - ) - epochs = mne.read_epochs(fname_epo_clean) - # Add PSD plots for 30s of data or all epochs if we have less available - if len(epochs) * (epochs.tmax - epochs.tmin) < 30: - psd = True - else: - psd = 30 - report.add_epochs( - epochs=epochs, - title='Epochs: after cleaning', - psd=psd, - drop_log_ignore=() - ) - - return report - - -def run_report_sensor( - *, - cfg: SimpleNamespace, - subject: str, - session: Optional[str], - report: mne.Report -) -> mne.Report: - import matplotlib.pyplot as plt # nested import to help joblib - - msg = 'Generating sensor-space analysis report …' - logger.info( - **gen_log_kwargs(message=msg, subject=subject, session=session) - ) - - if report is None: - report = _gen_empty_report( - cfg=cfg, - subject=subject, - session=session - ) - - bids_path = BIDSPath( - subject=subject, - session=session, - task=cfg.task, - acquisition=cfg.acq, - run=None, - recording=cfg.rec, - space=cfg.space, - extension='.fif', - datatype=cfg.datatype, - root=cfg.deriv_root, - check=False - ) - fname_epo_clean = bids_path.copy().update( - processing='clean', - suffix='epo' - ) - fname_ave = bids_path.copy().update(suffix='ave') - fname_decoding = bids_path.copy().update( - processing=None, - suffix='decoding', - extension='.mat' - ) - fname_tfr_pow = bids_path.copy().update( - suffix='power+condition+tfr', - extension='.h5' - ) - fname_tfr_itc = bids_path.copy().update( - suffix='itc+condition+tfr', - extension='.h5' - ) - fname_noise_cov = get_noise_cov_bids_path( - cfg=cfg, - subject=subject, - session=session - ) - - ########################################################################### - # - # Visualize evoked responses. - # - if cfg.conditions is None: - conditions = [] - elif isinstance(cfg.conditions, dict): - conditions = list(cfg.conditions.keys()) - else: - conditions = cfg.conditions.copy() - - conditions.extend([contrast["name"] for contrast in cfg.all_contrasts]) - - if conditions: - evokeds = mne.read_evokeds(fname_ave) - else: - evokeds = [] - - if evokeds: - msg = (f'Adding {len(conditions)} evoked signals and contrasts to the ' - f'report.') - else: - msg = 'No evoked conditions or contrasts found.' - - logger.info( - **gen_log_kwargs(message=msg, subject=subject, session=session) - ) - - if fname_noise_cov.fpath.exists(): - msg = f'Reading noise covariance: {fname_noise_cov.basename}' - logger.info( - **gen_log_kwargs(message=msg, subject=subject, session=session) - ) - noise_cov = fname_noise_cov - else: - msg = 'No noise covariance matrix found, not rendering whitened data' - logger.info( - **gen_log_kwargs(message=msg, subject=subject, session=session) - ) - noise_cov = None - - for condition, evoked in zip(conditions, evokeds): - _restrict_analyze_channels(evoked, cfg) - - tags = ('evoked', _sanitize_cond_tag(condition)) - if condition in cfg.conditions: - title = f'Condition: {condition}' - else: # It's a contrast of two conditions. - title = f'Contrast: {condition}' - tags = tags + ('contrast',) - - report.add_evokeds( - evokeds=evoked, - titles=title, - noise_cov=noise_cov, - n_time_points=cfg.report_evoked_n_time_points, - tags=tags, - ) - - ########################################################################### - # - # Visualize full-epochs decoding results. - # - decode = cfg.decode and cfg.decoding_contrasts - if decode: - msg = 'Adding full-epochs decoding results to the report.' - logger.info( - **gen_log_kwargs(message=msg, subject=subject, session=session) - ) - - all_decoding_scores = [] - for contrast in cfg.decoding_contrasts: - cond_1, cond_2 = contrast - a_vs_b = f'{cond_1}+{cond_2}'.replace(op.sep, '') - processing = f'{a_vs_b}+FullEpochs+{cfg.decoding_metric}' - processing = processing.replace('_', '-').replace('-', '') - fname_decoding = bids_path.copy().update( - processing=processing, - suffix='decoding', - extension='.mat' - ) - decoding_data = loadmat(fname_decoding) - all_decoding_scores.append( - np.atleast_1d(decoding_data['scores'].squeeze()) - ) - del fname_decoding, processing, a_vs_b, decoding_data - - fig, caption = _plot_full_epochs_decoding_scores( - contrast_names=_contrasts_to_names(cfg.decoding_contrasts), - scores=all_decoding_scores, - metric=cfg.decoding_metric, - ) - title = f'Full-epochs decoding: {cond_1} vs. {cond_2}' - report.add_figure( - fig=fig, - title=title, - caption=caption, - section='Decoding: full-epochs', - tags=( - 'epochs', - 'contrast', - 'decoding', - *[f'{_sanitize_cond_tag(cond_1)}–' - f'{_sanitize_cond_tag(cond_2)}' - for cond_1, cond_2 in cfg.decoding_contrasts] - ) - ) - # close figure to save memory - plt.close(fig) - del fig, caption, title - - ########################################################################### - # - # Visualize time-by-time decoding results. - # - if decode: - msg = 'Adding time-by-time decoding results to the report.' - logger.info( - **gen_log_kwargs(message=msg, subject=subject, session=session) - ) - - epochs = mne.read_epochs(fname_epo_clean) - - section = 'Decoding: time-by-time' - for contrast in cfg.decoding_contrasts: - cond_1, cond_2 = contrast - a_vs_b = f'{cond_1}+{cond_2}'.replace(op.sep, '') - tags = ( - 'epochs', - 'contrast', - 'decoding', - f"{_sanitize_cond_tag(contrast[0])}–" - f"{_sanitize_cond_tag(contrast[1])}" - ) - - processing = f'{a_vs_b}+TimeByTime+{cfg.decoding_metric}' - processing = processing.replace('_', '-').replace('-', '') - fname_decoding = bids_path.copy().update( - processing=processing, - suffix='decoding', - extension='.mat' - ) - decoding_data = loadmat(fname_decoding) - del fname_decoding, processing, a_vs_b - - fig = _plot_time_by_time_decoding_scores( - times=decoding_data['times'].ravel(), - cross_val_scores=decoding_data['scores'], - metric=cfg.decoding_metric, - time_generalization=cfg.decoding_time_generalization, - decim=decoding_data['decim'].item(), - ) - caption = ( - f'Time-by-time decoding: ' - f'{len(epochs[cond_1])} × {cond_1} vs. ' - f'{len(epochs[cond_2])} × {cond_2}' - ) - title = f'Decoding over time: {cond_1} vs. {cond_2}' - report.add_figure( - fig=fig, - title=title, - caption=caption, - section=section, - tags=tags, - ) - plt.close(fig) - - if cfg.decoding_time_generalization: - fig = _plot_decoding_time_generalization( - decoding_data=decoding_data, - metric=cfg.decoding_metric, - kind='single-subject' - ) - caption = ( - 'Time generalization (generalization across time, GAT): ' - 'each classifier is trained on each time point, and ' - 'tested on all other time points.' - ) - title = f'Time generalization: {cond_1} vs. {cond_2}' - report.add_figure( - fig=fig, - title=title, - caption=caption, - section=section, - tags=tags, - ) - plt.close(fig) - - del decoding_data, cond_1, cond_2, caption - - del epochs - - ########################################################################### - # - # Visualize CSP decoding results. - # - - if decode and cfg.decoding_csp: - msg = 'Adding CSP decoding results to the report.' - logger.info( - **gen_log_kwargs(message=msg, subject=subject, session=session) - ) - section = 'Decoding: CSP' - freq_name_to_bins_map = _handle_csp_args( - cfg.decoding_csp_times, - cfg.decoding_csp_freqs, - cfg.decoding_metric, - ) - all_csp_tf_results = dict() - for contrast in cfg.decoding_contrasts: - cond_1, cond_2 = contrast - a_vs_b = f'{cond_1}+{cond_2}'.replace(op.sep, '') - tags = ( - 'epochs', - 'contrast', - 'decoding', - 'csp', - f"{_sanitize_cond_tag(cond_1)}–{_sanitize_cond_tag(cond_2)}" - ) - processing = f'{a_vs_b}+CSP+{cfg.decoding_metric}' - processing = processing.replace('_', '-').replace('-', '') - fname_decoding = bids_path.copy().update( - processing=processing, - suffix='decoding', - extension='.xlsx' - ) - assert fname_decoding.fpath.is_file(), fname_decoding.fpath - csp_freq_results = pd.read_excel( - fname_decoding, - sheet_name='CSP Frequency' - ) - csp_freq_results['scores'] = csp_freq_results['scores'].apply( - lambda x: np.array(x[1:-1].split(), float)) - csp_tf_results = pd.read_excel( - fname_decoding, - sheet_name='CSP Time-Frequency' - ) - csp_tf_results['scores'] = csp_tf_results['scores'].apply( - lambda x: np.array(x[1:-1].split(), float)) - all_csp_tf_results[contrast] = csp_tf_results - del csp_tf_results - - all_decoding_scores = list() - contrast_names = list() - for freq_range_name, freq_bins in freq_name_to_bins_map.items(): - results = csp_freq_results.loc[ - csp_freq_results['freq_range_name'] == freq_range_name - ] - results.reset_index(drop=True, inplace=True) - assert len(results['scores']) == len(freq_bins) - for bi, freq_bin in enumerate(freq_bins): - all_decoding_scores.append(results['scores'][bi]) - f_min = float(freq_bin[0]) - f_max = float(freq_bin[1]) - contrast_names.append( - f'{freq_range_name}\n' - f'({f_min:0.1f}-{f_max:0.1f} Hz)' - ) - fig, caption = _plot_full_epochs_decoding_scores( - contrast_names=contrast_names, - scores=all_decoding_scores, - metric=cfg.decoding_metric, - ) - title = f'CSP decoding: {cond_1} vs. {cond_2}' - report.add_figure( - fig=fig, - title=title, - section=section, - caption=caption, - tags=tags, - ) - # close figure to save memory - plt.close(fig) - del fig, caption, title - - # Now, plot decoding scores across time-frequency bins. - for contrast in cfg.decoding_contrasts: - cond_1, cond_2 = contrast - tags = ( - 'epochs', - 'contrast', - 'decoding', - 'csp', - f"{_sanitize_cond_tag(cond_1)}–{_sanitize_cond_tag(cond_2)}", - ) - results = all_csp_tf_results[contrast] - mean_crossval_scores = list() - tmin, tmax, fmin, fmax = list(), list(), list(), list() - mean_crossval_scores.extend( - results['mean_crossval_score'].ravel()) - tmin.extend(results['t_min'].ravel()) - tmax.extend(results['t_max'].ravel()) - fmin.extend(results['f_min'].ravel()) - fmax.extend(results['f_max'].ravel()) - mean_crossval_scores = np.array(mean_crossval_scores, float) - fig, ax = plt.subplots(constrained_layout=True) - # XXX Add support for more metrics - assert cfg.decoding_metric == 'roc_auc' - metric = 'ROC AUC' - vmax = max( - np.abs(mean_crossval_scores.min() - 0.5), - np.abs(mean_crossval_scores.max() - 0.5) - ) + 0.5 - vmin = 0.5 - (vmax - 0.5) - img = _imshow_tf( - mean_crossval_scores, ax, - tmin=tmin, tmax=tmax, fmin=fmin, fmax=fmax, - vmin=vmin, vmax=vmax) - offset = matplotlib.transforms.offset_copy( - ax.transData, fig, 6, 0, units='points') - for freq_range_name, bins in freq_name_to_bins_map.items(): - ax.text(tmin[0], - 0.5 * bins[0][0] + 0.5 * bins[-1][1], - freq_range_name, transform=offset, - ha='left', va='center', rotation=90) - ax.set_xlim([np.min(tmin), np.max(tmax)]) - ax.set_ylim([np.min(fmin), np.max(fmax)]) - ax.set_xlabel('Time (s)') - ax.set_ylabel('Frequency (Hz)') - cbar = fig.colorbar( - ax=ax, shrink=0.75, orientation='vertical', mappable=img) - cbar.set_label(f'Mean decoding score ({metric})') - title = f'CSP TF decoding: {cond_1} vs. {cond_2}' - report.add_figure( - fig=fig, - title=title, - section=section, - tags=tags, - ) - - ########################################################################### - # - # Visualize TFR as topography. - # - if cfg.time_frequency_conditions is None: - conditions = [] - elif isinstance(cfg.time_frequency_conditions, dict): - conditions = list(cfg.time_frequency_conditions.keys()) - else: - conditions = cfg.time_frequency_conditions.copy() - - if conditions: - msg = 'Adding TFR analysis results to the report.' - logger.info( - **gen_log_kwargs(message=msg, subject=subject, session=session) - ) - - for condition in conditions: - cond = sanitize_cond_name(condition) - fname_tfr_pow_cond = str(fname_tfr_pow.copy()).replace("+condition+", - f"+{cond}+") - fname_tfr_itc_cond = str(fname_tfr_itc.copy()).replace("+condition+", - f"+{cond}+") - with mne.use_log_level('error'): # filename convention - power = mne.time_frequency.read_tfrs( - fname_tfr_pow_cond, condition=0) - power.apply_baseline( - baseline=cfg.time_frequency_baseline, - mode=cfg.time_frequency_baseline_mode) - if cfg.time_frequency_crop: - power.crop(**cfg.time_frequency_crop) - kwargs = dict( - show=False, fig_facecolor='w', font_color='k', border='k' - ) - fig_power = power.plot_topo(**kwargs) - report.add_figure( - fig=fig_power, - title=f'TFR Power: {condition}', - caption=f'TFR Power: {condition}', - tags=('time-frequency', _sanitize_cond_tag(condition)) - ) - plt.close(fig_power) - del power - - with mne.use_log_level('error'): # filename convention - itc = mne.time_frequency.read_tfrs( - fname_tfr_itc_cond, condition=0) - fig_itc = itc.plot_topo(**kwargs) - report.add_figure( - fig=fig_itc, - title=f'TFR ITC: {condition}', - caption=f'TFR Inter-Trial Coherence: {condition}', - tags=('time-frequency', _sanitize_cond_tag(condition)) - ) - plt.close(fig_power) - del itc - - return report - - -def run_report_source( - *, - cfg: SimpleNamespace, - subject: str, - session: Optional[str], - report: mne.Report -) -> mne.Report: - import matplotlib.pyplot as plt # nested import to help joblib - - msg = 'Generating source-space analysis report …' - logger.info( - **gen_log_kwargs(message=msg, subject=subject, session=session) - ) - - if report is None: - report = _gen_empty_report( - cfg=cfg, - subject=subject, - session=session - ) - - bids_path = BIDSPath( - subject=subject, - session=session, - task=cfg.task, - acquisition=cfg.acq, - run=None, - recording=cfg.rec, - space=cfg.space, - extension='.fif', - datatype=cfg.datatype, - root=cfg.deriv_root, - check=False - ) - - # Use this as a source for the Info dictionary - fname_info = bids_path.copy().update( - processing='clean', - suffix='epo' - ) - - fname_trans = bids_path.copy().update(suffix='trans') - if not fname_trans.fpath.exists(): - msg = 'No coregistration found, skipping source space report.' - logger.info(**gen_log_kwargs(message=msg, - subject=subject, session=session)) - return report - - fname_noise_cov = get_noise_cov_bids_path( - cfg=cfg, - subject=subject, - session=session - ) - - ########################################################################### - # - # Visualize coregistration, noise covariance matrix, & inverse solutions. - # - - if cfg.conditions is None: - conditions = [] - elif isinstance(cfg.conditions, dict): - conditions = list(cfg.conditions.keys()) - else: - conditions = cfg.conditions.copy() - - conditions.extend([contrast["name"] for contrast in cfg.all_contrasts]) - - msg = 'Rendering MRI slices with BEM contours.' - logger.info(**gen_log_kwargs(message=msg, - subject=subject, session=session)) - report.add_bem( - subject=cfg.fs_subject, - subjects_dir=cfg.fs_subjects_dir, - title='BEM', - width=256, - decim=8 - ) - - msg = 'Rendering sensor alignment (coregistration).' - logger.info(**gen_log_kwargs(message=msg, - subject=subject, session=session)) - report.add_trans( - trans=fname_trans, - info=fname_info, - title='Sensor alignment', - subject=cfg.fs_subject, - subjects_dir=cfg.fs_subjects_dir, - alpha=1 - ) - - msg = 'Rendering noise covariance matrix and corresponding SVD.' - logger.info(**gen_log_kwargs(message=msg, - subject=subject, session=session)) - report.add_covariance( - cov=fname_noise_cov, - info=fname_info, - title='Noise covariance' - ) - - for condition in conditions: - msg = f'Rendering inverse solution for {condition}' - logger.info(**gen_log_kwargs(message=msg, - subject=subject, session=session)) - - if condition in cfg.conditions: - title = f'Source: {condition}' - else: # It's a contrast of two conditions. - # XXX Will change once we process contrasts here too - continue - - method = cfg.inverse_method - cond_str = sanitize_cond_name(condition) - inverse_str = method - hemi_str = 'hemi' # MNE will auto-append '-lh' and '-rh'. - - fname_stc = bids_path.copy().update( - suffix=f'{cond_str}+{inverse_str}+{hemi_str}', - extension=None) - - tags = ( - 'source-estimate', - _sanitize_cond_tag(condition) - ) - if Path(f'{fname_stc.fpath}-lh.stc').exists(): - report.add_stc( - stc=fname_stc, - title=title, - subject=cfg.fs_subject, - subjects_dir=cfg.fs_subjects_dir, - n_time_points=cfg.report_stc_n_time_points, - tags=tags - ) - - plt.close('all') # close all figures to save memory - return report - - -@failsafe_run(script_path=__file__) -def run_report( - *, - cfg: SimpleNamespace, - subject: str, - session: Optional[str], -): - report = _gen_empty_report( - cfg=cfg, - subject=subject, - session=session - ) - kwargs = dict(cfg=cfg, subject=subject, session=session, report=report) - report = run_report_preprocessing(**kwargs) - report = run_report_sensor(**kwargs) - report = run_report_source(**kwargs) - - ########################################################################### - # - # Add configuration and system info. - # - add_system_info(report) - - ########################################################################### - # - # Save the report. - # - bids_path = BIDSPath( - subject=subject, - session=session, - task=cfg.task, - acquisition=cfg.acq, - run=None, - recording=cfg.rec, - space=cfg.space, - extension='.fif', - datatype=cfg.datatype, - root=cfg.deriv_root, - check=False - ) - fname_report = bids_path.copy().update(suffix='report', extension='.html') - report.save( - fname=fname_report, - open_browser=cfg.interactive, - overwrite=True - ) - - -def add_event_counts(*, - cfg, - subject: Optional[str], - session: Optional[str], - report: mne.Report) -> None: - try: - df_events = count_events(BIDSPath(root=cfg.bids_root, - session=session)) - except ValueError: - msg = 'Could not read events.' - logger.warning( - **gen_log_kwargs(message=msg, subject=subject, session=session) - ) - df_events = None - - if df_events is not None: - css_classes = ('table', 'table-striped', 'table-borderless', - 'table-hover') - report.add_html( - f'
\n' - f'{df_events.to_html(classes=css_classes, border=0)}\n' - f'
', - title='Event counts', - tags=('events',) - ) - css = ('.event-counts {\n' - ' display: -webkit-box;\n' - ' display: -ms-flexbox;\n' - ' display: -webkit-flex;\n' - ' display: flex;\n' - ' justify-content: center;\n' - ' text-align: center;\n' - '}\n\n' - 'th, td {\n' - ' text-align: center;\n' - '}\n') - report.add_custom_css(css=css) - - -def add_system_info(report: mne.Report): - """Add system information and the pipeline configuration to the report.""" - - config_path = Path(os.environ['MNE_BIDS_STUDY_CONFIG']) - report.add_code( - code=config_path, - title='Configuration file', - tags=('configuration',) - ) - report.add_sys_info(title='System information') - - -@failsafe_run(script_path=__file__) -def run_report_average(*, cfg, subject: str, session: str) -> None: - # Group report - import matplotlib.pyplot as plt # nested import to help joblib - - msg = 'Generating grand average report …' - logger.info( - **gen_log_kwargs(message=msg, subject=subject, session=session) - ) - - evoked_fname = BIDSPath( - subject=subject, - session=session, - task=cfg.task, - acquisition=cfg.acq, - run=None, - recording=cfg.rec, - space=cfg.space, - suffix='ave', - extension='.fif', - datatype=cfg.datatype, - root=cfg.deriv_root, - check=False - ) - - title = f'sub-{subject}' - if session is not None: - title += f', ses-{session}' - if cfg.task is not None: - title += f', task-{cfg.task}' - - report = mne.Report( - title=title, - raw_psd=True - ) - evokeds = mne.read_evokeds(evoked_fname) - for evoked in evokeds: - _restrict_analyze_channels(evoked, cfg) - - method = cfg.inverse_method - inverse_str = method - hemi_str = 'hemi' # MNE will auto-append '-lh' and '-rh'. - morph_str = 'morph2fsaverage' - - if isinstance(cfg.conditions, dict): - conditions = list(cfg.conditions.keys()) - else: - conditions = cfg.conditions.copy() - - conditions.extend([contrast["name"] for contrast in cfg.all_contrasts]) - - ####################################################################### - # - # Add event stats. - # - add_event_counts(cfg=cfg, report=report, subject=subject, session=session) - - ####################################################################### - # - # Visualize evoked responses. - # - for condition, evoked in zip(conditions, evokeds): - if condition in cfg.conditions: - title = f'Average: {condition}' - tags = ( - 'evoked', - _sanitize_cond_tag(condition) - ) - else: # It's a contrast of two conditions. - # XXX Will change once we process contrasts here too - continue - - report.add_evokeds( - evokeds=evoked, - titles=title, - projs=False, - tags=tags, - n_time_points=cfg.report_evoked_n_time_points, - # captions=evoked.comment # TODO upstream - ) - - ####################################################################### - # - # Visualize decoding results. - # - if cfg.decode and cfg.decoding_contrasts: - add_decoding_grand_average( - session=session, cfg=cfg, report=report - ) - - if cfg.decode and cfg.decoding_csp: - add_csp_grand_average( - session=session, cfg=cfg, report=report - ) - - ####################################################################### - # - # Visualize forward solution, inverse operator, and inverse solutions. - # - - for condition, evoked in zip(conditions, evokeds): - tags = ( - 'source-estimate', - _sanitize_cond_tag(condition), - ) - if condition in cfg.conditions: - title = f'Average: {condition}' - else: # It's a contrast of two conditions. - title = f'Average contrast: {condition}' - tags = tags + ('contrast',) - - cond_str = sanitize_cond_name(condition) - fname_stc_avg = evoked_fname.copy().update( - suffix=f'{cond_str}+{inverse_str}+{morph_str}+{hemi_str}', - extension=None) - - if Path(f'{fname_stc_avg.fpath}-lh.stc').exists(): - report.add_stc( - stc=fname_stc_avg, - title=title, - subject='fsaverage', - subjects_dir=cfg.fs_subjects_dir, - n_time_points=cfg.report_stc_n_time_points, - tags=tags - ) - - ########################################################################### - # - # Add configuration and system info. - # - add_system_info(report) - - ########################################################################### - # - # Save the report. - # - fname_report = evoked_fname.copy().update( - task=cfg.task, suffix='report', extension='.html') - report.save(fname=fname_report, open_browser=False, overwrite=True) - - plt.close('all') - - -def add_decoding_grand_average( - *, - session: Optional[str], - cfg: SimpleNamespace, - report: mne.Report, -): - """Add decoding results to the grand average report.""" - import matplotlib.pyplot as plt # nested import to help joblib - - bids_path = BIDSPath( - subject='average', - session=session, - task=cfg.task, - acquisition=cfg.acq, - run=None, - recording=cfg.rec, - space=cfg.space, - suffix='ave', - extension='.fif', - datatype=cfg.datatype, - root=cfg.deriv_root, - check=False - ) - - # Full-epochs decoding - all_decoding_scores = [] - for contrast in cfg.decoding_contrasts: - cond_1, cond_2 = contrast - a_vs_b = f'{cond_1}+{cond_2}'.replace(op.sep, '') - processing = f'{a_vs_b}+FullEpochs+{cfg.decoding_metric}' - processing = processing.replace('_', '-').replace('-', '') - fname_decoding = bids_path.copy().update( - processing=processing, - suffix='decoding', - extension='.mat' - ) - decoding_data = loadmat(fname_decoding) - all_decoding_scores.append( - np.atleast_1d(decoding_data['scores'].squeeze()) - ) - del fname_decoding, processing, a_vs_b, decoding_data - - fig, caption = _plot_full_epochs_decoding_scores( - contrast_names=_contrasts_to_names(cfg.decoding_contrasts), - scores=all_decoding_scores, - metric=cfg.decoding_metric, - kind='grand-average' - ) - title = f'Full-epochs decoding: {cond_1} vs. {cond_2}' - report.add_figure( - fig=fig, - title=title, - section='Decoding: full-epochs', - caption=caption, - tags=( - 'epochs', - 'contrast', - 'decoding', - *[f'{_sanitize_cond_tag(cond_1)}–{_sanitize_cond_tag(cond_2)}' - for cond_1, cond_2 in cfg.decoding_contrasts] - ) - ) - # close figure to save memory - plt.close(fig) - del fig, caption, title - - # Time-by-time decoding - for contrast in cfg.decoding_contrasts: - cond_1, cond_2 = contrast - a_vs_b = f'{cond_1}+{cond_2}'.replace(op.sep, '') - section = 'Decoding: time-by-time' - tags = ( - 'epochs', - 'contrast', - 'decoding', - f'{_sanitize_cond_tag(cond_1)}–{_sanitize_cond_tag(cond_2)}' - ) - processing = f'{a_vs_b}+TimeByTime+{cfg.decoding_metric}' - processing = processing.replace('_', '-').replace('-', '') - fname_decoding = bids_path.copy().update( - processing=processing, - suffix='decoding', - extension='.mat' - ) - decoding_data = loadmat(fname_decoding) - del fname_decoding, processing, a_vs_b - - # Plot scores - fig = _plot_time_by_time_decoding_scores_gavg( - cfg=cfg, - decoding_data=decoding_data, - ) - caption = ( - f'Based on N={decoding_data["N"].squeeze()} ' - f'subjects. Standard error and confidence interval ' - f'of the mean were bootstrapped with {cfg.n_boot} ' - f'resamples. CI must not be used for statistical inference here, ' - f'as it is not corrected for multiple testing.' - ) - if len(get_subjects(cfg)) > 1: - caption += ( - f' Time periods with decoding performance significantly above ' - f'chance, if any, were derived with a one-tailed ' - f'cluster-based permutation test ' - f'({decoding_data["cluster_n_permutations"].squeeze()} ' - f'permutations) and are highlighted in yellow.' - ) - title = f'Decoding over time: {cond_1} vs. {cond_2}' - report.add_figure( - fig=fig, - title=title, - caption=caption, - section=section, - tags=tags, - ) - plt.close(fig) - - # Plot t-values used to form clusters - if len(get_subjects(cfg)) > 1: - fig = plot_time_by_time_decoding_t_values( - decoding_data=decoding_data - ) - t_threshold = np.round( - decoding_data['cluster_t_threshold'], - 3 - ).item() - caption = ( - f'Observed t-values. Time points with ' - f't-values > {t_threshold} were used to form clusters.' - ) - report.add_figure( - fig=fig, - title=f't-values across time: {cond_1} vs. {cond_2}', - caption=caption, - section=section, - tags=tags, - ) - plt.close(fig) - - if cfg.decoding_time_generalization: - fig = _plot_decoding_time_generalization( - decoding_data=decoding_data, - metric=cfg.decoding_metric, - kind='grand-average' - ) - caption = ( - f'Time generalization (generalization across time, GAT): ' - f'each classifier is trained on each time point, and tested ' - f'on all other time points. The results were averaged across ' - f'N={decoding_data["N"].item()} subjects.' - ) - title = f'Time generalization: {cond_1} vs. {cond_2}' - report.add_figure( - fig=fig, - title=title, - caption=caption, - section=section, - tags=tags, - ) - plt.close(fig) - - -def _sanitize_cond_tag(cond): - return cond.lower().replace(' ', '-') - - -def _imshow_tf(vals, ax, *, tmin, tmax, fmin, fmax, vmin, vmax, cmap='RdBu_r', - mask=None, cmap_masked=None): - """Plot CSP TF decoding scores.""" - # XXX Add support for more metrics - assert len(vals) == len(tmin) == len(tmax) == len(fmin) == len(fmax) - mask = np.zeros(vals.shape, dtype=bool) if mask is None else mask - assert len(vals) == len(mask) - assert vals.ndim == mask.ndim == 1 - img = None - for v, t1, t2, f1, f2, m in zip(vals, tmin, tmax, fmin, fmax, mask): - use_cmap = cmap_masked if m else cmap - img = ax.imshow( - np.array([[v]], float), - cmap=use_cmap, extent=[t1, t2, f1, f2], aspect='auto', - interpolation='none', origin='lower', vmin=vmin, vmax=vmax, - ) - return img - - -def add_csp_grand_average( - *, - session: str, - cfg: SimpleNamespace, - report: mne.Report, -): - """Add CSP decoding results to the grand average report.""" - import matplotlib.pyplot as plt # nested import to help joblib - - bids_path = BIDSPath( - subject='average', - session=session, - task=cfg.task, - acquisition=cfg.acq, - run=None, - recording=cfg.rec, - space=cfg.space, - suffix='decoding', - datatype=cfg.datatype, - root=cfg.deriv_root, - check=False - ) - - # First, plot decoding scores across frequency bins (entire epochs). - section = 'Decoding: CSP' - freq_name_to_bins_map = _handle_csp_args( - cfg.decoding_csp_times, - cfg.decoding_csp_freqs, - cfg.decoding_metric, - ) - for contrast in cfg.decoding_contrasts: - cond_1, cond_2 = contrast - a_vs_b = f'{cond_1}+{cond_2}'.replace(op.sep, '') - processing = f'{a_vs_b}+CSP+{cfg.decoding_metric}' - processing = processing.replace('_', '-').replace('-', '') - fname_csp_freq_results = bids_path.copy().update( - processing=processing, - extension='.xlsx', - ) - csp_freq_results = pd.read_excel( - fname_csp_freq_results, - sheet_name='CSP Frequency' - ) - - freq_bin_starts = list() - freq_bin_widths = list() - decoding_scores = list() - error_bars = list() - for freq_range_name, freq_bins in freq_name_to_bins_map.items(): - results = csp_freq_results.loc[ - csp_freq_results['freq_range_name'] == freq_range_name, : - ] - results.reset_index(drop=True, inplace=True) - assert len(results) == len(freq_bins) - for bi, freq_bin in enumerate(freq_bins): - freq_bin_starts.append(freq_bin[0]) - freq_bin_widths.append(np.diff(freq_bin)[0]) - decoding_scores.append(results['mean'][bi]) - cis_lower = results['mean_ci_lower'][bi] - cis_upper = results['mean_ci_upper'][bi] - error_bars_lower = decoding_scores[-1] - cis_lower - error_bars_upper = cis_upper - decoding_scores[-1] - error_bars.append( - np.stack([error_bars_lower, error_bars_upper])) - assert len(error_bars[-1]) == 2 # lower, upper - del cis_lower, cis_upper, error_bars_lower, error_bars_upper - error_bars = np.array(error_bars, float).T - - if cfg.decoding_metric == 'roc_auc': - metric = 'ROC AUC' - - fig, ax = plt.subplots(constrained_layout=True) - ax.bar( - x=freq_bin_starts, - width=freq_bin_widths, - height=decoding_scores, - align='edge', - yerr=error_bars, - edgecolor='black', - ) - ax.set_ylim([0, 1.02]) - offset = matplotlib.transforms.offset_copy( - ax.transData, fig, 0, 5, units='points') - for freq_range_name, freq_bins in freq_name_to_bins_map.items(): - start = freq_bins[0][0] - stop = freq_bins[-1][1] - width = stop - start - ax.text( - x=start + width / 2, - y=0., - transform=offset, - s=freq_range_name, - ha='center', - va='bottom', - ) - ax.axhline(0.5, color='black', linestyle='--', label='chance') - ax.legend() - ax.set_xlabel('Frequency (Hz)') - ax.set_ylabel(f'Mean decoding score ({metric})') - tags = ( - 'epochs', - 'contrast', - 'decoding', - 'csp', - f"{_sanitize_cond_tag(cond_1)}–{_sanitize_cond_tag(cond_2)}", - ) - title = f'CSP decoding: {cond_1} vs. {cond_2}' - report.add_figure( - fig=fig, - title=title, - section=section, - caption='Mean decoding scores. Error bars represent ' - 'bootstrapped 95% confidence intervals.', - tags=tags, - ) - - # Now, plot decoding scores across time-frequency bins. - for contrast in cfg.decoding_contrasts: - cond_1, cond_2 = contrast - a_vs_b = f'{cond_1}+{cond_2}'.replace(op.sep, '') - processing = f'{a_vs_b}+CSP+{cfg.decoding_metric}' - processing = processing.replace('_', '-').replace('-', '') - fname_csp_cluster_results = bids_path.copy().update( - processing=processing, - extension='.mat', - ) - csp_cluster_results = loadmat(fname_csp_cluster_results) - - fig, ax = plt.subplots( - nrows=1, ncols=2, sharex=True, sharey=True, - constrained_layout=True) - n_clu = 0 - cbar = None - lims = [np.inf, -np.inf, np.inf, -np.inf] - for freq_range_name, bins in freq_name_to_bins_map.items(): - results = csp_cluster_results[freq_range_name][0][0] - mean_crossval_scores = results['mean_crossval_scores'].ravel() - # t_vals = results['t_vals'] - clusters = results['clusters'] - cluster_p_vals = results['cluster_p_vals'].squeeze() - tmin = results['time_bin_edges'].ravel() - tmin, tmax = tmin[:-1], tmin[1:] - fmin = results['freq_bin_edges'].ravel() - fmin, fmax = fmin[:-1], fmin[1:] - lims[0] = min(lims[0], tmin.min()) - lims[1] = max(lims[1], tmax.max()) - lims[2] = min(lims[2], fmin.min()) - lims[3] = max(lims[3], fmax.max()) - # replicate, matching time-frequency order during clustering - fmin, fmax = np.tile(fmin, len(tmin)), np.tile(fmax, len(tmax)) - tmin, tmax = np.repeat(tmin, len(bins)), np.repeat(tmax, len(bins)) - assert fmin.shape == fmax.shape == tmin.shape == tmax.shape - assert fmin.shape == mean_crossval_scores.shape - cluster_t_threshold = results['cluster_t_threshold'].ravel().item() - - significant_cluster_idx = np.where( - cluster_p_vals < cfg.cluster_permutation_p_threshold - )[0] - significant_clusters = clusters[significant_cluster_idx] - n_clu += len(significant_cluster_idx) - - # XXX Add support for more metrics - assert cfg.decoding_metric == 'roc_auc' - metric = 'ROC AUC' - vmax = max( - np.abs(mean_crossval_scores.min() - 0.5), - np.abs(mean_crossval_scores.max() - 0.5) - ) + 0.5 - vmin = 0.5 - (vmax - 0.5) - # For diverging gray colormap, we need to combine two existing - # colormaps, as there is no diverging colormap with gray/black at - # both endpoints. - from matplotlib.cm import gray, gray_r - from matplotlib.colors import ListedColormap - - black_to_white = gray( - np.linspace(start=0, stop=1, endpoint=False, num=128) - ) - white_to_black = gray_r( - np.linspace(start=0, stop=1, endpoint=False, num=128) - ) - black_to_white_to_black = np.vstack( - (black_to_white, white_to_black) - ) - diverging_gray_cmap = ListedColormap( - black_to_white_to_black, name='DivergingGray' - ) - cmap_gray = diverging_gray_cmap - img = _imshow_tf( - mean_crossval_scores, ax[0], - tmin=tmin, tmax=tmax, fmin=fmin, fmax=fmax, - vmin=vmin, vmax=vmax) - if cbar is None: - ax[0].set_xlabel('Time (s)') - ax[0].set_ylabel('Frequency (Hz)') - ax[1].set_xlabel('Time (s)') - cbar = fig.colorbar( - ax=ax[1], shrink=0.75, orientation='vertical', - mappable=img) - cbar.set_label(f'Mean decoding score ({metric})') - offset = matplotlib.transforms.offset_copy( - ax[0].transData, fig, 6, 0, units='points') - ax[0].text(tmin.min(), - 0.5 * fmin.min() + 0.5 * fmax.max(), - freq_range_name, transform=offset, - ha='left', va='center', rotation=90) - - if len(significant_clusters): - # Create a masked array that only shows the T-values for - # time-frequency bins that belong to significant clusters. - if len(significant_clusters) == 1: - mask = ~significant_clusters[0].astype(bool) - else: - mask = ~np.logical_or( - *significant_clusters - ) - mask = mask.ravel() - else: - mask = np.ones(mean_crossval_scores.shape, dtype=bool) - _imshow_tf( - mean_crossval_scores, ax[1], - tmin=tmin, tmax=tmax, fmin=fmin, fmax=fmax, - vmin=vmin, vmax=vmax, mask=mask, cmap_masked=cmap_gray) - - ax[0].set_xlim(lims[:2]) - ax[0].set_ylim(lims[2:]) - ax[0].set_title('Scores') - ax[1].set_title('Masked') - tags = ( - 'epochs', - 'contrast', - 'decoding', - 'csp', - f"{_sanitize_cond_tag(cond_1)}–{_sanitize_cond_tag(cond_2)}", - ) - title = f'CSP TF decoding: {cond_1} vs. {cond_2}' - report.add_figure( - fig=fig, - title=title, - section=section, - caption=f'Found {n_clu} ' - f'cluster{_pl(n_clu)} with ' - f'p < {cfg.cluster_permutation_p_threshold} ' - f'(clustering bins with absolute t-values > ' - f'{round(cluster_t_threshold, 3)}).', - tags=tags, - ) - - -def get_config( - *, - config, - subject: str, -) -> SimpleNamespace: - # Deal with configurations where `deriv_root` was specified, but not - # `fs_subjects_dir`. We normally raise an exception in this case in - # `get_fs_subjects_dir()`. However, in situations where users only run the - # sensor-space scripts, we never call this function, so everything works - # totally fine at first (which is expected). Yet, when creating the - # reports, the pipeline would fail with an exception – which is - # unjustified, as it would not make sense to force users to provide an - # `fs_subjects_dir` if they don't care about source analysis anyway! So - # simply assign a dummy value in such cases. - # `get_fs_subject()` calls `get_fs_subjects_dir()`, so take care of this - # too. - try: - fs_subjects_dir = get_fs_subjects_dir(config) - except ValueError: - fs_subjects_dir = None - fs_subject = None - else: - fs_subject = get_fs_subject(config=config, subject=subject) - - dtg_decim = config.decoding_time_generalization_decim - cfg = SimpleNamespace( - task=get_task(config), - task_is_rest=config.task_is_rest, - runs=get_runs(config=config, subject=subject), - datatype=get_datatype(config), - acq=config.acq, - rec=config.rec, - space=config.space, - proc=config.proc, - analyze_channels=config.analyze_channels, - find_noisy_channels_meg=config.find_noisy_channels_meg, - h_freq=config.h_freq, - spatial_filter=config.spatial_filter, - conditions=config.conditions, - all_contrasts=get_all_contrasts(config), - decoding_contrasts=get_decoding_contrasts(config), - ica_reject=config.ica_reject, - ch_types=config.ch_types, - time_frequency_conditions=config.time_frequency_conditions, - time_frequency_baseline=config.time_frequency_baseline, - time_frequency_baseline_mode=config.time_frequency_baseline_mode, - time_frequency_crop=config.time_frequency_crop, - decode=config.decode, - decoding_metric=config.decoding_metric, - decoding_time_generalization=config.decoding_time_generalization, - decoding_time_generalization_decim=dtg_decim, - decoding_csp=config.decoding_csp, - decoding_csp_freqs=config.decoding_csp_freqs, - decoding_csp_times=config.decoding_csp_times, - n_boot=config.n_boot, - cluster_permutation_p_threshold=config.cluster_permutation_p_threshold, - cluster_forming_t_threshold=config.cluster_forming_t_threshold, - inverse_method=config.inverse_method, - report_stc_n_time_points=config.report_stc_n_time_points, - report_evoked_n_time_points=config.report_evoked_n_time_points, - fs_subject=fs_subject, - fs_subjects_dir=fs_subjects_dir, - deriv_root=get_deriv_root(config), - bids_root=get_bids_root(config), - use_template_mri=config.use_template_mri, - interactive=config.interactive, - plot_psd_for_runs=config.plot_psd_for_runs, - eog_channels=config.eog_channels, - noise_cov=_sanitize_callable(config.noise_cov), - data_type=config.data_type, - subjects=config.subjects, - exclude_subjects=config.exclude_subjects, - ) - return cfg - - -@contextlib.contextmanager -def _agg_backend(): - import matplotlib - backend = matplotlib.get_backend() - matplotlib.use('Agg', force=True) - try: - yield - finally: - matplotlib.use(backend, force=True) - - -def main(*, config) -> None: - """Make reports.""" - with get_parallel_backend(config), _agg_backend(): - parallel, run_func = parallel_func(run_report, config=config) - sessions = get_sessions(config=config) - logs = parallel( - run_func( - cfg=get_config( - config=config, - subject=subject, - ), - subject=subject, - session=session - ) - for subject in get_subjects(config=config) - for session in sessions - ) - - if config.task_is_rest: - msg = ' … skipping "average" report for "rest" task.' - logger.info(**gen_log_kwargs(message=msg)) - avg_subjects = [] - else: - avg_subjects = ['average'] - - parallel, run_func = parallel_func(run_report_average, config=config) - logs.extend(parallel( - run_func( - cfg=get_config( - config=config, - subject=subject, - ), - subject=subject, - session=session, - ) - for subject in avg_subjects - for session in sessions - )) - save_logs(logs=logs, config=config) diff --git a/mne_bids_pipeline/scripts/report/__init__.py b/mne_bids_pipeline/scripts/report/__init__.py deleted file mode 100644 index 3ed1815d8..000000000 --- a/mne_bids_pipeline/scripts/report/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -"""Analysis report generation.""" - -from . import _01_make_reports - -SCRIPTS = ( - _01_make_reports, -) diff --git a/mne_bids_pipeline/scripts/sensor/_01_make_evoked.py b/mne_bids_pipeline/scripts/sensor/_01_make_evoked.py index 8ffbaac33..189eaeeaa 100644 --- a/mne_bids_pipeline/scripts/sensor/_01_make_evoked.py +++ b/mne_bids_pipeline/scripts/sensor/_01_make_evoked.py @@ -8,11 +8,13 @@ from ..._config_utils import ( get_sessions, get_subjects, get_task, get_datatype, - get_deriv_root, get_all_contrasts, + get_deriv_root, get_all_contrasts, _restrict_analyze_channels, + get_noise_cov_bids_path, ) from ..._logging import gen_log_kwargs, logger -from ..._run import failsafe_run, save_logs from ..._parallel import parallel_func, get_parallel_backend +from ..._report import _open_report, _sanitize_cond_tag +from ..._run import failsafe_run, save_logs, _sanitize_callable def get_input_fnames_evoked(**kwargs): @@ -36,6 +38,14 @@ def get_input_fnames_evoked(**kwargs): check=False) in_files = dict() in_files['epochs'] = fname_epochs + + fname_noise_cov = get_noise_cov_bids_path( + cfg=cfg, + subject=subject, + session=session + ) + if fname_noise_cov.fpath.is_file(): + in_files['noise_cov'] = fname_noise_cov return in_files @@ -87,6 +97,36 @@ def run_evoked(*, cfg, subject, session, in_files): evoked.nave = int(round(evoked.nave)) # avoid a warning mne.write_evokeds(out_files['evoked'], evokeds, overwrite=True) + # Report + if evokeds: + msg = (f'Adding {len(evokeds)} evoked signals and contrasts to the ' + f'report.') + else: + msg = 'No evoked conditions or contrasts found.' + logger.info( + **gen_log_kwargs(message=msg, subject=subject, session=session) + ) + noise_cov = in_files.pop('noise_cov', None) + with _open_report(cfg=cfg, subject=subject, session=session) as report: + for condition, evoked in all_evoked.items(): + _restrict_analyze_channels(evoked, cfg) + + tags = ('evoked', _sanitize_cond_tag(condition)) + if condition in cfg.conditions: + title = f'Condition: {condition}' + else: # It's a contrast of two conditions. + title = f'Contrast: {condition}' + tags = tags + ('contrast',) + + report.add_evokeds( + evokeds=evoked, + titles=title, + noise_cov=noise_cov, + n_time_points=cfg.report_evoked_n_time_points, + tags=tags, + ) + + # Interaction if cfg.interactive: for evoked in evokeds: evoked.plot() @@ -99,6 +139,7 @@ def run_evoked(*, cfg, subject, session, in_files): # evoked.plot_joint(title=condition, ts_args=ts_args, # topomap_args=topomap_args) assert len(in_files) == 0, in_files.keys() + return out_files @@ -116,6 +157,11 @@ def get_config( conditions=config.conditions, contrasts=get_all_contrasts(config), interactive=config.interactive, + proc=config.proc, + noise_cov=_sanitize_callable(config.noise_cov), + analyze_channels=config.analyze_channels, + ch_types=config.ch_types, + report_evoked_n_time_points=config.report_evoked_n_time_points, ) return cfg diff --git a/mne_bids_pipeline/scripts/sensor/_02_decoding_full_epochs.py b/mne_bids_pipeline/scripts/sensor/_02_decoding_full_epochs.py index 702e3b95a..8ae600fe1 100644 --- a/mne_bids_pipeline/scripts/sensor/_02_decoding_full_epochs.py +++ b/mne_bids_pipeline/scripts/sensor/_02_decoding_full_epochs.py @@ -14,7 +14,7 @@ import numpy as np import pandas as pd -from scipy.io import savemat +from scipy.io import savemat, loadmat from sklearn.model_selection import cross_val_score from sklearn.pipeline import make_pipeline @@ -29,9 +29,12 @@ get_deriv_root, _restrict_analyze_channels, get_decoding_contrasts ) from ..._logging import gen_log_kwargs, logger -from ..._run import failsafe_run, save_logs from ..._decoding import LogReg from ..._parallel import parallel_func, get_parallel_backend +from ..._run import failsafe_run, save_logs +from ..._report import ( + _open_report, _contrasts_to_names, _plot_full_epochs_decoding_scores, + _sanitize_cond_tag) def get_input_fnames_epochs_decoding(**kwargs): @@ -64,6 +67,7 @@ def get_input_fnames_epochs_decoding(**kwargs): get_input_fnames=get_input_fnames_epochs_decoding) def run_epochs_decoding(*, cfg, subject, condition1, condition2, session, in_files): + import matplotlib.pyplot as plt msg = f'Contrasting conditions: {condition1} – {condition2}' logger.info(**gen_log_kwargs(message=msg, subject=subject, session=session)) @@ -142,6 +146,57 @@ def run_epochs_decoding(*, cfg, subject, condition1, condition2, session, ) tabular_data = pd.DataFrame(tabular_data).T tabular_data.to_csv(out_files[tsv_key], sep='\t', index=False) + + # Report + with _open_report(cfg=cfg, subject=subject, session=session) as report: + msg = 'Adding full-epochs decoding results to the report.' + logger.info( + **gen_log_kwargs(message=msg, subject=subject, session=session) + ) + + all_decoding_scores = [] + all_contrasts = [] + for contrast in cfg.contrasts: + cond_1, cond_2 = contrast + a_vs_b = f'{cond_1}+{cond_2}'.replace(op.sep, '') + processing = f'{a_vs_b}+FullEpochs+{cfg.decoding_metric}' + processing = processing.replace('_', '-').replace('-', '') + fname_decoding = bids_path.copy().update( + processing=processing, + suffix='decoding', + extension='.mat' + ) + if not fname_decoding.fpath.is_file(): # not done yet + continue + decoding_data = loadmat(fname_decoding) + all_decoding_scores.append( + np.atleast_1d(decoding_data['scores'].squeeze()) + ) + all_contrasts.append(contrast) + del fname_decoding, processing, a_vs_b, decoding_data + + fig, caption = _plot_full_epochs_decoding_scores( + contrast_names=_contrasts_to_names(all_contrasts), + scores=all_decoding_scores, + metric=cfg.decoding_metric, + ) + report.add_figure( + fig=fig, + title='Full-epochs decoding', + caption=caption, + section='Decoding: full-epochs', + tags=( + 'epochs', + 'contrast', + 'decoding', + *[f'{_sanitize_cond_tag(cond_1)}–' + f'{_sanitize_cond_tag(cond_2)}' + for cond_1, cond_2 in cfg.contrasts] + ) + ) + # close figure to save memory + plt.close(fig) + assert len(in_files) == 0, in_files.keys() return out_files diff --git a/mne_bids_pipeline/scripts/sensor/_03_decoding_time_by_time.py b/mne_bids_pipeline/scripts/sensor/_03_decoding_time_by_time.py index 8321a16be..fadfb8d11 100644 --- a/mne_bids_pipeline/scripts/sensor/_03_decoding_time_by_time.py +++ b/mne_bids_pipeline/scripts/sensor/_03_decoding_time_by_time.py @@ -17,7 +17,7 @@ import numpy as np import pandas as pd -from scipy.io import savemat +from scipy.io import savemat, loadmat import mne from mne.decoding import ( @@ -39,6 +39,10 @@ from ..._run import failsafe_run, save_logs from ..._parallel import ( get_parallel_backend, get_n_jobs, get_parallel_backend_name) +from ..._report import ( + _open_report, _plot_decoding_time_generalization, _sanitize_cond_tag, + _plot_time_by_time_decoding_scores, +) def get_input_fnames_time_decoding(**kwargs): @@ -72,6 +76,7 @@ def get_input_fnames_time_decoding(**kwargs): get_input_fnames=get_input_fnames_time_decoding) def run_time_decoding(*, cfg, subject, condition1, condition2, session, in_files): + import matplotlib.pyplot as plt if cfg.decoding_time_generalization: kind = 'time generalization' else: @@ -180,6 +185,83 @@ def run_time_decoding(*, cfg, subject, condition1, condition2, session, ) tabular_data.to_csv( out_files[f'tsv_{processing}'], sep='\t', index=False) + + # Report + with _open_report(cfg=cfg, subject=subject, session=session) as report: + msg = 'Adding time-by-time decoding results to the report.' + logger.info( + **gen_log_kwargs(message=msg, subject=subject, session=session) + ) + + section = 'Decoding: time-by-time' + for contrast in cfg.contrasts: + cond_1, cond_2 = contrast + a_vs_b = f'{cond_1}+{cond_2}'.replace(op.sep, '') + tags = ( + 'epochs', + 'contrast', + 'decoding', + f"{_sanitize_cond_tag(contrast[0])}–" + f"{_sanitize_cond_tag(contrast[1])}" + ) + + processing = f'{a_vs_b}+TimeByTime+{cfg.decoding_metric}' + processing = processing.replace('_', '-').replace('-', '') + fname_decoding = bids_path.copy().update( + processing=processing, + suffix='decoding', + extension='.mat' + ) + if not fname_decoding.fpath.is_file(): + continue + decoding_data = loadmat(fname_decoding) + del fname_decoding, processing, a_vs_b + + fig = _plot_time_by_time_decoding_scores( + times=decoding_data['times'].ravel(), + cross_val_scores=decoding_data['scores'], + metric=cfg.decoding_metric, + time_generalization=cfg.decoding_time_generalization, + decim=decoding_data['decim'].item(), + ) + caption = ( + f'Time-by-time decoding: ' + f'{len(epochs[cond_1])} × {cond_1} vs. ' + f'{len(epochs[cond_2])} × {cond_2}' + ) + title = f'Decoding over time: {cond_1} vs. {cond_2}' + report.add_figure( + fig=fig, + title=title, + caption=caption, + section=section, + tags=tags, + ) + plt.close(fig) + + if cfg.decoding_time_generalization: + fig = _plot_decoding_time_generalization( + decoding_data=decoding_data, + metric=cfg.decoding_metric, + kind='single-subject' + ) + caption = ( + 'Time generalization (generalization across time, GAT): ' + 'each classifier is trained on each time point, and ' + 'tested on all other time points.' + ) + title = f'Time generalization: {cond_1} vs. {cond_2}' + report.add_figure( + fig=fig, + title=title, + caption=caption, + section=section, + tags=tags, + ) + plt.close(fig) + + del decoding_data, cond_1, cond_2, caption + assert len(in_files) == 0, in_files.keys() return out_files diff --git a/mne_bids_pipeline/scripts/sensor/_04_time_frequency.py b/mne_bids_pipeline/scripts/sensor/_04_time_frequency.py index 63b6a5917..a2cafffc7 100644 --- a/mne_bids_pipeline/scripts/sensor/_04_time_frequency.py +++ b/mne_bids_pipeline/scripts/sensor/_04_time_frequency.py @@ -20,6 +20,7 @@ from ..._logging import gen_log_kwargs, logger from ..._run import failsafe_run, save_logs, _script_path from ..._parallel import get_parallel_backend, parallel_func +from ..._report import _open_report, _sanitize_cond_tag def get_input_fnames_time_frequency(**kwargs): @@ -52,6 +53,7 @@ def get_input_fnames_time_frequency(**kwargs): @failsafe_run(script_path=__file__, get_input_fnames=get_input_fnames_time_frequency) def run_time_frequency(*, cfg, subject, session, in_files): + import matplotlib.pyplot as plt msg = f'Input: {in_files["epochs"].basename}' logger.info(**gen_log_kwargs(message=msg, subject=subject, session=session)) @@ -93,6 +95,50 @@ def run_time_frequency(*, cfg, subject, session, in_files): power.save(out_files[power_key], overwrite=True, verbose='error') itc.save(out_files[itc_key], overwrite=True, verbose='error') + # Report + with _open_report(cfg=cfg, subject=subject, session=session) as report: + msg = 'Adding TFR analysis results to the report.' + logger.info( + **gen_log_kwargs(message=msg, subject=subject, session=session) + ) + for condition in cfg.time_frequency_conditions: + cond = sanitize_cond_name(condition) + fname_tfr_pow_cond = out_files[f'power-{cond}'] + fname_tfr_itc_cond = out_files[f'itc-{cond}'] + with mne.use_log_level('error'): # filename convention + power = mne.time_frequency.read_tfrs( + fname_tfr_pow_cond, condition=0) + power.apply_baseline( + baseline=cfg.time_frequency_baseline, + mode=cfg.time_frequency_baseline_mode) + if cfg.time_frequency_crop: + power.crop(**cfg.time_frequency_crop) + kwargs = dict( + show=False, fig_facecolor='w', font_color='k', border='k' + ) + fig_power = power.plot_topo(**kwargs) + report.add_figure( + fig=fig_power, + title=f'TFR Power: {condition}', + caption=f'TFR Power: {condition}', + tags=('time-frequency', _sanitize_cond_tag(condition)) + ) + plt.close(fig_power) + del power + + with mne.use_log_level('error'): # filename convention + itc = mne.time_frequency.read_tfrs( + fname_tfr_itc_cond, condition=0) + fig_itc = itc.plot_topo(**kwargs) + report.add_figure( + fig=fig_itc, + title=f'TFR ITC: {condition}', + caption=f'TFR Inter-Trial Coherence: {condition}', + tags=('time-frequency', _sanitize_cond_tag(condition)) + ) + plt.close(fig_power) + del itc + assert len(in_files) == 0, in_files.keys() return out_files diff --git a/mne_bids_pipeline/scripts/sensor/_05_decoding_csp.py b/mne_bids_pipeline/scripts/sensor/_05_decoding_csp.py index 5cd179c25..b182698cc 100644 --- a/mne_bids_pipeline/scripts/sensor/_05_decoding_csp.py +++ b/mne_bids_pipeline/scripts/sensor/_05_decoding_csp.py @@ -9,7 +9,7 @@ import mne import numpy as np import pandas as pd -from mne import BaseEpochs +import matplotlib.transforms from mne.decoding import CSP, UnsupervisedSpatialFilter from mne_bids import BIDSPath from sklearn.decomposition import PCA @@ -24,11 +24,15 @@ from ..._logging import logger, gen_log_kwargs from ..._parallel import parallel_func, get_parallel_backend from ..._run import failsafe_run, _script_path, save_logs +from ..._report import ( + _open_report, _sanitize_cond_tag, _plot_full_epochs_decoding_scores, + _imshow_tf, +) def _prepare_labels( *, - epochs: BaseEpochs, + epochs: mne.BaseEpochs, contrast: Tuple[str, str] ) -> np.ndarray: """Return the projection of the events_id on a boolean vector. @@ -69,12 +73,12 @@ def _prepare_labels( def prepare_epochs_and_y( *, - epochs: BaseEpochs, + epochs: mne.BaseEpochs, contrast: Tuple[str, str], cfg, fmin: float, fmax: float -) -> Tuple[BaseEpochs, np.ndarray]: +) -> Tuple[mne.BaseEpochs, np.ndarray]: """Band-pass between, sub-select the desired epochs, and prepare y.""" epochs_filt = ( epochs @@ -150,6 +154,7 @@ def one_subject_decoding( 1. The frequency analysis. 2. The time-frequency analysis. """ + import matplotlib.pyplot as plt condition1, condition2 = contrast msg = f'Contrasting conditions: {condition1} – {condition2}' logger.info(**gen_log_kwargs(msg, subject=subject, session=session)) @@ -366,10 +371,146 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tf_decoding_table.to_excel( w, sheet_name='CSP Time-Frequency', index=False ) + out_files = {'csp-excel': fname_results} - assert len(in_files) == 0, in_files.keys() + # Report + with _open_report(cfg=cfg, subject=subject, session=session) as report: + msg = 'Adding CSP decoding results to the report.' + logger.info( + **gen_log_kwargs(message=msg, subject=subject, session=session) + ) + section = 'Decoding: CSP' + freq_name_to_bins_map = _handle_csp_args( + cfg.decoding_csp_times, + cfg.decoding_csp_freqs, + cfg.decoding_metric, + ) + all_csp_tf_results = dict() + for contrast in cfg.decoding_contrasts: + cond_1, cond_2 = contrast + a_vs_b = f'{cond_1}+{cond_2}'.replace(op.sep, '') + tags = ( + 'epochs', + 'contrast', + 'decoding', + 'csp', + f"{_sanitize_cond_tag(cond_1)}–{_sanitize_cond_tag(cond_2)}" + ) + processing = f'{a_vs_b}+CSP+{cfg.decoding_metric}' + processing = processing.replace('_', '-').replace('-', '') + fname_decoding = bids_path.copy().update( + processing=processing, + suffix='decoding', + extension='.xlsx' + ) + if not fname_decoding.fpath.is_file(): + continue # not done yet + csp_freq_results = pd.read_excel( + fname_decoding, + sheet_name='CSP Frequency' + ) + csp_freq_results['scores'] = csp_freq_results['scores'].apply( + lambda x: np.array(x[1:-1].split(), float)) + csp_tf_results = pd.read_excel( + fname_decoding, + sheet_name='CSP Time-Frequency' + ) + csp_tf_results['scores'] = csp_tf_results['scores'].apply( + lambda x: np.array(x[1:-1].split(), float)) + all_csp_tf_results[contrast] = csp_tf_results + del csp_tf_results + + all_decoding_scores = list() + contrast_names = list() + for freq_range_name, freq_bins in freq_name_to_bins_map.items(): + results = csp_freq_results.loc[ + csp_freq_results['freq_range_name'] == freq_range_name + ] + results.reset_index(drop=True, inplace=True) + assert len(results['scores']) == len(freq_bins) + for bi, freq_bin in enumerate(freq_bins): + all_decoding_scores.append(results['scores'][bi]) + f_min = float(freq_bin[0]) + f_max = float(freq_bin[1]) + contrast_names.append( + f'{freq_range_name}\n' + f'({f_min:0.1f}-{f_max:0.1f} Hz)' + ) + fig, caption = _plot_full_epochs_decoding_scores( + contrast_names=contrast_names, + scores=all_decoding_scores, + metric=cfg.decoding_metric, + ) + title = f'CSP decoding: {cond_1} vs. {cond_2}' + report.add_figure( + fig=fig, + title=title, + section=section, + caption=caption, + tags=tags, + ) + # close figure to save memory + plt.close(fig) + del fig, caption, title + + # Now, plot decoding scores across time-frequency bins. + for contrast in cfg.decoding_contrasts: + if contrast not in all_csp_tf_results: + continue + cond_1, cond_2 = contrast + tags = ( + 'epochs', + 'contrast', + 'decoding', + 'csp', + f"{_sanitize_cond_tag(cond_1)}–{_sanitize_cond_tag(cond_2)}", + ) + results = all_csp_tf_results[contrast] + mean_crossval_scores = list() + tmin, tmax, fmin, fmax = list(), list(), list(), list() + mean_crossval_scores.extend( + results['mean_crossval_score'].ravel()) + tmin.extend(results['t_min'].ravel()) + tmax.extend(results['t_max'].ravel()) + fmin.extend(results['f_min'].ravel()) + fmax.extend(results['f_max'].ravel()) + mean_crossval_scores = np.array(mean_crossval_scores, float) + fig, ax = plt.subplots(constrained_layout=True) + # XXX Add support for more metrics + assert cfg.decoding_metric == 'roc_auc' + metric = 'ROC AUC' + vmax = max( + np.abs(mean_crossval_scores.min() - 0.5), + np.abs(mean_crossval_scores.max() - 0.5) + ) + 0.5 + vmin = 0.5 - (vmax - 0.5) + img = _imshow_tf( + mean_crossval_scores, ax, + tmin=tmin, tmax=tmax, fmin=fmin, fmax=fmax, + vmin=vmin, vmax=vmax) + offset = matplotlib.transforms.offset_copy( + ax.transData, fig, 6, 0, units='points') + for freq_range_name, bins in freq_name_to_bins_map.items(): + ax.text(tmin[0], + 0.5 * bins[0][0] + 0.5 * bins[-1][1], + freq_range_name, transform=offset, + ha='left', va='center', rotation=90) + ax.set_xlim([np.min(tmin), np.max(tmax)]) + ax.set_ylim([np.min(fmin), np.max(fmax)]) + ax.set_xlabel('Time (s)') + ax.set_ylabel('Frequency (Hz)') + cbar = fig.colorbar( + ax=ax, shrink=0.75, orientation='vertical', mappable=img) + cbar.set_label(f'Mean decoding score ({metric})') + title = f'CSP TF decoding: {cond_1} vs. {cond_2}' + report.add_figure( + fig=fig, + title=title, + section=section, + tags=tags, + ) - out_files = {'csp-excel': fname_results} + assert len(in_files) == 0, in_files.keys() return out_files diff --git a/mne_bids_pipeline/scripts/sensor/_06_make_cov.py b/mne_bids_pipeline/scripts/sensor/_06_make_cov.py index c48353c7c..c3d96ad54 100644 --- a/mne_bids_pipeline/scripts/sensor/_06_make_cov.py +++ b/mne_bids_pipeline/scripts/sensor/_06_make_cov.py @@ -16,8 +16,9 @@ ) from ..._config_import import _import_config from ..._logging import gen_log_kwargs, logger -from ..._run import failsafe_run, save_logs, _sanitize_callable from ..._parallel import get_parallel_backend, parallel_func +from ..._report import _open_report +from ..._run import failsafe_run, save_logs, _sanitize_callable def get_input_fnames_cov(**kwargs): @@ -29,6 +30,24 @@ def get_input_fnames_cov(**kwargs): # short circuit to say: always re-run cov_type = _get_cov_type(cfg) in_files = dict() + processing = 'clean' if cfg.spatial_filter is not None else None + fname_epochs = BIDSPath(subject=subject, + session=session, + task=cfg.task, + acquisition=cfg.acq, + run=None, + recording=cfg.rec, + space=cfg.space, + extension='.fif', + suffix='epo', + processing=processing, + datatype=cfg.datatype, + root=cfg.deriv_root, + check=False) + in_files['report_info'] = fname_epochs.copy().update( + processing='clean', + suffix='epo' + ) if cov_type == 'custom': in_files['__unknown_inputs__'] = 'custom noise_cov callable' return in_files @@ -57,22 +76,6 @@ def get_input_fnames_cov(**kwargs): in_files['raw'] = bids_path_raw_noise else: assert cov_type == 'epochs', cov_type - processing = None - if cfg.spatial_filter is not None: - processing = 'clean' - fname_epochs = BIDSPath(subject=subject, - session=session, - task=cfg.task, - acquisition=cfg.acq, - run=None, - recording=cfg.rec, - space=cfg.space, - extension='.fif', - suffix='epo', - processing=processing, - datatype=cfg.datatype, - root=cfg.deriv_root, - check=False) in_files['epochs'] = fname_epochs return in_files @@ -183,6 +186,7 @@ def run_covariance(*, cfg, subject, session, in_files): kwargs = dict( cfg=cfg, subject=subject, session=session, in_files=in_files, out_files=out_files) + fname_info = in_files.pop('report_info') if cov_type == 'custom': cov = retrieve_custom_cov(**kwargs) elif cov_type == 'raw': @@ -191,6 +195,20 @@ def run_covariance(*, cfg, subject, session, in_files): tmin, tmax = cfg.noise_cov cov = compute_cov_from_epochs(tmin=tmin, tmax=tmax, **kwargs) cov.save(out_files['cov'], overwrite=True) + + # Report + with _open_report(cfg=cfg, subject=subject, session=session) as report: + msg = 'Rendering noise covariance matrix and corresponding SVD.' + logger.info( + **gen_log_kwargs(message=msg, subject=subject, session=session) + ) + report.add_covariance( + cov=cov, + info=fname_info, + title='Noise covariance' + ) + + assert len(in_files) == 0, in_files return out_files diff --git a/mne_bids_pipeline/scripts/sensor/_99_group_average.py b/mne_bids_pipeline/scripts/sensor/_99_group_average.py index a79d1768f..d60267c5d 100644 --- a/mne_bids_pipeline/scripts/sensor/_99_group_average.py +++ b/mne_bids_pipeline/scripts/sensor/_99_group_average.py @@ -19,11 +19,13 @@ from ..._config_utils import ( get_sessions, get_subjects, get_task, get_datatype, get_deriv_root, get_eeg_reference, get_decoding_contrasts, get_bids_root, + get_all_contrasts, ) from ..._decoding import _handle_csp_args from ..._logging import gen_log_kwargs, logger from ..._parallel import get_parallel_backend, parallel_func from ..._run import failsafe_run, save_logs +from ..._report import run_report_average_sensor def average_evokeds(cfg, session): @@ -155,7 +157,7 @@ def average_time_by_time_decoding( subjects = cfg.subjects del epochs, fname_epo - for contrast in cfg.contrasts: + for contrast in cfg.decoding_contrasts: cond_1, cond_2 = contrast if cfg.decoding_time_generalization: time_points_shape = (len(times), len(times)) @@ -298,7 +300,7 @@ def average_full_epochs_decoding( cfg: SimpleNamespace, session: str ): - for contrast in cfg.contrasts: + for contrast in cfg.decoding_contrasts: cond_1, cond_2 = contrast n_subjects = len(cfg.subjects) @@ -590,7 +592,6 @@ def get_config( proc=config.proc, deriv_root=get_deriv_root(config), conditions=config.conditions, - contrasts=get_decoding_contrasts(config), decode=config.decode, decoding_metric=config.decoding_metric, decoding_n_splits=config.decoding_n_splits, @@ -599,6 +600,7 @@ def get_config( decoding_csp=config.decoding_csp, decoding_csp_freqs=config.decoding_csp_freqs, decoding_csp_times=config.decoding_csp_times, + decoding_contrasts=get_decoding_contrasts(config), random_state=config.random_state, n_boot=config.n_boot, cluster_forming_t_threshold=config.cluster_forming_t_threshold, @@ -614,13 +616,15 @@ def get_config( parallel_backend=config.parallel_backend, N_JOBS=config.N_JOBS, exclude_subjects=config.exclude_subjects, + all_contrasts=get_all_contrasts(config), + report_evoked_n_time_points=config.report_evoked_n_time_points, ) return cfg # pass 'average' subject for logging @failsafe_run(script_path=__file__) -def run_group_average_sensor(*, cfg, subject='average'): +def run_group_average_sensor(*, cfg): if cfg.task_is_rest: msg = ' … skipping: for "rest" task.' logger.info(**gen_log_kwargs(message=msg)) @@ -653,11 +657,12 @@ def run_group_average_sensor(*, cfg, subject='average'): for session in get_sessions(config=cfg) for contrast in get_decoding_contrasts(config=cfg) ) + for session in sessions: + run_report_average_sensor(cfg=cfg, session=session) def main(*, config) -> None: log = run_group_average_sensor( cfg=get_config(config=config), - subject='average', ) save_logs(config=config, logs=[log]) diff --git a/mne_bids_pipeline/scripts/source/_04_make_forward.py b/mne_bids_pipeline/scripts/source/_04_make_forward.py index 0798f0139..3408144aa 100644 --- a/mne_bids_pipeline/scripts/source/_04_make_forward.py +++ b/mne_bids_pipeline/scripts/source/_04_make_forward.py @@ -18,6 +18,7 @@ from ..._config_import import _import_config from ..._logging import logger, gen_log_kwargs from ..._parallel import get_parallel_backend, parallel_func +from ..._report import _open_report from ..._run import failsafe_run, save_logs @@ -155,6 +156,44 @@ def run_forward(*, cfg, subject, session, in_files): out_files['forward'] = bids_path.copy().update(suffix='fwd') mne.write_trans(out_files['trans'], fwd['mri_head_t'], overwrite=True) mne.write_forward_solution(out_files['forward'], fwd, overwrite=True) + + # Report + with _open_report(cfg=cfg, subject=subject, session=session) as report: + msg = 'Rendering MRI slices with BEM contours.' + logger.info( + **gen_log_kwargs(message=msg, subject=subject, session=session) + ) + report.add_bem( + subject=cfg.fs_subject, + subjects_dir=cfg.fs_subjects_dir, + title='BEM', + width=256, + decim=8 + ) + msg = 'Rendering sensor alignment (coregistration).' + logger.info( + **gen_log_kwargs(message=msg, subject=subject, session=session) + ) + report.add_trans( + trans=trans, + info=info, + title='Sensor alignment', + subject=cfg.fs_subject, + subjects_dir=cfg.fs_subjects_dir, + alpha=1 + ) + msg = 'Rendering forward solution.' + logger.info( + **gen_log_kwargs(message=msg, subject=subject, session=session) + ) + report.add_forward( + forward=fwd, + title='Forward solution', + subject=cfg.fs_subject, + subjects_dir=cfg.fs_subjects_dir, + ) + + assert len(in_files) == 0, in_files return out_files diff --git a/mne_bids_pipeline/scripts/source/_05_make_inverse.py b/mne_bids_pipeline/scripts/source/_05_make_inverse.py index 1953b8262..d74d9d559 100644 --- a/mne_bids_pipeline/scripts/source/_05_make_inverse.py +++ b/mne_bids_pipeline/scripts/source/_05_make_inverse.py @@ -17,6 +17,7 @@ get_task, get_datatype, get_deriv_root, get_sessions) from ..._logging import logger, gen_log_kwargs from ..._parallel import get_parallel_backend, parallel_func +from ..._report import _open_report, _sanitize_cond_tag from ..._run import failsafe_run, save_logs, _sanitize_callable @@ -79,18 +80,15 @@ def run_inverse(*, cfg, subject, session, in_files): else: conditions = cfg.conditions + method = cfg.inverse_method if 'evoked' in in_files: fname_ave = in_files.pop('evoked') evokeds = mne.read_evokeds(fname_ave) for condition, evoked in zip(conditions, evokeds): - method = cfg.inverse_method pick_ori = None - cond_str = sanitize_cond_name(condition) - inverse_str = method - hemi_str = 'hemi' # MNE will auto-append '-lh' and '-rh'. - key = f'{cond_str}+{inverse_str}+{hemi_str}' + key = f'{cond_str}+{method}+hemi' out_files[key] = fname_ave.copy().update( suffix=key, extension=None) @@ -106,6 +104,33 @@ def run_inverse(*, cfg, subject, session, in_files): ) stc.save(out_files[key], overwrite=True) out_files[key] = pathlib.Path(str(out_files[key]) + '-lh.stc') + + with _open_report(cfg=cfg, subject=subject, session=session) as report: + for condition in conditions: + cond_str = sanitize_cond_name(condition) + key = f'{cond_str}+{method}+hemi' + if key not in out_files: + continue + msg = f'Rendering inverse solution for {condition}' + logger.info( + **gen_log_kwargs( + message=msg, subject=subject, session=session) + ) + fname_stc = out_files[key] + tags = ( + 'source-estimate', + _sanitize_cond_tag(condition) + ) + report.add_stc( + stc=fname_stc, + title=f'Source: {condition}', + subject=cfg.fs_subject, + subjects_dir=cfg.fs_subjects_dir, + n_time_points=cfg.report_stc_n_time_points, + tags=tags + ) + + assert len(in_files) == 0, in_files return out_files @@ -129,6 +154,7 @@ def get_config( inverse_method=config.inverse_method, deriv_root=get_deriv_root(config), noise_cov=_sanitize_callable(config.noise_cov), + report_stc_n_time_points=config.report_stc_n_time_points, ) return cfg diff --git a/mne_bids_pipeline/scripts/source/_99_group_average.py b/mne_bids_pipeline/scripts/source/_99_group_average.py index 8b550b092..a80f5395c 100644 --- a/mne_bids_pipeline/scripts/source/_99_group_average.py +++ b/mne_bids_pipeline/scripts/source/_99_group_average.py @@ -13,7 +13,8 @@ from ..._config_utils import ( get_fs_subjects_dir, get_subjects, sanitize_cond_name, get_fs_subject, - get_task, get_datatype, get_deriv_root, get_sessions, get_bids_root) + get_task, get_datatype, get_deriv_root, get_sessions, get_bids_root, + get_all_contrasts) from ..._logging import logger, gen_log_kwargs from ..._parallel import get_parallel_backend, parallel_func from ..._run import failsafe_run, save_logs @@ -125,6 +126,8 @@ def get_config( exclude_subjects=config.exclude_subjects, sessions=get_sessions(config), use_template_mri=config.use_template_mri, + all_contrasts=get_all_contrasts(config), + report_stc_n_time_points=config.report_stc_n_time_points, ) return cfg diff --git a/mne_bids_pipeline/tests/run_tests.py b/mne_bids_pipeline/tests/run_tests.py index 8e3b78e83..f4c40c231 100644 --- a/mne_bids_pipeline/tests/run_tests.py +++ b/mne_bids_pipeline/tests/run_tests.py @@ -39,7 +39,7 @@ class TestOptionsT(TypedDict, total=False): # key: { # 'dataset': key.split('_')[0], # 'config': f'config_{key}.py', -# 'steps': ('preprocessing', 'sensor', 'report'), +# 'steps': ('preprocessing', 'sensor'), # 'env': {}, # 'task': None, # } @@ -57,13 +57,13 @@ class TestOptionsT(TypedDict, total=False): 'ds000246': { 'steps': ('preprocessing', 'preprocessing/make_epochs', # Test the group/step syntax - 'sensor', 'report'), + 'sensor'), }, 'ds000247': { 'task': 'rest', }, 'ds000248': { - 'steps': ('preprocessing', 'sensor', 'source', 'report'), + 'steps': ('preprocessing', 'sensor', 'source'), }, 'ds000248_ica': {}, 'ds000248_T1_BEM': { @@ -76,13 +76,13 @@ class TestOptionsT(TypedDict, total=False): 'steps': ('freesurfer/coreg_surfaces',), }, 'ds000248_no_mri': { - 'steps': ('preprocessing', 'sensor', 'source', 'report'), + 'steps': ('preprocessing', 'sensor', 'source'), }, 'ds001810': { - 'steps': ('preprocessing', 'preprocessing', 'sensor', 'report'), + 'steps': ('preprocessing', 'preprocessing', 'sensor'), }, 'ds003104': { - 'steps': ('preprocessing', 'sensor', 'source', 'report'), + 'steps': ('preprocessing', 'sensor', 'source'), }, 'ERP_CORE_N400': { 'dataset': 'ERP_CORE', @@ -181,7 +181,7 @@ def run_tests(test_suite, *, download, debug, cache): # Run the tests. steps = test_options.get( - 'steps', ('preprocessing', 'sensor', 'report')) + 'steps', ('preprocessing', 'sensor')) task = test_options.get('task', None) command = [ 'mne_bids_pipeline', diff --git a/pyproject.toml b/pyproject.toml index fcf25f4e4..1227a4769 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,6 +61,7 @@ tests = [ "openneuro-py >= 2022.2.0", "httpx >= 0.20", "tqdm", + "filelock", "Pygments", "pyyaml", ] From 9d1e41af6bccd05beb6ac9f9d36c2a188c55f601 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Wed, 9 Nov 2022 17:10:22 -0500 Subject: [PATCH 02/10] FIX: Update --- .circleci/setup_bash.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/.circleci/setup_bash.sh b/.circleci/setup_bash.sh index fcdb83e90..2e3731df9 100755 --- a/.circleci/setup_bash.sh +++ b/.circleci/setup_bash.sh @@ -68,6 +68,7 @@ echo "export DOWNLOAD_DATA=\"python $HOME/project/mne_bids_pipeline/tests/downlo # Similar CircleCI setup to mne-python (Xvfb, venv, minimal commands, env vars) wget -q https://raw.githubusercontent.com/mne-tools/mne-python/main/tools/setup_xvfb.sh bash setup_xvfb.sh +sudo apt update sudo apt install -qq tcsh git-annex-standalone python3.10-venv python3-venv libxft2 python3.10 -m venv ~/python_env wget -q https://raw.githubusercontent.com/mne-tools/mne-python/main/tools/get_minimal_commands.sh From f72516bfa162a5b99674fbdea65fd1cfd2948876 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Wed, 9 Nov 2022 17:26:04 -0500 Subject: [PATCH 03/10] FIX: Try again --- mne_bids_pipeline/scripts/sensor/_05_decoding_csp.py | 1 + mne_bids_pipeline/scripts/sensor/_99_group_average.py | 2 ++ mne_bids_pipeline/scripts/source/_05_make_inverse.py | 8 ++++++-- 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/mne_bids_pipeline/scripts/sensor/_05_decoding_csp.py b/mne_bids_pipeline/scripts/sensor/_05_decoding_csp.py index b182698cc..8c7fed4b3 100644 --- a/mne_bids_pipeline/scripts/sensor/_05_decoding_csp.py +++ b/mne_bids_pipeline/scripts/sensor/_05_decoding_csp.py @@ -538,6 +538,7 @@ def get_config( decoding_csp_freqs=config.decoding_csp_freqs, decoding_csp_times=config.decoding_csp_times, decoding_n_splits=config.decoding_n_splits, + decoding_contrasts=get_decoding_contrasts(config), n_boot=config.n_boot, random_state=config.random_state, interactive=config.interactive diff --git a/mne_bids_pipeline/scripts/sensor/_99_group_average.py b/mne_bids_pipeline/scripts/sensor/_99_group_average.py index d60267c5d..21d94b04b 100644 --- a/mne_bids_pipeline/scripts/sensor/_99_group_average.py +++ b/mne_bids_pipeline/scripts/sensor/_99_group_average.py @@ -592,6 +592,7 @@ def get_config( proc=config.proc, deriv_root=get_deriv_root(config), conditions=config.conditions, + contrasts=config.contrasts, decode=config.decode, decoding_metric=config.decoding_metric, decoding_n_splits=config.decoding_n_splits, @@ -618,6 +619,7 @@ def get_config( exclude_subjects=config.exclude_subjects, all_contrasts=get_all_contrasts(config), report_evoked_n_time_points=config.report_evoked_n_time_points, + cluster_permutation_p_threshold=config.cluster_permutation_p_threshold, ) return cfg diff --git a/mne_bids_pipeline/scripts/source/_05_make_inverse.py b/mne_bids_pipeline/scripts/source/_05_make_inverse.py index d74d9d559..45c00ff2f 100644 --- a/mne_bids_pipeline/scripts/source/_05_make_inverse.py +++ b/mne_bids_pipeline/scripts/source/_05_make_inverse.py @@ -14,7 +14,8 @@ from ..._config_utils import ( get_noise_cov_bids_path, get_subjects, sanitize_cond_name, - get_task, get_datatype, get_deriv_root, get_sessions) + get_task, get_datatype, get_deriv_root, get_sessions, + get_fs_subjects_dir, get_fs_subject) from ..._logging import logger, gen_log_kwargs from ..._parallel import get_parallel_backend, parallel_func from ..._report import _open_report, _sanitize_cond_tag @@ -137,6 +138,7 @@ def run_inverse(*, cfg, subject, session, in_files): def get_config( *, config, + subject, ) -> SimpleNamespace: cfg = SimpleNamespace( task=get_task(config), @@ -155,6 +157,8 @@ def get_config( deriv_root=get_deriv_root(config), noise_cov=_sanitize_callable(config.noise_cov), report_stc_n_time_points=config.report_stc_n_time_points, + fs_subject=get_fs_subject(config=config, subject=subject), + fs_subjects_dir=get_fs_subjects_dir(config), ) return cfg @@ -170,7 +174,7 @@ def main(*, config) -> None: parallel, run_func = parallel_func(run_inverse, config=config) logs = parallel( run_func( - cfg=get_config(config=config), + cfg=get_config(config=config, subject=subject), subject=subject, session=session, ) From d9a21edd068821b1a23da90520cf6b0f4abdd280 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Wed, 9 Nov 2022 17:57:42 -0500 Subject: [PATCH 04/10] FIX: Flake --- .../scripts/sensor/_03_decoding_time_by_time.py | 9 +++++++-- mne_bids_pipeline/scripts/sensor/_04_time_frequency.py | 5 ++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/mne_bids_pipeline/scripts/sensor/_03_decoding_time_by_time.py b/mne_bids_pipeline/scripts/sensor/_03_decoding_time_by_time.py index fadfb8d11..97f8b1384 100644 --- a/mne_bids_pipeline/scripts/sensor/_03_decoding_time_by_time.py +++ b/mne_bids_pipeline/scripts/sensor/_03_decoding_time_by_time.py @@ -98,6 +98,11 @@ def run_time_decoding(*, cfg, subject, condition1, condition2, session, else: epochs_conds = cond_names = [condition1, condition2] epochs_conds = [condition1, condition2] + epoch_counts = dict() + for contrast in cfg.contrasts: + for cond in contrast: + if cond not in epoch_counts: + epoch_counts[cond] = len(epochs[cond]) # We have to use this approach because the conditions could be based on # metadata selection, so simply using epochs[conds[0], conds[1]] would @@ -226,8 +231,8 @@ def run_time_decoding(*, cfg, subject, condition1, condition2, session, ) caption = ( f'Time-by-time decoding: ' - f'{len(epochs[cond_1])} × {cond_1} vs. ' - f'{len(epochs[cond_2])} × {cond_2}' + f'{epoch_counts[cond_1]} × {cond_1} vs. ' + f'{epoch_counts[cond_2]} × {cond_2}' ) title = f'Decoding over time: {cond_1} vs. {cond_2}' report.add_figure( diff --git a/mne_bids_pipeline/scripts/sensor/_04_time_frequency.py b/mne_bids_pipeline/scripts/sensor/_04_time_frequency.py index a2cafffc7..b28146158 100644 --- a/mne_bids_pipeline/scripts/sensor/_04_time_frequency.py +++ b/mne_bids_pipeline/scripts/sensor/_04_time_frequency.py @@ -162,7 +162,10 @@ def get_config( time_frequency_freq_min=config.time_frequency_freq_min, time_frequency_freq_max=config.time_frequency_freq_max, time_frequency_cycles=config.time_frequency_cycles, - time_frequency_subtract_evoked=config.time_frequency_subtract_evoked + time_frequency_subtract_evoked=config.time_frequency_subtract_evoked, + time_frequency_baseline=config.time_frequency_baseline, + time_frequency_baseline_mode=config.time_frequency_baseline_mode, + time_frequency_crop=config.time_frequency_crop, ) return cfg From 98d036f9b114ef28741913f55804472191c176d8 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Thu, 10 Nov 2022 15:12:09 -0500 Subject: [PATCH 05/10] FIX: Replace --- mne_bids_pipeline/_config_import.py | 4 +++ mne_bids_pipeline/_import_data.py | 9 +++++-- mne_bids_pipeline/_report.py | 25 +++++++++++++------ mne_bids_pipeline/config.py | 4 +-- .../preprocessing/_02_frequency_filter.py | 6 +++-- .../scripts/preprocessing/_03_make_epochs.py | 6 +++-- .../scripts/preprocessing/_04a_run_ica.py | 6 +++-- .../scripts/preprocessing/_04b_run_ssp.py | 8 ++++-- .../scripts/preprocessing/_05a_apply_ica.py | 6 +++-- .../scripts/preprocessing/_06_ptp_reject.py | 3 ++- .../scripts/sensor/_01_make_evoked.py | 1 + .../sensor/_02_decoding_full_epochs.py | 3 ++- .../sensor/_03_decoding_time_by_time.py | 2 ++ .../scripts/sensor/_04_time_frequency.py | 6 +++-- .../scripts/sensor/_05_decoding_csp.py | 2 ++ .../scripts/sensor/_06_make_cov.py | 3 ++- .../scripts/source/_04_make_forward.py | 7 ++++-- .../scripts/source/_05_make_inverse.py | 3 ++- .../tests/configs/config_ERP_CORE.py | 2 +- 19 files changed, 76 insertions(+), 30 deletions(-) diff --git a/mne_bids_pipeline/_config_import.py b/mne_bids_pipeline/_config_import.py index b177e2e06..3ea88182b 100644 --- a/mne_bids_pipeline/_config_import.py +++ b/mne_bids_pipeline/_config_import.py @@ -179,3 +179,7 @@ def _check_config(config: ModuleType) -> None: 'Please indicate the name of your conditions in your ' 'configuration. Currently the `conditions` parameter is empty. ' 'This is only allowed for resting-state analysis.') + + _check_option( + 'config.on_rename_missing_events', config.on_rename_missing_events, + ('raise', 'warn', 'ignore')) diff --git a/mne_bids_pipeline/_import_data.py b/mne_bids_pipeline/_import_data.py index 8b91597e2..9a889d2f9 100644 --- a/mne_bids_pipeline/_import_data.py +++ b/mne_bids_pipeline/_import_data.py @@ -184,9 +184,14 @@ def _rename_events_func( f'they are not present in the BIDS input data:\n' f'{", ".join(sorted(list(events_not_in_raw)))}') if cfg.on_rename_missing_events == 'warn': - logger.warning(msg) - else: + logger.warning( + **gen_log_kwargs(message=msg, subject=subject, session=session) + ) + elif cfg.on_rename_missing_events == 'raise': raise ValueError(msg) + else: + # should be guaranteed + assert cfg.on_rename_missing_events == 'ignore' # Do the actual event renaming. msg = 'Renaming events …' diff --git a/mne_bids_pipeline/_report.py b/mne_bids_pipeline/_report.py index c1f1f2c8e..576728200 100644 --- a/mne_bids_pipeline/_report.py +++ b/mne_bids_pipeline/_report.py @@ -480,7 +480,8 @@ def add_event_counts(*, f'{df_events.to_html(classes=css_classes, border=0)}\n' f'', title='Event counts', - tags=('events',) + tags=('events',), + replace=True, ) css = ('.event-counts {\n' ' display: -webkit-box;\n' @@ -493,7 +494,8 @@ def add_event_counts(*, 'th, td {\n' ' text-align: center;\n' '}\n') - report.add_custom_css(css=css) + if css not in report.include: + report.add_custom_css(css=css) def add_system_info(report: mne.Report): @@ -503,9 +505,10 @@ def add_system_info(report: mne.Report): report.add_code( code=config_path, title='Configuration file', - tags=('configuration',) + tags=('configuration',), + replace=True, ) - report.add_sys_info(title='System information') + report.add_sys_info(title='System information', replace=True) def _all_conditions(*, cfg): @@ -582,7 +585,8 @@ def run_report_average_sensor(*, cfg, session: str) -> None: projs=False, tags=tags, n_time_points=cfg.report_evoked_n_time_points, - # captions=evoked.comment # TODO upstream + # captions=evoked.comment, # TODO upstream + replace=True, ) ####################################################################### @@ -649,7 +653,8 @@ def run_report_average_source(*, cfg, session: str) -> None: subject='fsaverage', subjects_dir=cfg.fs_subjects_dir, n_time_points=cfg.report_stc_n_time_points, - tags=tags + tags=tags, + replace=True, ) @@ -713,7 +718,8 @@ def add_decoding_grand_average( 'decoding', *[f'{_sanitize_cond_tag(cond_1)}–{_sanitize_cond_tag(cond_2)}' for cond_1, cond_2 in cfg.decoding_contrasts] - ) + ), + replace=True, ) # close figure to save memory plt.close(fig) @@ -767,6 +773,7 @@ def add_decoding_grand_average( caption=caption, section=section, tags=tags, + replace=True, ) plt.close(fig) @@ -789,6 +796,7 @@ def add_decoding_grand_average( caption=caption, section=section, tags=tags, + replace=True, ) plt.close(fig) @@ -811,6 +819,7 @@ def add_decoding_grand_average( caption=caption, section=section, tags=tags, + replace=True, ) plt.close(fig) @@ -952,6 +961,7 @@ def add_csp_grand_average( caption='Mean decoding scores. Error bars represent ' 'bootstrapped 95% confidence intervals.', tags=tags, + replace=True, ) # Now, plot decoding scores across time-frequency bins. @@ -1084,6 +1094,7 @@ def add_csp_grand_average( f'(clustering bins with absolute t-values > ' f'{round(cluster_t_threshold, 3)}).', tags=tags, + replace=True, ) diff --git a/mne_bids_pipeline/config.py b/mne_bids_pipeline/config.py index 8647f6c16..8f793e254 100644 --- a/mne_bids_pipeline/config.py +++ b/mne_bids_pipeline/config.py @@ -765,13 +765,13 @@ ``` """ -on_rename_missing_events: Literal['warn', 'raise'] = 'raise' +on_rename_missing_events: Literal['warn', 'raise', 'ignore'] = 'raise' """ How to handle the situation where you specified an event to be renamed via ``rename_events``, but this particular event is not present in the data. By default, we will raise an exception to avoid accidental mistakes due to typos; however, if you're sure what you're doing, you may change this to ``'warn'`` -to only get a warning instead. +to only get a warning instead, or ``'ignore'`` to ignore it completely. """ ############################################################################### diff --git a/mne_bids_pipeline/scripts/preprocessing/_02_frequency_filter.py b/mne_bids_pipeline/scripts/preprocessing/_02_frequency_filter.py index 92c23a8ec..d55f655cf 100644 --- a/mne_bids_pipeline/scripts/preprocessing/_02_frequency_filter.py +++ b/mne_bids_pipeline/scripts/preprocessing/_02_frequency_filter.py @@ -288,7 +288,8 @@ def filter_data( fig=figs, caption=captions, title='Data Quality', - tags=tags + tags=tags, + replace=True, ) for fig in figs: plt.close(fig) @@ -313,7 +314,8 @@ def filter_data( butterfly=5, psd=plot_raw_psd, tags=('raw', 'filtered', f'run-{fname.run}'), - # caption=fname.basename # TODO upstream + # caption=fname.basename, # TODO upstream + replace=True, ) return out_files diff --git a/mne_bids_pipeline/scripts/preprocessing/_03_make_epochs.py b/mne_bids_pipeline/scripts/preprocessing/_03_make_epochs.py index 3afb5d91d..bf2be43ac 100644 --- a/mne_bids_pipeline/scripts/preprocessing/_03_make_epochs.py +++ b/mne_bids_pipeline/scripts/preprocessing/_03_make_epochs.py @@ -234,7 +234,8 @@ def run_epochs(*, cfg, subject, session, in_files): sfreq=sfreq, first_samp=first_samp, title='Events', - # caption='Events in filtered continuous data' # TODO upstream + # caption='Events in filtered continuous data', # TODO upstr + replace=True, ) msg = 'Adding uncleaned epochs to report.' logger.info( @@ -251,7 +252,8 @@ def run_epochs(*, cfg, subject, session, in_files): epochs=epochs, title='Epochs: before cleaning', psd=psd, - drop_log_ignore=() + drop_log_ignore=(), + replace=True, ) # Interactive diff --git a/mne_bids_pipeline/scripts/preprocessing/_04a_run_ica.py b/mne_bids_pipeline/scripts/preprocessing/_04a_run_ica.py index b5c8496b8..62420e538 100644 --- a/mne_bids_pipeline/scripts/preprocessing/_04a_run_ica.py +++ b/mne_bids_pipeline/scripts/preprocessing/_04a_run_ica.py @@ -488,7 +488,8 @@ def run_ica(*, cfg, subject, session, in_files): report.add_epochs( epochs=epochs, title='Epochs used for ICA fitting', - drop_log_ignore=() + drop_log_ignore=(), + replace=True, ) ecg_evoked = None if epochs_ecg is None else epochs_ecg.average() @@ -503,7 +504,8 @@ def run_ica(*, cfg, subject, session, in_files): ecg_evoked=ecg_evoked, eog_evoked=eog_evoked, ecg_scores=ecg_scores, - eog_scores=eog_scores + eog_scores=eog_scores, + replace=True, ) msg = (f"ICA completed. Please carefully review the extracted ICs in the " diff --git a/mne_bids_pipeline/scripts/preprocessing/_04b_run_ssp.py b/mne_bids_pipeline/scripts/preprocessing/_04b_run_ssp.py index f3999f9fd..0526f8dd2 100644 --- a/mne_bids_pipeline/scripts/preprocessing/_04b_run_ssp.py +++ b/mne_bids_pipeline/scripts/preprocessing/_04b_run_ssp.py @@ -168,8 +168,12 @@ def run_ssp(*, cfg, subject, session, in_files): f'(from {len(proj_epochs.drop_log)} original events)' ) report.add_figure( - fig, title=f'SSP: {kind.upper()}', caption=caption, - tags=('ssp', kind)) + fig, + title=f'SSP: {kind.upper()}', + caption=caption, + tags=('ssp', kind), + replace=True, + ) plt.close(fig) return out_files diff --git a/mne_bids_pipeline/scripts/preprocessing/_05a_apply_ica.py b/mne_bids_pipeline/scripts/preprocessing/_05a_apply_ica.py index 9fd7d5013..7e7a8cad2 100644 --- a/mne_bids_pipeline/scripts/preprocessing/_05a_apply_ica.py +++ b/mne_bids_pipeline/scripts/preprocessing/_05a_apply_ica.py @@ -130,7 +130,8 @@ def apply_ica(*, cfg, subject, session, in_files): ica=ica, title='Effects of ICA cleaning', inst=epochs.copy().apply_baseline(cfg.baseline), - picks=picks + picks=picks, + replace=True, ) report.save( out_files['report'], overwrite=True, open_browser=cfg.interactive) @@ -151,11 +152,12 @@ def apply_ica(*, cfg, subject, session, in_files): ica=ica, title='ICA', inst=epochs, - picks=ica.exclude + picks=ica.exclude, # TODO upstream # captions=f'Evoked response (across all epochs) ' # f'before and after ICA ' # f'({len(ica.exclude)} ICs removed)' + replace=True, ) return out_files diff --git a/mne_bids_pipeline/scripts/preprocessing/_06_ptp_reject.py b/mne_bids_pipeline/scripts/preprocessing/_06_ptp_reject.py index b9292eeff..fe43ad6bd 100644 --- a/mne_bids_pipeline/scripts/preprocessing/_06_ptp_reject.py +++ b/mne_bids_pipeline/scripts/preprocessing/_06_ptp_reject.py @@ -138,7 +138,8 @@ def drop_ptp(*, cfg, subject, session, in_files): epochs=epochs, title='Epochs: after cleaning', psd=psd, - drop_log_ignore=() + drop_log_ignore=(), + replace=True, ) return out_files diff --git a/mne_bids_pipeline/scripts/sensor/_01_make_evoked.py b/mne_bids_pipeline/scripts/sensor/_01_make_evoked.py index 189eaeeaa..ca48a6939 100644 --- a/mne_bids_pipeline/scripts/sensor/_01_make_evoked.py +++ b/mne_bids_pipeline/scripts/sensor/_01_make_evoked.py @@ -124,6 +124,7 @@ def run_evoked(*, cfg, subject, session, in_files): noise_cov=noise_cov, n_time_points=cfg.report_evoked_n_time_points, tags=tags, + replace=True, ) # Interaction diff --git a/mne_bids_pipeline/scripts/sensor/_02_decoding_full_epochs.py b/mne_bids_pipeline/scripts/sensor/_02_decoding_full_epochs.py index 8ae600fe1..f7977c44f 100644 --- a/mne_bids_pipeline/scripts/sensor/_02_decoding_full_epochs.py +++ b/mne_bids_pipeline/scripts/sensor/_02_decoding_full_epochs.py @@ -192,7 +192,8 @@ def run_epochs_decoding(*, cfg, subject, condition1, condition2, session, *[f'{_sanitize_cond_tag(cond_1)}–' f'{_sanitize_cond_tag(cond_2)}' for cond_1, cond_2 in cfg.contrasts] - ) + ), + replace=True, ) # close figure to save memory plt.close(fig) diff --git a/mne_bids_pipeline/scripts/sensor/_03_decoding_time_by_time.py b/mne_bids_pipeline/scripts/sensor/_03_decoding_time_by_time.py index 97f8b1384..42f564ca2 100644 --- a/mne_bids_pipeline/scripts/sensor/_03_decoding_time_by_time.py +++ b/mne_bids_pipeline/scripts/sensor/_03_decoding_time_by_time.py @@ -241,6 +241,7 @@ def run_time_decoding(*, cfg, subject, condition1, condition2, session, caption=caption, section=section, tags=tags, + replace=True, ) plt.close(fig) @@ -262,6 +263,7 @@ def run_time_decoding(*, cfg, subject, condition1, condition2, session, caption=caption, section=section, tags=tags, + replace=True, ) plt.close(fig) diff --git a/mne_bids_pipeline/scripts/sensor/_04_time_frequency.py b/mne_bids_pipeline/scripts/sensor/_04_time_frequency.py index b28146158..89c2fff25 100644 --- a/mne_bids_pipeline/scripts/sensor/_04_time_frequency.py +++ b/mne_bids_pipeline/scripts/sensor/_04_time_frequency.py @@ -121,7 +121,8 @@ def run_time_frequency(*, cfg, subject, session, in_files): fig=fig_power, title=f'TFR Power: {condition}', caption=f'TFR Power: {condition}', - tags=('time-frequency', _sanitize_cond_tag(condition)) + tags=('time-frequency', _sanitize_cond_tag(condition)), + replace=True, ) plt.close(fig_power) del power @@ -134,7 +135,8 @@ def run_time_frequency(*, cfg, subject, session, in_files): fig=fig_itc, title=f'TFR ITC: {condition}', caption=f'TFR Inter-Trial Coherence: {condition}', - tags=('time-frequency', _sanitize_cond_tag(condition)) + tags=('time-frequency', _sanitize_cond_tag(condition)), + replace=True, ) plt.close(fig_power) del itc diff --git a/mne_bids_pipeline/scripts/sensor/_05_decoding_csp.py b/mne_bids_pipeline/scripts/sensor/_05_decoding_csp.py index 8c7fed4b3..eb956c55a 100644 --- a/mne_bids_pipeline/scripts/sensor/_05_decoding_csp.py +++ b/mne_bids_pipeline/scripts/sensor/_05_decoding_csp.py @@ -448,6 +448,7 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, section=section, caption=caption, tags=tags, + replace=True, ) # close figure to save memory plt.close(fig) @@ -508,6 +509,7 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, title=title, section=section, tags=tags, + replace=True, ) assert len(in_files) == 0, in_files.keys() diff --git a/mne_bids_pipeline/scripts/sensor/_06_make_cov.py b/mne_bids_pipeline/scripts/sensor/_06_make_cov.py index c3d96ad54..4080d5c7d 100644 --- a/mne_bids_pipeline/scripts/sensor/_06_make_cov.py +++ b/mne_bids_pipeline/scripts/sensor/_06_make_cov.py @@ -205,7 +205,8 @@ def run_covariance(*, cfg, subject, session, in_files): report.add_covariance( cov=cov, info=fname_info, - title='Noise covariance' + title='Noise covariance', + replace=True, ) assert len(in_files) == 0, in_files diff --git a/mne_bids_pipeline/scripts/source/_04_make_forward.py b/mne_bids_pipeline/scripts/source/_04_make_forward.py index 3408144aa..3d0fbda6a 100644 --- a/mne_bids_pipeline/scripts/source/_04_make_forward.py +++ b/mne_bids_pipeline/scripts/source/_04_make_forward.py @@ -168,7 +168,8 @@ def run_forward(*, cfg, subject, session, in_files): subjects_dir=cfg.fs_subjects_dir, title='BEM', width=256, - decim=8 + decim=8, + replace=True, ) msg = 'Rendering sensor alignment (coregistration).' logger.info( @@ -180,7 +181,8 @@ def run_forward(*, cfg, subject, session, in_files): title='Sensor alignment', subject=cfg.fs_subject, subjects_dir=cfg.fs_subjects_dir, - alpha=1 + alpha=1, + replace=True, ) msg = 'Rendering forward solution.' logger.info( @@ -191,6 +193,7 @@ def run_forward(*, cfg, subject, session, in_files): title='Forward solution', subject=cfg.fs_subject, subjects_dir=cfg.fs_subjects_dir, + replace=True, ) assert len(in_files) == 0, in_files diff --git a/mne_bids_pipeline/scripts/source/_05_make_inverse.py b/mne_bids_pipeline/scripts/source/_05_make_inverse.py index 45c00ff2f..304baca8e 100644 --- a/mne_bids_pipeline/scripts/source/_05_make_inverse.py +++ b/mne_bids_pipeline/scripts/source/_05_make_inverse.py @@ -128,7 +128,8 @@ def run_inverse(*, cfg, subject, session, in_files): subject=cfg.fs_subject, subjects_dir=cfg.fs_subjects_dir, n_time_points=cfg.report_stc_n_time_points, - tags=tags + tags=tags, + replace=True, ) assert len(in_files) == 0, in_files diff --git a/mne_bids_pipeline/tests/configs/config_ERP_CORE.py b/mne_bids_pipeline/tests/configs/config_ERP_CORE.py index c1b333a6d..7b33c8b07 100644 --- a/mne_bids_pipeline/tests/configs/config_ERP_CORE.py +++ b/mne_bids_pipeline/tests/configs/config_ERP_CORE.py @@ -69,7 +69,7 @@ run_source_estimation = False on_error = 'abort' -on_rename_missing_events = 'warn' +on_rename_missing_events = 'ignore' parallel_backend = 'dask' dask_worker_memory_limit = '2G' From fdd1bf6174765601203a1a89ce029666e72fbd8a Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Thu, 10 Nov 2022 15:55:56 -0500 Subject: [PATCH 06/10] FIX: Append --- mne_bids_pipeline/_report.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/mne_bids_pipeline/_report.py b/mne_bids_pipeline/_report.py index 576728200..ce0fb5b82 100644 --- a/mne_bids_pipeline/_report.py +++ b/mne_bids_pipeline/_report.py @@ -500,15 +500,20 @@ def add_event_counts(*, def add_system_info(report: mne.Report): """Add system information and the pipeline configuration to the report.""" - config_path = Path(os.environ['MNE_BIDS_STUDY_CONFIG']) + # ensure they are always appended + titles = ['Configuration file', 'System information'] + for title in titles: + report.remove(title=title, remove_all=True) + # No longer need replace=True in these report.add_code( code=config_path, - title='Configuration file', + title=titles[0], tags=('configuration',), - replace=True, ) - report.add_sys_info(title='System information', replace=True) + report.add_sys_info( + title=titles[1], + ) def _all_conditions(*, cfg): From 55063c74d294c6022940d36a652bc17236f58ec0 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Thu, 10 Nov 2022 16:07:45 -0500 Subject: [PATCH 07/10] FIX: Bug --- .circleci/config.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.circleci/config.yml b/.circleci/config.yml index 9b5a8f135..4f653e515 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -45,6 +45,7 @@ jobs: command: | pip install --upgrade --progress-bar off pip setuptools pip install -ve .[tests,dev] + pip install https://github.com/larsoner/mne-python/zipball/report # fix Report bug pip install PyQt6 - run: name: Check Qt From d797adf6f99d94132c0c36ef209f49cc290ad616 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Thu, 10 Nov 2022 17:44:09 -0500 Subject: [PATCH 08/10] FIX: Parse --- docs/source/changes.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/changes.md b/docs/source/changes.md index 5ee882fa3..9029c5a6f 100644 --- a/docs/source/changes.md +++ b/docs/source/changes.md @@ -204,7 +204,7 @@ authors: ({{ gh(653) }} by {{ authors.hoechenberger }}) - Make report generation happen within relevant steps instead of at the end of all steps - ({{ gh(652) by {{ authors.larsoner }}) + ({{ gh(652) }} by {{ authors.larsoner }}) ### Behavior changes From 8f353ddc91ff2062c2b044c47a59c5e99d1592c3 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Thu, 10 Nov 2022 19:42:04 -0500 Subject: [PATCH 09/10] FIX: Render --- docs/mkdocs.yml | 1 + mne_bids_pipeline/_config.py | 175 +++++++++++++++++------------------ 2 files changed, 88 insertions(+), 88 deletions(-) diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index b32182ace..d14b004aa 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -93,6 +93,7 @@ plugins: show_root_heading: true show_root_full_path: false separate_signature: true + line_length: 80 # needed for long param entries show_bases: false selection: docstring_style: numpy diff --git a/mne_bids_pipeline/_config.py b/mne_bids_pipeline/_config.py index 1a4cd9cfb..832c2fd95 100644 --- a/mne_bids_pipeline/_config.py +++ b/mne_bids_pipeline/_config.py @@ -29,22 +29,22 @@ bids_root: Optional[PathLike] = None """ -Specify the BIDS root directory. Pass an empty string or ```None`` to use -the value specified in the ``BIDS_ROOT`` environment variable instead. +Specify the BIDS root directory. Pass an empty string or ```None` to use +the value specified in the `BIDS_ROOT` environment variable instead. Raises an exception if the BIDS root has not been specified. ???+ example "Example" ``` python bids_root = '/path/to/your/bids_root' # Use this to specify a path here. - bids_root = None # Make use of the ``BIDS_ROOT`` environment variable. + bids_root = None # Make use of the `BIDS_ROOT` environment variable. ``` """ deriv_root: Optional[PathLike] = None """ The root of the derivatives directory in which the pipeline will store -the processing results. If ``None``, this will be -``derivatives/mne-bids-pipeline`` inside the BIDS root. +the processing results. If `None`, this will be +`derivatives/mne-bids-pipeline` inside the BIDS root. Note: Note If specified and you wish to run the source analysis steps, you must @@ -54,16 +54,16 @@ subjects_dir: Optional[PathLike] = None """ Path to the directory that contains the FreeSurfer reconstructions of all -subjects. Specifically, this defines the ``SUBJECTS_DIR`` that is used by +subjects. Specifically, this defines the `SUBJECTS_DIR` that is used by FreeSurfer. -- When running the ``freesurfer`` processing step to create the +- When running the `freesurfer` processing step to create the reconstructions from anatomical scans in the BIDS dataset, the output will be stored in this directory. - When running the source analysis steps, we will look for the surfaces in this directory and also store the BEM surfaces there. -If ``None``, this will default to +If `None`, this will default to [`bids_root`][mne_bids_pipeline.config.bids_root]`/derivatives/freesurfer/subjects`. Note: Note @@ -85,7 +85,7 @@ sessions: Union[List, Literal['all']] = 'all' """ -The sessions to process. If ``'all'``, will process all sessions found in the +The sessions to process. If `'all'`, will process all sessions found in the BIDS dataset. """ @@ -96,7 +96,7 @@ runs: Union[Iterable, Literal['all']] = 'all' """ -The runs to process. If ``'all'``, will process all runs found in the +The runs to process. If `'all'`, will process all runs found in the BIDS dataset. """ @@ -118,9 +118,9 @@ crop_runs: Optional[Tuple[float, float]] = None """ -Crop the raw data of each run to the specified time interval ``[tmin, tmax]``, +Crop the raw data of each run to the specified time interval `[tmin, tmax]`, in seconds. The runs will be cropped before Maxwell or frequency filtering is -applied. If ``None``, do not crop the data. +applied. If `None`, do not crop the data. """ acq: Optional[str] = None @@ -153,14 +153,14 @@ subjects: Union[Iterable[str], Literal['all']] = 'all' """ -Subjects to analyze. If ``'all'``, include all subjects. To only +Subjects to analyze. If `'all'`, include all subjects. To only include a subset of subjects, pass a list of their identifiers. Even if you plan on analyzing only a single subject, pass their identifier as a list. Please note that if you intend to EXCLUDE only a few subjects, you -should consider setting ``subjects = 'all'`` and adding the -identifiers of the excluded subjects to ``exclude_subjects`` (see next +should consider setting `subjects = 'all'` and adding the +identifiers of the excluded subjects to `exclude_subjects` (see next section). ???+ example "Example" @@ -180,7 +180,7 @@ Keep track of the criteria leading you to exclude a participant (e.g. too many movements, missing blocks, aborted experiment, did not understand the instructions, etc, ...) - The ``emptyroom`` subject will be excluded automatically. + The `emptyroom` subject will be excluded automatically. """ process_empty_room: bool = True @@ -188,7 +188,7 @@ Whether to apply the same pre-processing steps to the empty-room data as to the experimental data (up until including frequency filtering). This is required if you wish to use the empty-room recording to estimate noise -covariance (via ``noise_cov='emptyroom'``). The empty-room recording +covariance (via `noise_cov='emptyroom'`). The empty-room recording corresponding to the processed experimental data will be retrieved automatically. """ @@ -198,7 +198,7 @@ Whether to apply the same pre-processing steps to the resting-state data as to the experimental data (up until including frequency filtering). This is required if you wish to use the resting-state recording to estimate noise -covariance (via ``noise_cov='rest'``). +covariance (via `noise_cov='rest'`). """ ch_types: Iterable[Literal['meg', 'mag', 'grad', 'eeg']] = [] @@ -228,7 +228,7 @@ For MEG recordings, this will usually be 'meg'; and for EEG, 'eeg'. However, if your dataset contains simultaneous recordings of MEG and EEG, stored in a single file, you will typically need to set this to 'meg'. -If ``None``, we will assume that the data type matches the channel type. +If `None`, we will assume that the data type matches the channel type. ???+ example "Example" The dataset contains simultaneous recordings of MEG and EEG, and we only @@ -273,18 +273,18 @@ You can specify one or multiple channel names. Each will be treated as if it were a dedicated EOG channel, without excluding it from any other analyses. -If ``None``, only actual EOG channels will be used for EOG recovery. +If `None`, only actual EOG channels will be used for EOG recovery. If there are multiple actual EOG channels in your data, and you only specify a subset of them here, only this subset will be used during processing. ???+ example "Example" - Treat ``Fp1`` as virtual EOG channel: + Treat `Fp1` as virtual EOG channel: ```python eog_channels = ['Fp1'] ``` - Treat ``Fp1`` and ``Fp2`` as virtual EOG channels: + Treat `Fp1` and `Fp2` as virtual EOG channels: ```python eog_channels = ['Fp1', 'Fp2'] ``` @@ -325,7 +325,7 @@ eeg_reference: Union[Literal['average'], str, Iterable['str']] = 'average' """ -The EEG reference to use. If ``average``, will use the average reference, +The EEG reference to use. If `average`, will use the average reference, i.e. the average across all channels. If a string, must be the name of a single channel. To use multiple channels as reference, set to a list of channel names. @@ -357,7 +357,7 @@ Please be aware that the actual cap placement most likely deviated somewhat from the template, and, therefore, source reconstruction may be impaired. -If ``None``, do not apply a template montage. If a string, must be the +If `None`, do not apply a template montage. If a string, must be the name of a built-in template montage in MNE-Python. You can find an overview of supported template montages at https://mne.tools/stable/generated/mne.channels.make_standard_montage.html @@ -544,7 +544,7 @@ If the data were recorded with internal active compensation (MaxShield), they need to be run through Maxwell filter to avoid distortions. Bad channels need to be set through BIDS channels.tsv and / or via the - ``find_flat_channels_meg`` and ``find_noisy_channels_meg`` options above + `find_flat_channels_meg` and `find_noisy_channels_meg` options above before applying Maxwell filter. """ @@ -566,10 +566,10 @@ ???+ info "Good Practice / Advice" If you are interested in low frequency activity (<0.1Hz), avoid using - tSSS and set ``mf_st_duration`` to ``None``. + tSSS and set `mf_st_duration` to `None`. If you are interested in low frequency above 0.1 Hz, you can use the - default ``mf_st_duration`` to 10 s, meaning it acts like a 0.1 Hz + default `mf_st_duration` to 10 s, meaning it acts like a 0.1 Hz high-pass filter. ???+ example "Example" @@ -581,7 +581,7 @@ mf_head_origin: Union[Literal['auto'], ArrayLike] = 'auto' """ -``mf_head_origin`` : array-like, shape (3,) | 'auto' +`mf_head_origin` : array-like, shape (3,) | 'auto' Origin of internal and external multipolar moment space in meters. If 'auto', it will be estimated from headshape points. If automatic fitting fails (e.g., due to having too few digitization @@ -604,7 +604,7 @@ the recording session). Which run to take as the reference for adjusting the head position of all -runs. If ``None``, pick the first run. +runs. If `None`, pick the first run. ???+ example "Example" ```python @@ -726,7 +726,7 @@ """ Says how much to decimate data at the epochs level. It is typically an alternative to the `resample_sfreq` parameter that -can be used for resampling raw data. ``1`` means no decimation. +can be used for resampling raw data. `1` means no decimation. ???+ info "Good Practice / Advice" Decimation requires to lowpass filtered the data to avoid aliasing. @@ -752,20 +752,20 @@ Pass an empty dictionary to not perform any renaming. ???+ example "Example" - Rename ``audio_left`` in the BIDS dataset to ``audio/left`` in the + Rename `audio_left` in the BIDS dataset to `audio/left` in the pipeline: ```python rename_events = {'audio_left': 'audio/left'} ``` """ -on_rename_missing_events: Literal['warn', 'raise', 'ignore'] = 'raise' +on_rename_missing_events: Literal['ignore', 'warn', 'raise'] = 'raise' """ How to handle the situation where you specified an event to be renamed via -``rename_events``, but this particular event is not present in the data. By +`rename_events`, but this particular event is not present in the data. By default, we will raise an exception to avoid accidental mistakes due to typos; -however, if you're sure what you're doing, you may change this to ``'warn'`` -to only get a warning instead, or ``'ignore'`` to ignore it completely. +however, if you're sure what you're doing, you may change this to `'warn'` +to only get a warning instead, or `'ignore'` to ignore it completely. """ ############################################################################### @@ -795,13 +795,13 @@ """ The beginning of the time window for metadata generation, in seconds, relative to the time-locked event of the respective epoch. This may be less -than or larger than the epoch's first time point. If ``None``, use the first +than or larger than the epoch's first time point. If `None`, use the first time point of the epoch. """ epochs_metadata_tmax: Optional[float] = None """ -Same as ``epochs_metadata_tmin``, but specifying the **end** of the time +Same as `epochs_metadata_tmin`, but specifying the **end** of the time window for metadata generation. """ @@ -810,34 +810,34 @@ Event groupings using hierarchical event descriptors (HEDs) for which to store the time of the **first** occurrence of any event of this group in a new column with the group name, and the **type** of that event in a column named after the -group, but with a ``first_`` prefix. If ``None`` (default), no event +group, but with a `first_` prefix. If `None` (default), no event aggregation will take place and no new columns will be created. ???+ example "Example" - Assume you have two response events types, ``response/left`` and - ``response/right``; in some trials, both responses occur, because the + Assume you have two response events types, `response/left` and + `response/right`; in some trials, both responses occur, because the participant pressed both buttons. Now, you want to keep the first response only. To achieve this, set ```python epochs_metadata_keep_first = ['response'] ``` - This will add two new columns to the metadata: ``response``, indicating - the **time** relative to the time-locked event; and ``first_response``, - depicting the **type** of event (``'left'`` or ``'right'``). + This will add two new columns to the metadata: `response`, indicating + the **time** relative to the time-locked event; and `first_response`, + depicting the **type** of event (`'left'` or `'right'`). You may also specify a grouping for multiple event types: ```python epochs_metadata_keep_first = ['response', 'stimulus'] ``` - This will add the columns ``response``, ``first_response``, ``stimulus``, - and ``first_stimulus``. + This will add the columns `response`, `first_response`, `stimulus`, + and `first_stimulus`. """ epochs_metadata_keep_last: Optional[Iterable[str]] = None """ -Same as ``epochs_metadata_keep_first``, but for keeping the **last** +Same as `epochs_metadata_keep_first`, but for keeping the **last** occurrence of matching event types. The columns indicating the event types -will be named with a ``last_`` instead of a ``first_`` prefix. +will be named with a `last_` instead of a `first_` prefix. """ epochs_metadata_query: Optional[str] = None @@ -857,8 +857,8 @@ """ The time-locked events based on which to create evoked responses. This can either be name of the experimental condition as specified in the -BIDS ``*_events.tsv`` file; or the name of condition *groups*, if the condition -names contain the (MNE-specific) group separator, ``/``. See the [Subselecting +BIDS `*_events.tsv` file; or the name of condition *groups*, if the condition +names contain the (MNE-specific) group separator, `/`. See the [Subselecting epochs tutorial](https://mne.tools/stable/auto_tutorials/epochs/plot_10_epochs_overview.html#subselecting-epochs) for more information. @@ -917,14 +917,14 @@ rest_epochs_overlap: Optional[float] = None """ -Overlap between epochs in seconds. This is used if the task is ``'rest'`` +Overlap between epochs in seconds. This is used if the task is `'rest'` and when the annotations do not contain any stimulation or behavior events. """ baseline: Optional[Tuple[Optional[float], Optional[float]]] = (None, 0) """ Specifies which time interval to use for baseline correction of epochs; -if ``None``, no baseline correction is applied. +if `None`, no baseline correction is applied. ???+ example "Example" ```python @@ -960,7 +960,7 @@ ???+ example "Example" Contrast the "left" and the "right" conditions by calculating - ``left - right`` at every time point of the evoked responses: + `left - right` at every time point of the evoked responses: ```python contrasts = [('left', 'right')] # Note we pass a tuple inside the list! ``` @@ -1168,11 +1168,11 @@ Using a relatively high cutoff like 1 Hz will remove slow drifts from the data, yielding improved ICA results. Must be set to 1 Hz or above. -Set to ``None`` to not apply an additional high-pass filter. +Set to `None` to not apply an additional high-pass filter. Note: Note The filter will be applied to raw data which was already filtered - according to the ``l_freq`` and ``h_freq`` settings. After filtering, the + according to the `l_freq` and `h_freq` settings. After filtering, the data will be epoched, and the epochs will be submitted to ICA. !!! info @@ -1215,7 +1215,7 @@ explained variance less than the value specified here will be passed to ICA. -If ``None``, **all** principal components will be used. +If `None`, **all** principal components will be used. This setting may drastically alter the time required to compute ICA. """ @@ -1225,7 +1225,7 @@ The decimation parameter to compute ICA. If 5 it means that 1 every 5 sample is used by ICA solver. The higher the faster it is to run but the less data you have to compute a good ICA. Set to -``1`` or ``None`` to not perform any decimation. +`1` or `None` to not perform any decimation. """ ica_ctps_ecg_threshold: float = 0.1 @@ -1293,7 +1293,7 @@ reject_tmin: Optional[float] = None """ -Start of the time window used to reject epochs. If ``None``, the window will +Start of the time window used to reject epochs. If `None`, the window will start with the first time point. ???+ example "Example" ```python @@ -1303,7 +1303,7 @@ reject_tmax: Optional[float] = None """ -End of the time window used to reject epochs. If ``None``, the window will end +End of the time window used to reject epochs. If `None`, the window will end with the last time point. ???+ example "Example" ```python @@ -1635,17 +1635,17 @@ bem_mri_images: Literal['FLASH', 'T1', 'auto'] = 'auto' """ Which types of MRI images to use when creating the BEM model. -If ``'FLASH'``, use FLASH MRI images, and raise an exception if they cannot be +If `'FLASH'`, use FLASH MRI images, and raise an exception if they cannot be found. ???+ info "Advice" It is recommended to use the FLASH images if available, as the quality of the extracted BEM surfaces will be higher. -If ``'T1'``, create the BEM surfaces from the T1-weighted images using the -``watershed`` algorithm. +If `'T1'`, create the BEM surfaces from the T1-weighted images using the +`watershed` algorithm. -If ``'auto'``, use FLASH images if available, and use the ``watershed`` +If `'auto'`, use FLASH images if available, and use the `watershed`` algorithm with the T1-weighted images otherwise. *[FLASH MRI]: Fast low angle shot magnetic resonance imaging @@ -1654,22 +1654,22 @@ recreate_bem: bool = False """ Whether to re-create the BEM surfaces, even if existing surfaces have been -found. If ``False``, the BEM surfaces are only created if they do not exist -already. ``True`` forces their recreation, overwriting existing BEM surfaces. +found. If `False`, the BEM surfaces are only created if they do not exist +already. `True` forces their recreation, overwriting existing BEM surfaces. """ recreate_scalp_surface: bool = False """ Whether to re-create the scalp surfaces used for visualization of the coregistration in the report and the lower-density coregistration surfaces. -If ``False``, the scalp surface is only created if it does not exist already. -If ``True``, forces a re-computation. +If `False`, the scalp surface is only created if it does not exist already. +If `True`, forces a re-computation. """ freesurfer_verbose: bool = False """ Whether to print the complete output of FreeSurfer commands. Note that if -``False``, no FreeSurfer output might be displayed at all!""" +`False`, no FreeSurfer output might be displayed at all!""" mri_t1_path_generator: Optional[ Callable[[BIDSPath], BIDSPath] @@ -1754,9 +1754,9 @@ def mri_landmarks_kind(bids_path): spacing: Union[Literal['oct5', 'oct6', 'ico4', 'ico5', 'all'], int] = 'oct6' """ -The spacing to use. Can be ``'ico#'`` for a recursively subdivided -icosahedron, ``'oct#'`` for a recursively subdivided octahedron, -``'all'`` for all points, or an integer to use approximate +The spacing to use. Can be `'ico#'` for a recursively subdivided +icosahedron, `'oct#'` for a recursively subdivided octahedron, +`'all'` for all points, or an integer to use approximate distance-based spacing (in mm). See (the respective MNE-Python documentation) [https://mne.tools/dev/overview/cookbook.html#setting-up-the-source-space] for more info. @@ -1768,20 +1768,19 @@ def mri_landmarks_kind(bids_path): """ loose: Union[float, Literal['auto']] = 0.2 -# ``loose`` : float in [0, 1] | 'auto' """ Value that weights the source variances of the dipole components -that are parallel (tangential) to the cortical surface. If ``0``, then the +that are parallel (tangential) to the cortical surface. If `0`, then the inverse solution is computed with **fixed orientation.** -If ``1``, it corresponds to **free orientation.** -The default value, ``'auto'``, is set to ``0.2`` for surface-oriented source -spaces, and to ``1.0`` for volumetric, discrete, or mixed source spaces, -unless ``fixed is True`` in which case the value 0. is used. +If `1`, it corresponds to **free orientation.** +The default value, `'auto'`, is set to `0.2` for surface-oriented source +spaces, and to `1.0` for volumetric, discrete, or mixed source spaces, +unless `fixed is True` in which case the value 0. is used. """ depth: Optional[Union[float, dict]] = 0.8 """ -If float (default 0.8), it acts as the depth weighting exponent (``exp``) +If float (default 0.8), it acts as the depth weighting exponent (`exp`) to use (must be between 0 and 1). None is equivalent to 0, meaning no depth weighting is performed. Can also be a `dict` containing additional keyword arguments to pass to :func:`mne.forward.compute_depth_prior` @@ -1803,22 +1802,22 @@ def mri_landmarks_kind(bids_path): Specify how to estimate the noise covariance matrix, which is used in inverse modeling. -If a tuple, it takes the form ``(tmin, tmax)`` with the time specified in -seconds. If the first value of the tuple is ``None``, the considered +If a tuple, it takes the form `(tmin, tmax)` with the time specified in +seconds. If the first value of the tuple is `None`, the considered period starts at the beginning of the epoch. If the second value of the -tuple is ``None``, the considered period ends at the end of the epoch. -The default, ``(None, 0)``, includes the entire period before the event, +tuple is `None`, the considered period ends at the end of the epoch. +The default, `(None, 0)`, includes the entire period before the event, which is typically the pre-stimulus period. -If ``'emptyroom'``, the noise covariance matrix will be estimated from an +If `'emptyroom'`, the noise covariance matrix will be estimated from an empty-room MEG recording. The empty-room recording will be automatically selected based on recording date and time. This cannot be used with EEG data. -If ``'rest'``, the noise covariance will be estimated from a resting-state +If `'rest'`, the noise covariance will be estimated from a resting-state recording (i.e., a recording with `task-rest` and without a `run` in the filename). -If ``'ad-hoc'``, a diagonal ad-hoc noise covariance matrix will be used. +If `'ad-hoc'`, a diagonal ad-hoc noise covariance matrix will be used. You can also pass a function that accepts a `BIDSPath` and returns an `mne.Covariance` instance. The `BIDSPath` will point to the file containing @@ -1957,7 +1956,7 @@ def noise_cov(bids_path): """ You can specify the seed of the random number generator (RNG). This setting is passed to the ICA algorithm and to the decoding function, -ensuring reproducible results. Set to ``None`` to avoid setting the RNG +ensuring reproducible results. Set to `None` to avoid setting the RNG to a defined state. """ @@ -1989,15 +1988,15 @@ def noise_cov(bids_path): """ If not None (or False), caching will be enabled and the cache files will be stored in the given directory. The default (True) will use a -``'joblib'`` subdirectory in the BIDS derivative root of the dataset. +`'joblib'` subdirectory in the BIDS derivative root of the dataset. """ MemoryFileMethodT = Literal['mtime', 'hash'] memory_file_method: MemoryFileMethodT = 'mtime' """ The method to use for cache invalidation (i.e., detecting changes). Using the -"modified time" reported by the filesystem (``'mtime'``, default) is very fast +"modified time" reported by the filesystem (`'mtime'`, default) is very fast but requires that the filesystem supports proper mtime reporting. Using file -hashes (``'hash'``) is slower and requires reading all input files but should +hashes (`'hash'`) is slower and requires reading all input files but should work on any filesystem. """ memory_verbose: int = 0 From 35fcb42954fc2fab5c28d1fc3abf4c94baa56d6d Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Fri, 11 Nov 2022 12:15:03 -0500 Subject: [PATCH 10/10] FIX: URL --- .circleci/config.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index fd30ddd5b..bd891fecd 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -46,7 +46,6 @@ jobs: command: | pip install --upgrade --progress-bar off pip setuptools pip install -ve .[tests,dev] - pip install https://github.com/larsoner/mne-python/zipball/report # fix Report bug pip install PyQt6 - run: name: Check Qt