Skip to content

Commit

Permalink
CLN: Remove unused multiple model functionality, fix #538
Browse files Browse the repository at this point in the history
> The config allows instantiating multiple models
  for training / prediction but this was never fully implemented

Squash commits:
- Make config just use "model" option, not "models":
  - Remove `comma_separated_list` from converters
  - Change option name 'models' -> 'model' in config/valid.toml
  - Rewrite is_valid_model_name to test a single string,
    not a list of strings
  - Change attribute `models` -> `model` in config/eval.py
  - Change attribute `models` -> `model` in config/learncurve.py
  - Change attribute `models` -> `model` in config/predict.py
  - Change attribute `models` -> `model` in config/train.py
  - Rewrite/rename config.models -> model.config_from_toml_path,
    model.config_from_toml_dict
  - Fix option 'models' -> model = 'str',
    in all .toml files in tests/data_for_tests/configs
- Rewrite `models.from_model_config_map` as `models.get`:
  - Add src/vak/models/_api.py with BUILTIN_MODELS and
    MODEL_NAMES to use for validation in `models.get`
  - Rewrite models/models.py `from_model_config_map` as `models.get`
  - Import get and _api in vak/models/__init__.py
  - Rewrite core/train.py to take model_name and model_config
    then use models.get
  - Fix cli/train.py to pass model_name and model_config into core.train
  - Rewrite core/eval.py to take model_name and model_config
    then use models.get
  - Fix cli/eval.py to pass model_name and model_config into core.eval
  - Rewrite core/learncurve.py to take model_name and model_config
  - Fix cli/learncurve.py to pass model_name and model_config
    into core.learncurve
  - Rewrite core/predict.py to take model_name and model_config
    then use models.get
  - Fix cli/predict.py to pass model_name and model_config
    into core.predict
  - Make 'model' not 'models' required in src/vak/config/parse.py
  - Use models.MODEL_NAMES in src/vak/config/validators.py
  - Use models.MODEL_NAMES in config/model.py
- Fix tests
  - Fix tests to use vak.config.model.config_from_toml_path:
    - tests/test_models/test_windowed_frame_classification_model.py
    - tests/test_core/test_train.py
    - tests/test_core/test_predict.py
    - tests/test_core/test_learncurve.py
    - tests/test_core/test_eval.py
  - Fix test to use 'model' option not 'models' in
    tests/test_config/test_parse.py
  - Fix assert helper function in tests/test_core
    - test_eval.py
    - test_learncurve.py
    - test_predict.py
    - test_prep.py
    - test_train.py
  - Rewrite fixture module with constants we can import
    in test modules to parametrize: tests/fixtures/config.py
  - Add tests/test_config/test_model.py
  • Loading branch information
NickleDave committed Feb 13, 2023
1 parent f8c2535 commit b206669
Show file tree
Hide file tree
Showing 55 changed files with 677 additions and 599 deletions.
8 changes: 5 additions & 3 deletions src/vak/cli/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def eval(toml_path):

logger.info("Logging results to {}".format(cfg.eval.output_dir))

model_config_map = config.models.map_from_path(toml_path, cfg.eval.models)
model_name = cfg.eval.model
model_config = config.model.config_from_toml_path(toml_path, model_name)

if cfg.eval.csv_path is None:
raise ValueError(
Expand All @@ -53,8 +54,9 @@ def eval(toml_path):
)

