Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify organization of base classes #675

Merged
merged 18 commits into from
May 4, 2022
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ For more information about functional characterization analysis, see :doc:`decod
decode.discrete
decode.continuous
decode.encode
decode.base


.. _api_io_ref:
Expand Down Expand Up @@ -329,6 +330,3 @@ For more information about fetching data from the internet, see :ref:`fetching t

base.NiMAREBase
base.Estimator
base.MetaEstimator
base.Transformer
base.Decoder
308 changes: 34 additions & 274 deletions nimare/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,11 @@
import gzip
import inspect
import logging
import multiprocessing as mp
import pickle
from abc import ABCMeta, abstractmethod
from collections import defaultdict
from hashlib import md5

import nibabel as nb
import numpy as np
from nilearn._utils.niimg_conversions import _check_same_fov
from nilearn.image import concat_imgs, resample_to_img

from nimare.results import MetaResult
from nimare.utils import get_masker, mm2vox

LGR = logging.getLogger(__name__)

Expand All @@ -25,7 +17,6 @@ class NiMAREBase(metaclass=ABCMeta):
This class contains a few features that are useful throughout the library:

- Custom __repr__ method for printing the object.
- A private _check_ncores method to check if the common n_cores argument is valid.
- get_params from scikit-learn, with which parameters provided at __init__ can be viewed.
- set_params from scikit-learn, with which parameters provided at __init__ can be overwritten.
I'm not sure that this is actually used or useable in NiMARE.
Expand Down Expand Up @@ -74,18 +65,6 @@ def __repr__(self):
rep = f"{self.__class__.__name__}({', '.join(param_strs)})"
return rep

def _check_ncores(self, n_cores):
"""Check number of cores used for method."""
if n_cores <= 0:
n_cores = mp.cpu_count()
elif n_cores > mp.cpu_count():
LGR.warning(
f"Desired number of cores ({n_cores}) greater than number "
f"available ({mp.cpu_count()}). Setting to {mp.cpu_count()}."
)
n_cores = mp.cpu_count()
return n_cores

@classmethod
def _get_param_names(cls):
"""Get parameter names for the estimator."""
Expand Down Expand Up @@ -241,18 +220,40 @@ class Estimator(NiMAREBase):
"""Estimators take in Datasets and return MetaResults.

All Estimators must have a ``_fit`` method implemented, which applies algorithm-specific
methods to a dataset and returns a dictionary of arrays to be converted into a MetaResult.
methods to a Dataset and returns a dictionary of arrays to be converted into a MetaResult.

Users will interact with the ``_fit`` method by calling the user-facing ``fit`` method.
``fit`` takes in a ``Dataset``, calls ``_validate_input``, then ``_preprocess_input``,
``fit`` takes in a ``Dataset``, calls ``_collect_inputs``, then ``_preprocess_input``,
then ``_fit``, and finally converts the dictionary returned by ``_fit`` into a ``MetaResult``.
"""

# Inputs that must be available in input Dataset. Keys are names of
# attributes to set; values are strings indicating location in Dataset.
_required_inputs = {}

