Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

using caching for get_runs_all_subjects to speed up get_runs #716

Merged
merged 7 commits into from
Feb 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions mne_bids_pipeline/_config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import copy
import functools
import pathlib
from typing import List, Optional, Union, Iterable, Tuple, Dict, TypeVar, Literal
from typing import List, Optional, Union, Iterable, Tuple, Dict, TypeVar, Literal, Any
from types import SimpleNamespace, ModuleType

import numpy as np
Expand Down Expand Up @@ -65,7 +65,10 @@ def get_datatype(config: SimpleNamespace) -> Literal["meg", "eeg"]:
return "meg"
else:
raise RuntimeError(
"This probably shouldn't happen. Please contact "
"This probably shouldn't happen, got "
f"config.data_type={repr(config.data_type)} and "
f"config.ch_types={repr(config.ch_types)} "
"but could not determine the datatype. Please contact "
"the MNE-BIDS-pipeline developers. Thank you."
)

Expand Down Expand Up @@ -110,6 +113,16 @@ def get_sessions(config: SimpleNamespace) -> Union[List[None], List[str]]:
return sessions


@functools.lru_cache(maxsize=None)
def _get_runs_all_subjects_cached(
**config_dict: Dict[str, Any],
) -> Dict[str, Union[List[None], List[str]]]:
config = SimpleNamespace(**config_dict)
# Sometimes we check list equivalence for ch_types, so convert it back
config.ch_types = list(config.ch_types)
return get_runs_all_subjects(config)


def get_runs_all_subjects(
config: SimpleNamespace,
) -> Dict[str, Union[List[None], List[str]]]:
Expand Down Expand Up @@ -173,7 +186,14 @@ def get_runs(

runs = copy.deepcopy(config.runs)

subj_runs = get_runs_all_subjects(config)
subj_runs = _get_runs_all_subjects_cached(
bids_root=config.bids_root,
data_type=config.data_type,
ch_types=tuple(config.ch_types),
subjects=tuple(config.subjects) if config.subjects != "all" else "all",
exclude_subjects=tuple(config.exclude_subjects),
exclude_runs=tuple(config.exclude_runs) if config.exclude_runs else None,
)
valid_runs = subj_runs[subject]

if len(get_subjects(config)) > 1:
Expand Down Expand Up @@ -227,7 +247,7 @@ def get_mf_reference_run(config: SimpleNamespace) -> str:
raise ValueError(
f"The intersection of runs by subjects is empty. "
f"Check the list of runs: "
f"{get_runs_all_subjects()}"
f"{get_runs_all_subjects(config)}"
)


Expand Down
1 change: 0 additions & 1 deletion mne_bids_pipeline/_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,6 @@ def run_report_average_sensor(
with _open_report(
cfg=cfg, exec_params=exec_params, subject=subject, session=session
) as report:

#######################################################################
#
# Add event stats.
Expand Down
4 changes: 4 additions & 0 deletions mne_bids_pipeline/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ def pytest_configure(config):
ignore:The get_cmap function will be deprecated.*:
ignore:make_current is deprecated.*:DeprecationWarning
ignore:`np.*` is a deprecated alias for .*:DeprecationWarning
ignore:.*implicit namespace.*:DeprecationWarning
ignore:Deprecated call to `pkg_resources.*:DeprecationWarning
ignore:.*declare_namespace.*mpl_toolkits.*:DeprecationWarning
ignore:_SixMetaPathImporter\.find_spec.*:ImportWarning
"""
for warning_line in warning_lines.split("\n"):
warning_line = warning_line.strip()
Expand Down
17 changes: 13 additions & 4 deletions mne_bids_pipeline/tests/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,25 @@ class DATASET_OPTIONS_T(TypedDict):
"git": "",
"openneuro": "",
"osf": "", # original dataset: '9f5w7'
"web": "https://osf.io/3zk6n/download",
"web": "https://osf.io/3zk6n/download?version=2",
"include": [],
"exclude": [],
},
"eeg_matchingpennies": {
"git": "https://gin.g-node.org/sappelhoff/eeg_matchingpennies",
# This dataset started out on osf.io as dataset https://osf.io/cj2dr
# then moved to g-node.org. As of 2023/02/28 when we download it via
# datalad it's too (~200 kB/sec!) and times out at the end:
#
# "git": "https://gin.g-node.org/sappelhoff/eeg_matchingpennies",
# "web": "",
# "include": ["sub-05"],
#
# So now we mirror this datalad-fetched git repo back on osf.io!
"git": "",
"openneuro": "",
"osf": "", # original dataset: 'cj2dr'
"web": "",
"include": ["sub-05"],
"web": "https://osf.io/download/8rbfk?version=1",
"include": [],
"exclude": [],
},
"ds003104": { # Anonymized "somato" dataset.
Expand Down
2 changes: 1 addition & 1 deletion mne_bids_pipeline/tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def dataset_test(request):
test_options = TEST_SUITE[dataset]
dataset_name = test_options.get("dataset", dataset.split("_")[0])
with capsys.disabled():
if request.config.getoption("--download"): # download requested
if request.config.getoption("--download", False): # download requested
download_main(dataset_name)
yield

Expand Down