core.eval(
cfg.eval.csv_path,
model_config_map,
model_name=model_name,
model_config=model_config,
csv_path=cfg.eval.csv_path,
checkpoint_path=cfg.eval.checkpoint_path,
labelmap_path=cfg.eval.labelmap_path,
output_dir=cfg.eval.output_dir,
Expand Down
6 changes: 4 additions & 2 deletions src/vak/cli/learncurve.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ def learning_curve(toml_path):
log_version(logger)
logger.info("Logging results to {}".format(results_path))

model_config_map = config.models.map_from_path(toml_path, cfg.learncurve.models)
model_name = cfg.learncurve.model
model_config = config.model.config_from_toml_path(toml_path, model_name)

if cfg.learncurve.csv_path is None:
raise ValueError(
Expand All @@ -60,7 +61,8 @@ def learning_curve(toml_path):
)

core.learning_curve(
model_config_map,
model_name=model_name,
model_config=model_config,
train_set_durs=cfg.learncurve.train_set_durs,
num_replicates=cfg.learncurve.num_replicates,
csv_path=cfg.learncurve.csv_path,
Expand Down
6 changes: 4 additions & 2 deletions src/vak/cli/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def predict(toml_path):
log_version(logger)
logger.info("Logging results to {}".format(cfg.prep.output_dir))

model_config_map = config.models.map_from_path(toml_path, cfg.predict.models)
model_name = cfg.predict.model
model_config = config.model.config_from_toml_path(toml_path, model_name)

if cfg.predict.csv_path is None:
raise ValueError(
Expand All @@ -48,10 +49,11 @@ def predict(toml_path):
)

core.predict(
model_name=model_name,
model_config=model_config,
csv_path=cfg.predict.csv_path,
checkpoint_path=cfg.predict.checkpoint_path,
labelmap_path=cfg.predict.labelmap_path,
model_config_map=model_config_map,
window_size=cfg.dataloader.window_size,
num_workers=cfg.predict.num_workers,
spect_key=cfg.spect_params.spect_key,
Expand Down
6 changes: 4 additions & 2 deletions src/vak/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ def train(toml_path):
log_version(logger)
logger.info("Logging results to {}".format(results_path))

model_config_map = config.models.map_from_path(toml_path, cfg.train.models)
model_name = cfg.train.model
model_config = config.model.config_from_toml_path(toml_path, model_name)

if cfg.train.csv_path is None:
raise ValueError(
Expand All @@ -64,7 +65,8 @@ def train(toml_path):
labelset, labelmap_path = cfg.prep.labelset, None

core.train(
model_config_map=model_config_map,
model_name=model_name,
model_config=model_config,
csv_path=cfg.train.csv_path,
labelset=labelset,
window_size=cfg.dataloader.window_size,
Expand Down
2 changes: 1 addition & 1 deletion src/vak/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
dataloader,
eval,
learncurve,
models,
model,
parse,
predict,
prep,
Expand Down
11 changes: 5 additions & 6 deletions src/vak/config/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from .validators import is_valid_model_name
from .. import device
from ..converters import comma_separated_list, expanded_user_path
from ..converters import expanded_user_path


def convert_post_tfm_kwargs(post_tfm_kwargs: dict) -> dict:
Expand Down Expand Up @@ -72,8 +72,8 @@ class EvalConfig:
Path to location where .csv files with evaluation metrics should be saved.
labelmap_path : str
path to 'labelmap.json' file.
models : list
of model names. e.g., 'models = TweetyNet, GRUNet, ConvNet'
model : str
Model name, e.g., ``model = "TweetyNet"``
batch_size : int
number of samples per batch presented to models during training.
num_workers : int
Expand Down Expand Up @@ -106,9 +106,8 @@ class EvalConfig:
output_dir = attr.ib(converter=expanded_user_path)

# required, model / dataloader
models = attr.ib(
converter=comma_separated_list,
validator=[instance_of(list), is_valid_model_name],
model = attr.ib(
validator=[instance_of(str), is_valid_model_name],
)
batch_size = attr.ib(converter=int, validator=instance_of(int))

Expand Down
4 changes: 2 additions & 2 deletions src/vak/config/learncurve.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ class LearncurveConfig(TrainConfig):
Attributes
----------
models : list
of model names. e.g., 'models = TweetyNet, GRUNet, ConvNet'
model : str
Model name, e.g., ``model = "TweetyNet"``
csv_path : str
path to where dataset was saved as a csv.
num_epochs : int
Expand Down
88 changes: 88 additions & 0 deletions src/vak/config/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from __future__ import annotations
import pathlib

import toml

from .. import models


MODEL_TABLES = [
"network",
"optimizer",
"loss",
"metrics",
]


def config_from_toml_dict(toml_dict: dict, model_name: str) -> dict:
"""Get configuration for a model from a .toml configuration file
loaded into a ``dict``.
Parameters
----------
toml_dict : dict
Configuration from a .toml file, loaded into a dictionary.
model_name : str
Name of a model, specified as the ``model`` option in a table
(such as TRAIN or PREDICT),
that should have its own corresponding table
specifying its configuration: hyperparameters such as learning rate, etc.
Returns
-------
model_config : dict
Model configuration in a ``dict``,
as loaded from a .toml file,
and used by the model method ``from_config``.
"""
if model_name not in models.MODEL_NAMES:
raise ValueError(
f"Invalid model name: {model_name}.\nValid model names are: {models.MODEL_NAMES}"
)

try:
model_config = toml_dict[model_name]
except KeyError as e:
raise ValueError(
f"A config section specifies the model name '{model_name}', "
f"but there is no section named '{model_name}' in the config."
) from e

# check if config declares parameters for required attributes;
# if not, just put an empty dict that will get passed as the "kwargs"
for attr in MODEL_TABLES:
if attr not in model_config:
model_config[attr] = {}

return model_config


def config_from_toml_path(toml_path: str | pathlib.Path, model_name: str) -> dict:
"""Get configuration for a model from a .toml configuration file,
given the path to the file.
Parameters
----------
toml_path : str, Path
to configuration file in .toml format
model_name : str
of str, i.e. names of models specified by a section
(such as TRAIN or PREDICT) that should each have corresponding sections
specifying their configuration: hyperparameters such as learning rate, etc.
Returns
-------
model_config : dict
Model configuration in a ``dict``,
as loaded from a .toml file,
and used by the model method ``from_config``.
"""
toml_path = pathlib.Path(toml_path)
if not toml_path.is_file():
raise FileNotFoundError(
f"File not found, or not recognized as a file: {toml_path}"
)

with toml_path.open("r") as fp:
config_dict = toml.load(fp)
return config_from_toml_dict(config_dict, model_name)
115 changes: 0 additions & 115 deletions src/vak/config/models.py

This file was deleted.

8 changes: 4 additions & 4 deletions src/vak/config/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,26 +29,26 @@
"checkpoint_path",
"labelmap_path",
"output_dir",
"models",
"model",
],
"LEARNCURVE": [
"models",
"model",
"root_results_dir",
"train_set_durs",
"num_replicates",
],
"PREDICT": [
"checkpoint_path",
"labelmap_path",
"models",
"model",
],
"PREP": [
"data_dir",
"output_dir",
],
"SPECT_PARAMS": None,
"TRAIN": [
"models",
"model",
"root_results_dir",
],
}
Expand Down
Loading

0 comments on commit b206669

Please sign in to comment.