def _validate_input(self, dataset, drop_invalid=True):
"""Search for, and validate, required inputs as necessary."""
def _collect_inputs(self, dataset, drop_invalid=True):
"""Search for, and validate, required inputs as necessary.

This method populates the ``inputs_`` attribute.

.. versionchanged:: 0.0.12

Renamed from ``_validate_input``.

Parameters
----------
dataset : :obj:`~nimare.dataset.Dataset`
drop_invalid : :obj:`bool`, optional
Whether to automatically drop any studies in the Dataset without valid data or not.
Default is True.

Attributes
----------
inputs_ : :obj:`dict`
A dictionary of required inputs for the Estimator, extracted from the Dataset.
The actual inputs collected in this attribute are determined by the
``_required_inputs`` variable that should be specified in each child class.
"""
if not hasattr(dataset, "slice"):
raise ValueError(
f"Argument 'dataset' must be a valid Dataset object, not a {type(dataset)}."
Expand All @@ -278,7 +279,13 @@ def _validate_input(self, dataset, drop_invalid=True):
self.inputs_[k] = v

def _preprocess_input(self, dataset):
"""Perform any additional preprocessing steps on data in self.inputs_."""
"""Perform any additional preprocessing steps on data in self.inputs_.

Parameters
----------
dataset : :obj:`~nimare.dataset.Dataset`
The Dataset
"""
pass

def fit(self, dataset, drop_invalid=True):
Expand Down Expand Up @@ -309,7 +316,7 @@ def fit(self, dataset, drop_invalid=True):
"fitting" methods are implemented as `_fit`, although users should
call `fit`.
"""
self._validate_input(dataset, drop_invalid=drop_invalid)
self._collect_inputs(dataset, drop_invalid=drop_invalid)
self._preprocess_input(dataset)
maps = self._fit(dataset)

Expand All @@ -328,250 +335,3 @@ def _fit(self, dataset):
and values are ndarrays.
"""
pass


class MetaEstimator(Estimator):
"""Base class for meta-analysis methods in :mod:`~nimare.meta`.

.. versionchanged:: 0.0.8

* [REF] Use saved MA maps, when available.

.. versionadded:: 0.0.3

"""

def __init__(self, *args, **kwargs):
mask = kwargs.get("mask")
if mask is not None:
mask = get_masker(mask)
self.masker = mask

self.resample = kwargs.get("resample", False)
self.memory_limit = kwargs.get("memory_limit", None)

# defaults for resampling images (nilearn's defaults do not work well)
self._resample_kwargs = {"clip": True, "interpolation": "linear"}
self._resample_kwargs.update(
{k.split("resample__")[1]: v for k, v in kwargs.items() if k.startswith("resample__")}
)

def _preprocess_input(self, dataset):
"""Preprocess inputs to the Estimator from the Dataset as needed."""
masker = self.masker or dataset.masker

mask_img = masker.mask_img or masker.labels_img
if isinstance(mask_img, str):
mask_img = nb.load(mask_img)

# Ensure that protected values are not included among _required_inputs
assert "aggressive_mask" not in self._required_inputs.keys(), "This is a protected name."

if "aggressive_mask" in self.inputs_.keys():
LGR.warning("Removing existing 'aggressive_mask' from Estimator.")
self.inputs_.pop("aggressive_mask")

# A dictionary to collect masked image data, to be further reduced by the aggressive mask.
temp_image_inputs = {}

for name, (type_, _) in self._required_inputs.items():
if type_ == "image":
# If no resampling is requested, check if resampling is required
if not self.resample:
check_imgs = {img: nb.load(img) for img in self.inputs_[name]}
_check_same_fov(**check_imgs, reference_masker=mask_img, raise_error=True)
imgs = list(check_imgs.values())
else:
# resampling will only occur if shape/affines are different
# making this harmless if all img shapes/affines are the same as the reference
imgs = [
resample_to_img(nb.load(img), mask_img, **self._resample_kwargs)
for img in self.inputs_[name]
]

# input to NiFtiLabelsMasker must be 4d
img4d = concat_imgs(imgs, ensure_ndim=4)

# Mask required input images using either the dataset's mask or the estimator's.
temp_arr = masker.transform(img4d)

# An intermediate step to mask out bad voxels.
# Can be dropped once PyMARE is able to handle masked arrays or missing data.
nonzero_voxels_bool = np.all(temp_arr != 0, axis=0)
nonnan_voxels_bool = np.all(~np.isnan(temp_arr), axis=0)
good_voxels_bool = np.logical_and(nonzero_voxels_bool, nonnan_voxels_bool)

data = masker.transform(img4d)

temp_image_inputs[name] = data
if "aggressive_mask" not in self.inputs_.keys():
self.inputs_["aggressive_mask"] = good_voxels_bool
else:
# Remove any voxels that are bad in any image-based inputs
self.inputs_["aggressive_mask"] = np.logical_or(
self.inputs_["aggressive_mask"],
good_voxels_bool,
)

elif type_ == "coordinates":
# Try to load existing MA maps
if hasattr(self, "kernel_transformer"):
self.kernel_transformer._infer_names(affine=md5(mask_img.affine).hexdigest())
if self.kernel_transformer.image_type in dataset.images.columns:
files = dataset.get_images(
ids=self.inputs_["id"],
imtype=self.kernel_transformer.image_type,
)
if all(f is not None for f in files):
self.inputs_["ma_maps"] = files

# Calculate IJK matrix indices for target mask
# Mask space is assumed to be the same as the Dataset's space
# These indices are used directly by any KernelTransformer
xyz = self.inputs_["coordinates"][["x", "y", "z"]].values
ijk = mm2vox(xyz, mask_img.affine)
self.inputs_["coordinates"][["i", "j", "k"]] = ijk

# Further reduce image-based inputs to remove "bad" voxels
# (voxels with zeros or NaNs in any studies)
if "aggressive_mask" in self.inputs_.keys():
n_bad_voxels = (
self.inputs_["aggressive_mask"].size - self.inputs_["aggressive_mask"].sum()
)
if n_bad_voxels:
LGR.warning(
f"Masking out {n_bad_voxels} additional voxels. "
"The updated masker is available in the Estimator.masker attribute."
)

for name, raw_masked_data in temp_image_inputs.items():
self.inputs_[name] = raw_masked_data[:, self.inputs_["aggressive_mask"]]


class Transformer(NiMAREBase):
"""Transformers take in Datasets and return Datasets.

Initialize with hyperparameters.
"""

def __init__(self):
pass

@abstractmethod
def transform(self, dataset):
"""Add stuff to transformer."""
# Using attribute check instead of type check to allow fake Datasets for testing.
if not hasattr(dataset, "slice"):
raise ValueError(
f"Argument 'dataset' must be a valid Dataset object, not a {type(dataset)}"
)


class Decoder(NiMAREBase):
"""Base class for decoders in :mod:`~nimare.decode`.

.. versionadded:: 0.0.3

"""

__id_cols = ["id", "study_id", "contrast_id"]

def _validate_input(self, dataset, drop_invalid=True):
"""Search for, and validate, required inputs as necessary."""
if not hasattr(dataset, "slice"):
raise ValueError(
f"Argument 'dataset' must be a valid Dataset object, not a {type(dataset)}."
)

if self._required_inputs:
data = dataset.get(self._required_inputs, drop_invalid=drop_invalid)
# Do not overwrite existing inputs_ attribute.
# This is necessary for PairwiseCBMAEstimator, which validates two sets of coordinates
# in the same object.
# It makes the *strong* assumption that required inputs will not changes within an
# Estimator across fit calls, so all fields of inputs_ will be overwritten instead of
# retaining outdated fields from previous fit calls.
if not hasattr(self, "inputs_"):
self.inputs_ = {}

for k, v in data.items():
if v is None:
raise ValueError(
f"Estimator {self.__class__.__name__} requires input dataset to contain "
f"{k}, but no matching data were found."
)
self.inputs_[k] = v

def _preprocess_input(self, dataset):
"""Select features for model based on requested features and feature_group.

This also takes into account which features have at least one study in the
Dataset with the feature.
"""
# Reduce feature list as desired
if self.feature_group is not None:
if not self.feature_group.endswith("__"):
self.feature_group += "__"
feature_names = self.inputs_["annotations"].columns.values
feature_names = [f for f in feature_names if f.startswith(self.feature_group)]
if self.features is not None:
features = [f.split("__")[-1] for f in feature_names if f in self.features]
else:
features = feature_names
else:
if self.features is None:
features = self.inputs_["annotations"].columns.values
else:
features = self.features

features = [f for f in features if f not in self.__id_cols]
n_features_orig = len(features)

# At least one study in the dataset much have each label
counts = (self.inputs_["annotations"][features] > self.frequency_threshold).sum(0)
features = counts[counts > 0].index.tolist()
if not len(features):
raise Exception("No features identified in Dataset!")
elif len(features) < n_features_orig:
LGR.info(f"Retaining {len(features)}/({n_features_orig} features.")

self.features_ = features

def fit(self, dataset, drop_invalid=True):
"""Fit Decoder to Dataset.

Parameters
----------
dataset : :obj:`~nimare.dataset.Dataset`
Dataset object to analyze.
drop_invalid : :obj:`bool`, optional
Whether to automatically ignore any studies without the required data or not.
Default is True.


Returns
-------
:obj:`~nimare.results.MetaResult`
Results of Decoder fitting.

Notes
-----
The `fit` method is a light wrapper that runs input validation and
preprocessing before fitting the actual model. Decoders' individual
"fitting" methods are implemented as `_fit`, although users should
call `fit`.

Selection of features based on requested features and feature group is performed in
`Decoder._preprocess_input`.
"""
self._validate_input(dataset, drop_invalid=drop_invalid)
self._preprocess_input(dataset)
self._fit(dataset)

@abstractmethod
def _fit(self, dataset):
"""Apply decoding to dataset and output results.

Must return a DataFrame, with one row for each feature.
"""
pass
4 changes: 2 additions & 2 deletions nimare/correct.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self):
def _name_suffix(self):
pass

def _validate_input(self, result):
def _collect_inputs(self, result):
if not isinstance(result, MetaResult):
raise ValueError(
"First argument to transform() must be an "
Expand Down Expand Up @@ -132,7 +132,7 @@ def transform(self, result):
)
corr_maps = getattr(est, correction_method)(result, **self.parameters)
else:
self._validate_input(result)
self._collect_inputs(result)
corr_maps = self._transform(result)

# Update corrected map names and add them to maps dict
Expand Down
Loading