diff --git a/src/vak/cli/eval.py b/src/vak/cli/eval.py index 2361d6c1e..b33c6ed8f 100644 --- a/src/vak/cli/eval.py +++ b/src/vak/cli/eval.py @@ -43,7 +43,8 @@ def eval(toml_path): logger.info("Logging results to {}".format(cfg.eval.output_dir)) - model_config_map = config.models.map_from_path(toml_path, cfg.eval.models) + model_name = cfg.eval.model + model_config = config.model.config_from_toml_path(toml_path, model_name) if cfg.eval.csv_path is None: raise ValueError( @@ -53,8 +54,9 @@ def eval(toml_path): ) core.eval( - cfg.eval.csv_path, - model_config_map, + model_name=model_name, + model_config=model_config, + csv_path=cfg.eval.csv_path, checkpoint_path=cfg.eval.checkpoint_path, labelmap_path=cfg.eval.labelmap_path, output_dir=cfg.eval.output_dir, diff --git a/src/vak/cli/learncurve.py b/src/vak/cli/learncurve.py index 1f84b0d89..f98793541 100644 --- a/src/vak/cli/learncurve.py +++ b/src/vak/cli/learncurve.py @@ -50,7 +50,8 @@ def learning_curve(toml_path): log_version(logger) logger.info("Logging results to {}".format(results_path)) - model_config_map = config.models.map_from_path(toml_path, cfg.learncurve.models) + model_name = cfg.learncurve.model + model_config = config.model.config_from_toml_path(toml_path, model_name) if cfg.learncurve.csv_path is None: raise ValueError( @@ -60,7 +61,8 @@ def learning_curve(toml_path): ) core.learning_curve( - model_config_map, + model_name=model_name, + model_config=model_config, train_set_durs=cfg.learncurve.train_set_durs, num_replicates=cfg.learncurve.num_replicates, csv_path=cfg.learncurve.csv_path, diff --git a/src/vak/cli/predict.py b/src/vak/cli/predict.py index 6c5333ce7..9039b0b85 100644 --- a/src/vak/cli/predict.py +++ b/src/vak/cli/predict.py @@ -38,7 +38,8 @@ def predict(toml_path): log_version(logger) logger.info("Logging results to {}".format(cfg.prep.output_dir)) - model_config_map = config.models.map_from_path(toml_path, cfg.predict.models) + model_name = cfg.predict.model + model_config = config.model.config_from_toml_path(toml_path, model_name) if cfg.predict.csv_path is None: raise ValueError( @@ -48,10 +49,11 @@ def predict(toml_path): ) core.predict( + model_name=model_name, + model_config=model_config, csv_path=cfg.predict.csv_path, checkpoint_path=cfg.predict.checkpoint_path, labelmap_path=cfg.predict.labelmap_path, - model_config_map=model_config_map, window_size=cfg.dataloader.window_size, num_workers=cfg.predict.num_workers, spect_key=cfg.spect_params.spect_key, diff --git a/src/vak/cli/train.py b/src/vak/cli/train.py index a9c071b3b..0cd8a7d08 100644 --- a/src/vak/cli/train.py +++ b/src/vak/cli/train.py @@ -49,7 +49,8 @@ def train(toml_path): log_version(logger) logger.info("Logging results to {}".format(results_path)) - model_config_map = config.models.map_from_path(toml_path, cfg.train.models) + model_name = cfg.train.model + model_config = config.model.config_from_toml_path(toml_path, model_name) if cfg.train.csv_path is None: raise ValueError( @@ -64,7 +65,8 @@ def train(toml_path): labelset, labelmap_path = cfg.prep.labelset, None core.train( - model_config_map=model_config_map, + model_name=model_name, + model_config=model_config, csv_path=cfg.train.csv_path, labelset=labelset, window_size=cfg.dataloader.window_size, diff --git a/src/vak/config/__init__.py b/src/vak/config/__init__.py index 132b5366d..bad9749fc 100644 --- a/src/vak/config/__init__.py +++ b/src/vak/config/__init__.py @@ -4,7 +4,7 @@ dataloader, eval, learncurve, - models, + model, parse, predict, prep, diff --git a/src/vak/config/eval.py b/src/vak/config/eval.py index cc0bbe996..5c7d54be9 100644 --- a/src/vak/config/eval.py +++ b/src/vak/config/eval.py @@ -5,7 +5,7 @@ from .validators import is_valid_model_name from .. import device -from ..converters import comma_separated_list, expanded_user_path +from ..converters import expanded_user_path def convert_post_tfm_kwargs(post_tfm_kwargs: dict) -> dict: @@ -72,8 +72,8 @@ class EvalConfig: Path to location where .csv files with evaluation metrics should be saved. labelmap_path : str path to 'labelmap.json' file. - models : list - of model names. e.g., 'models = TweetyNet, GRUNet, ConvNet' + model : str + Model name, e.g., ``model = "TweetyNet"`` batch_size : int number of samples per batch presented to models during training. num_workers : int @@ -106,9 +106,8 @@ class EvalConfig: output_dir = attr.ib(converter=expanded_user_path) # required, model / dataloader - models = attr.ib( - converter=comma_separated_list, - validator=[instance_of(list), is_valid_model_name], + model = attr.ib( + validator=[instance_of(str), is_valid_model_name], ) batch_size = attr.ib(converter=int, validator=instance_of(int)) diff --git a/src/vak/config/learncurve.py b/src/vak/config/learncurve.py index a70ba22b2..8178e926f 100644 --- a/src/vak/config/learncurve.py +++ b/src/vak/config/learncurve.py @@ -14,8 +14,8 @@ class LearncurveConfig(TrainConfig): Attributes ---------- - models : list - of model names. e.g., 'models = TweetyNet, GRUNet, ConvNet' + model : str + Model name, e.g., ``model = "TweetyNet"`` csv_path : str path to where dataset was saved as a csv. num_epochs : int diff --git a/src/vak/config/model.py b/src/vak/config/model.py new file mode 100644 index 000000000..b3fcf779d --- /dev/null +++ b/src/vak/config/model.py @@ -0,0 +1,88 @@ +from __future__ import annotations +import pathlib + +import toml + +from .. import models + + +MODEL_TABLES = [ + "network", + "optimizer", + "loss", + "metrics", + ] + + +def config_from_toml_dict(toml_dict: dict, model_name: str) -> dict: + """Get configuration for a model from a .toml configuration file + loaded into a ``dict``. + + Parameters + ---------- + toml_dict : dict + Configuration from a .toml file, loaded into a dictionary. + model_name : str + Name of a model, specified as the ``model`` option in a table + (such as TRAIN or PREDICT), + that should have its own corresponding table + specifying its configuration: hyperparameters such as learning rate, etc. + + Returns + ------- + model_config : dict + Model configuration in a ``dict``, + as loaded from a .toml file, + and used by the model method ``from_config``. + """ + if model_name not in models.MODEL_NAMES: + raise ValueError( + f"Invalid model name: {model_name}.\nValid model names are: {models.MODEL_NAMES}" + ) + + try: + model_config = toml_dict[model_name] + except KeyError as e: + raise ValueError( + f"A config section specifies the model name '{model_name}', " + f"but there is no section named '{model_name}' in the config." + ) from e + + # check if config declares parameters for required attributes; + # if not, just put an empty dict that will get passed as the "kwargs" + for attr in MODEL_TABLES: + if attr not in model_config: + model_config[attr] = {} + + return model_config + + +def config_from_toml_path(toml_path: str | pathlib.Path, model_name: str) -> dict: + """Get configuration for a model from a .toml configuration file, + given the path to the file. + + Parameters + ---------- + toml_path : str, Path + to configuration file in .toml format + model_name : str + of str, i.e. names of models specified by a section + (such as TRAIN or PREDICT) that should each have corresponding sections + specifying their configuration: hyperparameters such as learning rate, etc. + + Returns + ------- + model_config : dict + Model configuration in a ``dict``, + as loaded from a .toml file, + and used by the model method ``from_config``. + """ + toml_path = pathlib.Path(toml_path) + if not toml_path.is_file(): + raise FileNotFoundError( + f"File not found, or not recognized as a file: {toml_path}" + ) + + with toml_path.open("r") as fp: + config_dict = toml.load(fp) + return config_from_toml_dict(config_dict, model_name) diff --git a/src/vak/config/models.py b/src/vak/config/models.py deleted file mode 100644 index 98018a6d8..000000000 --- a/src/vak/config/models.py +++ /dev/null @@ -1,115 +0,0 @@ -from pathlib import Path - -import toml - -from .. import models - - -MODEL_TABLES = [ - "network", - "optimizer", - "loss", - "metrics", - ] - - -def map_from_config_dict(config_dict, model_names): - """map a list of model names to model configuration sections from a - config.toml file. - - Given the configuraiton in a dict and a list of model names, returns dict that - maps model names to config sections. If no section in the config.toml file - matches the model name, an error is raised. - - Used to get configuration for only the models specified - in a certain section of config.toml file, e.g. in the TRAIN section. - - The returned model-config map can be used with vak.models.from_model_config_map - (and the number of classes and input shape) to get a list of model instances - ready for training. - - Parameters - ---------- - config_dict : dict - configuration from a .toml file, loaded into a dictionary - model_names : list - of str, i.e. names of models specified by a section - (such as TRAIN or PREDICT) that should each have corresponding sections - specifying their configuration: hyperparameters such as learning rate, etc. - - Returns - ------- - model_config_map : dict - where each key is the name of a model and the corresponding value is - a section from a config.toml file. - """ - # first check whether models in list are installed - - # load entry points within function, not at module level, - # to avoid circular dependencies - # (user would be unable to import models in other packages - # if the module in the other package needed to `import vak`) - MODEL_NAMES = list(models.models.BUILTIN_MODELS.keys()) - for model_name in model_names: - if model_name not in MODEL_NAMES: - raise ValueError( - f"Invalid model name: {model_name}.\nValid model names are: {MODEL_NAMES}" - ) - - # now see if we can find corresponding sections in config - sections = list(config_dict.keys()) - model_config_map = {} - for model_name in model_names: - if model_name in sections: - model_config_dict = config_dict[model_name] - else: - # try appending 'Model' to name - tmp_model_name = f"{model_name}Model" - if tmp_model_name not in sections: - raise ValueError( - f"did not find section named {model_name} or {tmp_model_name} " - f"in config" - ) - model_config_dict = config_dict[tmp_model_name] - - # check if config declares parameters for required attributes; - # if not, just put an empty dict that will get passed as the "kwargs" - for attr in MODEL_TABLES: - if attr not in model_config_dict: - model_config_dict[attr] = {} - - model_config_map[model_name] = config_dict[model_name] - - return model_config_map - - -def map_from_path(toml_path, model_names): - """map a list of model names to sections from a .toml configuration file - that specify parameters for those models. - - Parameters - ---------- - toml_path : str, Path - to configuration file in .toml format - model_names : list - of str, i.e. names of models specified by a section - (such as TRAIN or PREDICT) that should each have corresponding sections - specifying their configuration: hyperparameters such as learning rate, etc. - - Returns - ------- - model_config_map : dict - where each key is the name of a model and the corresponding value is - a section from a config.toml file. - """ - # check config_path is a file, - # because if it doesn't ConfigParser will just return an "empty" instance w/out sections or options - toml_path = Path(toml_path) - if not toml_path.is_file(): - raise FileNotFoundError( - f"file not found, or not recognized as a file: {toml_path}" - ) - - with toml_path.open("r") as fp: - config_dict = toml.load(fp) - return map_from_config_dict(config_dict, model_names) diff --git a/src/vak/config/parse.py b/src/vak/config/parse.py index 46e9086e3..280b8e4c1 100644 --- a/src/vak/config/parse.py +++ b/src/vak/config/parse.py @@ -29,10 +29,10 @@ "checkpoint_path", "labelmap_path", "output_dir", - "models", + "model", ], "LEARNCURVE": [ - "models", + "model", "root_results_dir", "train_set_durs", "num_replicates", @@ -40,7 +40,7 @@ "PREDICT": [ "checkpoint_path", "labelmap_path", - "models", + "model", ], "PREP": [ "data_dir", @@ -48,7 +48,7 @@ ], "SPECT_PARAMS": None, "TRAIN": [ - "models", + "model", "root_results_dir", ], } diff --git a/src/vak/config/predict.py b/src/vak/config/predict.py index 065761b71..92d3eea89 100644 --- a/src/vak/config/predict.py +++ b/src/vak/config/predict.py @@ -8,7 +8,7 @@ from .validators import is_valid_model_name from .. import device -from ..converters import comma_separated_list, expanded_user_path +from ..converters import expanded_user_path @attr.s @@ -23,8 +23,8 @@ class PredictConfig: path to directory with checkpoint files saved by Torch, to reload model labelmap_path : str path to 'labelmap.json' file. - models : list - of model names. e.g., 'models = TweetyNet, GRUNet, ConvNet' + model : str + Model name, e.g., ``model = "TweetyNet"`` batch_size : int number of samples per batch presented to models during training. num_workers : int @@ -75,9 +75,8 @@ class PredictConfig: labelmap_path = attr.ib(converter=expanded_user_path) # required, model / dataloader - models = attr.ib( - converter=comma_separated_list, - validator=[instance_of(list), is_valid_model_name], + model = attr.ib( + validator=[instance_of(str), is_valid_model_name], ) batch_size = attr.ib(converter=int, validator=instance_of(int)) diff --git a/src/vak/config/train.py b/src/vak/config/train.py index 96b0a232b..678ef18cb 100644 --- a/src/vak/config/train.py +++ b/src/vak/config/train.py @@ -5,7 +5,7 @@ from .validators import is_valid_model_name from .. import device -from ..converters import bool_from_str, comma_separated_list, expanded_user_path +from ..converters import bool_from_str, expanded_user_path @attr.s @@ -14,9 +14,8 @@ class TrainConfig: Attributes ---------- - models : list - comma-separated list of model names. - e.g., 'models = TweetyNet, GRUNet, ConvNet' + model : str + Model name, e.g., ``model = "TweetyNet"`` csv_path : str path to where dataset was saved as a csv. num_epochs : int @@ -66,9 +65,8 @@ class TrainConfig: path to 'labelmap.json' file. Default is None. """ # required - models = attr.ib( - converter=comma_separated_list, - validator=[instance_of(list), is_valid_model_name], + model = attr.ib( + validator=[instance_of(str), is_valid_model_name], ) num_epochs = attr.ib(converter=int, validator=instance_of(int)) batch_size = attr.ib(converter=int, validator=instance_of(int)) diff --git a/src/vak/config/valid.toml b/src/vak/config/valid.toml index 8c41e57fe..956b67259 100644 --- a/src/vak/config/valid.toml +++ b/src/vak/config/valid.toml @@ -34,7 +34,7 @@ audio_path_key = 'audio_path' window_size = 88 [TRAIN] -models = 'TweetyNet' +model = 'TweetyNet' root_results_dir = './tests/test_data/results/train' csv_path = 'tests/test_data/prep/train/032312_prep_191224_225912.csv' num_workers = 4 @@ -57,7 +57,7 @@ csv_path = 'tests/test_data/prep/learncurve/032312_prep_191224_225910.csv' checkpoint_path = '/home/user/results_181014_194418/TweetyNet/checkpoints/' labelmap_path = '/home/user/results_181014_194418/labelmap.json' output_dir = './tests/test_data/prep/learncurve' -models = 'TweetyNet' +model = 'TweetyNet' batch_size = 11 num_workers = 4 device = 'cuda' @@ -65,7 +65,7 @@ spect_scaler_path = '/home/user/results_181014_194418/spect_scaler' post_tfm_kwargs = {'majority_vote' = true, 'min_segment_dur' = 0.01} [LEARNCURVE] -models = 'TweetyNet' +model = 'TweetyNet' root_results_dir = './tests/test_data/results/learncurve' batch_size = 11 num_epochs = 2 @@ -90,7 +90,7 @@ checkpoint_path = '/home/user/results_181014_194418/TweetyNet/checkpoints/' labelmap_path = '/home/user/results_181014_194418/labelmap.json' annot_csv_filename = '032312_prep_191224_225910.annot.csv' output_dir = './tests/test_data/prep/learncurve' -models = 'TweetyNet' +model = 'TweetyNet' batch_size = 11 num_workers = 4 device = 'cuda' diff --git a/src/vak/config/validators.py b/src/vak/config/validators.py index 4041970e1..02c3fa5b7 100644 --- a/src/vak/config/validators.py +++ b/src/vak/config/validators.py @@ -25,13 +25,12 @@ def is_a_file(instance, attribute, value): ) -def is_valid_model_name(instance, attribute, value): - MODEL_NAMES = list(models.models.BUILTIN_MODELS.keys()) - for model_name in value: - if model_name not in MODEL_NAMES: - raise ValueError( - f"Invalid model name: {model_name}.\nValid model names are: {MODEL_NAMES}" - ) +def is_valid_model_name(instance, attribute, value: str) -> None: + """Validate model name.""" + if value not in models.MODEL_NAMES: + raise ValueError( + f"Invalid model name: {value}.\nValid model names are: {models.MODEL_NAMES}" + ) def is_audio_format(instance, attribute, value): @@ -91,7 +90,7 @@ def are_sections_valid(config_dict, toml_path=None): f"Please use just one command besides `prep` per .toml configuration file" ) - MODEL_NAMES = list(models.models.BUILTIN_MODELS.keys()) + MODEL_NAMES = list(models.MODEL_NAMES) # add model names to valid sections so users can define model config in sections valid_sections = VALID_SECTIONS + MODEL_NAMES for section in sections: diff --git a/src/vak/converters.py b/src/vak/converters.py index 83885c170..c00d5e69d 100644 --- a/src/vak/converters.py +++ b/src/vak/converters.py @@ -9,17 +9,6 @@ def bool_from_str(value): return bool(strtobool(value)) -def comma_separated_list(value): - if type(value) is list: - return value - elif type(value) is str: - return [element.strip() for element in value.split()] - else: - raise TypeError( - f"unexpected type when converting to comma-separated list: {type(value)}" - ) - - def expanded_user_path(value): return Path(value).expanduser() diff --git a/src/vak/core/eval.py b/src/vak/core/eval.py index 62fe7abb9..f39907a10 100644 --- a/src/vak/core/eval.py +++ b/src/vak/core/eval.py @@ -1,6 +1,5 @@ from collections import OrderedDict from datetime import datetime -import functools import json import logging @@ -24,8 +23,9 @@ def eval( + model_name: str, + model_config: dict, csv_path, - model_config_map, checkpoint_path, labelmap_path, output_dir, @@ -42,10 +42,14 @@ def eval( Parameters ---------- + model_name : str + Model name, must be one of vak.models.MODEL_NAMES. + model_config : dict + Model configuration in a ``dict``, + as loaded from a .toml file, + and used by the model method ``from_config``. csv_path : str, pathlib.Path path to where dataset was saved as a csv. - model_config_map : dict - where each key-value pair is model name : dict of config parameters checkpoint_path : str, pathlib.Path path to directory with checkpoint files saved by Torch, to reload model output_dir : str, pathlib.Path @@ -183,56 +187,59 @@ def eval( else: post_tfm = None - models_map = models.from_model_config_map( - model_config_map, num_classes=len(labelmap), input_shape=input_shape, labelmap=labelmap + model = models.get( + model_name, + model_config, + num_classes=len(labelmap), + input_shape=input_shape, + labelmap=labelmap, ) - for model_name, model in models_map.items(): - logger.info(f"running evaluation for model: {model_name}") - - model.load_state_dict_from_path(checkpoint_path) + logger.info(f"running evaluation for model: {model_name}") - if device == 'cuda': - accelerator = 'gpu' - else: - accelerator = None + model.load_state_dict_from_path(checkpoint_path) - trainer_logger = lightning.loggers.TensorBoardLogger( - save_dir=output_dir - ) - trainer = lightning.Trainer(accelerator=accelerator, logger=trainer_logger) - # TODO: check for hasattr(model, test_step) and if so run test - # below, [0] because validate returns list of dicts, length of no. of val loaders - metric_vals = trainer.validate(model, dataloaders=val_loader)[0] - metric_vals = {f'avg_{k}': v for k, v in metric_vals.items()} - for metric_name, metric_val in metric_vals.items(): - if metric_name.startswith('avg_'): - logger.info( - f'{metric_name}: {metric_val:0.5f}' - ) + if device == 'cuda': + accelerator = 'gpu' + else: + accelerator = None - # create a "DataFrame" with just one row which we will save as a csv; - # the idea is to be able to concatenate csvs from multiple runs of eval - row = OrderedDict( - [ - ("model_name", model_name), - ("checkpoint_path", checkpoint_path), - ("labelmap_path", labelmap_path), - ("spect_scaler_path", spect_scaler_path), - ("csv_path", csv_path), - ] - ) - # TODO: is this still necessary after switching to Lightning? Stop saying "average"? - # order metrics by name to be extra sure they will be consistent across runs - row.update( - sorted([(k, v) for k, v in metric_vals.items() if k.startswith("avg_")]) - ) + trainer_logger = lightning.loggers.TensorBoardLogger( + save_dir=output_dir + ) + trainer = lightning.Trainer(accelerator=accelerator, logger=trainer_logger) + # TODO: check for hasattr(model, test_step) and if so run test + # below, [0] because validate returns list of dicts, length of no. of val loaders + metric_vals = trainer.validate(model, dataloaders=val_loader)[0] + metric_vals = {f'avg_{k}': v for k, v in metric_vals.items()} + for metric_name, metric_val in metric_vals.items(): + if metric_name.startswith('avg_'): + logger.info( + f'{metric_name}: {metric_val:0.5f}' + ) + + # create a "DataFrame" with just one row which we will save as a csv; + # the idea is to be able to concatenate csvs from multiple runs of eval + row = OrderedDict( + [ + ("model_name", model_name), + ("checkpoint_path", checkpoint_path), + ("labelmap_path", labelmap_path), + ("spect_scaler_path", spect_scaler_path), + ("csv_path", csv_path), + ] + ) + # TODO: is this still necessary after switching to Lightning? Stop saying "average"? + # order metrics by name to be extra sure they will be consistent across runs + row.update( + sorted([(k, v) for k, v in metric_vals.items() if k.startswith("avg_")]) + ) - # pass index into dataframe, needed when using all scalar values (a single row) - # throw away index below when saving to avoid extra column - eval_df = pd.DataFrame(row, index=[0]) - eval_csv_path = output_dir.joinpath(f"eval_{model_name}_{timenow}.csv") - logger.info(f"saving csv with evaluation metrics at: {eval_csv_path}") - eval_df.to_csv( - eval_csv_path, index=False - ) # index is False to avoid having "Unnamed: 0" column when loading + # pass index into dataframe, needed when using all scalar values (a single row) + # throw away index below when saving to avoid extra column + eval_df = pd.DataFrame(row, index=[0]) + eval_csv_path = output_dir.joinpath(f"eval_{model_name}_{timenow}.csv") + logger.info(f"saving csv with evaluation metrics at: {eval_csv_path}") + eval_df.to_csv( + eval_csv_path, index=False + ) # index is False to avoid having "Unnamed: 0" column when loading diff --git a/src/vak/core/learncurve/learncurve.py b/src/vak/core/learncurve/learncurve.py index 68e037ec3..9921ed09c 100644 --- a/src/vak/core/learncurve/learncurve.py +++ b/src/vak/core/learncurve/learncurve.py @@ -21,7 +21,8 @@ # TODO: add post_tfm_kwargs here def learning_curve( - model_config_map, + model_name: str, + model_config: dict, train_set_durs, num_replicates, csv_path, @@ -43,13 +44,20 @@ def learning_curve( patience=None, device=None, ): - """generate learning curve, by training models on training sets across a - range of sizes and then measure accuracy of those models on a test set. + """Generate a learning curve. + + Trains a class of model with a range of dataset sizes, + and then evaluates each trained model + with a test set that is held constant (unlike the training sets). Parameters ---------- - model_config_map : dict - where each key-value pair is model name : dict of config parameters + model_name : str + Model name, must be one of vak.models.MODEL_NAMES. + model_config : dict + Model configuration in a ``dict``, + as loaded from a .toml file, + and used by the model method ``from_config``. train_set_durs : list of int, durations in seconds of subsets taken from training data to create a learning curve, e.g. [5, 10, 15, 20]. @@ -252,7 +260,8 @@ def learning_curve( ) train( - model_config_map, + model_name, + model_config, this_train_dur_this_replicate_csv_path, window_size, batch_size, @@ -272,71 +281,68 @@ def learning_curve( ) logger.info( - f"Evaluating models from replicate {replicate_num} " + f"Evaluating model from replicate {replicate_num} " f"using dataset from .csv file: {this_train_dur_this_replicate_results_path}", ) - for model_name in model_config_map.keys(): - logger.info( - f"Evaluating model: {model_name}" - ) - results_model_root = ( - this_train_dur_this_replicate_results_path.joinpath(model_name) - ) - ckpt_root = results_model_root.joinpath("checkpoints") - ckpt_paths = sorted(ckpt_root.glob("*.pt")) - if any(["max-val-acc" in str(ckpt_path) for ckpt_path in ckpt_paths]): - ckpt_paths = [ - ckpt_path - for ckpt_path in ckpt_paths - if "max-val-acc" in str(ckpt_path) - ] - if len(ckpt_paths) != 1: - raise ValueError( - f"did not find a single max-val-acc checkpoint path, instead found:\n{ckpt_paths}" - ) - ckpt_path = ckpt_paths[0] - else: - if len(ckpt_paths) != 1: - raise ValueError( - f"did not find a single checkpoint path, instead found:\n{ckpt_paths}" - ) - ckpt_path = ckpt_paths[0] - logger.info( - f"Using checkpoint: {ckpt_path}" - ) - labelmap_path = this_train_dur_this_replicate_results_path.joinpath( - "labelmap.json" + results_model_root = ( + this_train_dur_this_replicate_results_path.joinpath(model_name) + ) + ckpt_root = results_model_root.joinpath("checkpoints") + ckpt_paths = sorted(ckpt_root.glob("*.pt")) + if any(["max-val-acc" in str(ckpt_path) for ckpt_path in ckpt_paths]): + ckpt_paths = [ + ckpt_path + for ckpt_path in ckpt_paths + if "max-val-acc" in str(ckpt_path) + ] + if len(ckpt_paths) != 1: + raise ValueError( + f"did not find a single max-val-acc checkpoint path, instead found:\n{ckpt_paths}" + ) + ckpt_path = ckpt_paths[0] + else: + if len(ckpt_paths) != 1: + raise ValueError( + f"did not find a single checkpoint path, instead found:\n{ckpt_paths}" + ) + ckpt_path = ckpt_paths[0] + logger.info( + f"Using checkpoint: {ckpt_path}" + ) + labelmap_path = this_train_dur_this_replicate_results_path.joinpath( + "labelmap.json" + ) + logger.info( + f"Using labelmap: {labelmap_path}" + ) + if normalize_spectrograms: + spect_scaler_path = ( + this_train_dur_this_replicate_results_path.joinpath( + "StandardizeSpect" + ) ) logger.info( - f"Using labelmap: {labelmap_path}" + f"Using spect scaler to normalize: {spect_scaler_path}", ) - if normalize_spectrograms: - spect_scaler_path = ( - this_train_dur_this_replicate_results_path.joinpath( - "StandardizeSpect" - ) - ) - logger.info( - f"Using spect scaler to normalize: {spect_scaler_path}", - ) - else: - spect_scaler_path = None + else: + spect_scaler_path = None - eval( - this_train_dur_this_replicate_csv_path, - model_config_map, - checkpoint_path=ckpt_path, - labelmap_path=labelmap_path, - output_dir=this_train_dur_this_replicate_results_path, - window_size=window_size, - num_workers=num_workers, - split="test", - spect_scaler_path=spect_scaler_path, - post_tfm_kwargs=post_tfm_kwargs, - spect_key=spect_key, - timebins_key=timebins_key, - device=device, - ) + eval( + model_name, + model_config, + this_train_dur_this_replicate_csv_path, + checkpoint_path=ckpt_path, + labelmap_path=labelmap_path, + output_dir=this_train_dur_this_replicate_results_path, + window_size=window_size, + num_workers=num_workers, + split="test", + spect_scaler_path=spect_scaler_path, + post_tfm_kwargs=post_tfm_kwargs, + spect_key=spect_key, + timebins_key=timebins_key, + device=device, + ) # ---- make a csv for analysis ------------------------------------------------------------------------------------- reg_exp_num = re.compile( diff --git a/src/vak/core/predict.py b/src/vak/core/predict.py index 702b28eaf..fd83966c2 100644 --- a/src/vak/core/predict.py +++ b/src/vak/core/predict.py @@ -1,4 +1,3 @@ -import functools import json import logging import os @@ -28,10 +27,11 @@ def predict( + model_name: str, + model_config: dict, csv_path, checkpoint_path, labelmap_path, - model_config_map, window_size, num_workers=2, spect_key="s", @@ -48,14 +48,18 @@ def predict( Parameters ---------- + model_name : str + Model name, must be one of vak.models.MODEL_NAMES. + model_config : dict + Model configuration in a ``dict``, + as loaded from a .toml file, + and used by the model method ``from_config``. csv_path : str path to where dataset was saved as a csv. checkpoint_path : str path to directory with checkpoint files saved by Torch, to reload model labelmap_path : str path to 'labelmap.json' file. - model_config_map : dict - where each key-value pair is model name : dict of config parameters window_size : int size of windows taken from spectrograms, in number of time bins, shown to neural networks @@ -183,88 +187,93 @@ def predict( input_shape = input_shape[1:] logger.info(f"shape of input to networks used for predictions: {input_shape}") - logger.info(f"instantiating models from model-config map:/n{model_config_map}") - models_map = models.from_model_config_map( - model_config_map, num_classes=len(labelmap), input_shape=input_shape, labelmap=labelmap + logger.info(f"instantiating model from config:/n{model_name}") + + model = models.get( + model_name, + model_config, + num_classes=len(labelmap), + input_shape=input_shape, + labelmap=labelmap, ) - for model_name, model in models_map.items(): - # ---------------- do the actual predicting -------------------------------------------------------------------- - logger.info(f"loading checkpoint for {model_name} from path: {checkpoint_path}") - model.load_state_dict_from_path(checkpoint_path) - if device == 'cuda': - accelerator = 'gpu' - else: - accelerator = None - trainer_logger = lightning.loggers.TensorBoardLogger( - save_dir=output_dir - ) - trainer = lightning.Trainer(accelerator=accelerator, logger=trainer_logger) + # ---------------- do the actual predicting -------------------------------------------------------------------- + logger.info(f"loading checkpoint for {model_name} from path: {checkpoint_path}") + model.load_state_dict_from_path(checkpoint_path) - logger.info(f"running predict method of {model_name}") - results = trainer.predict(model, pred_loader) - # TODO: figure out how to overload `on_predict_epoch_end` to return dict - pred_dict = { - spect_path: y_pred - for result in results - for spect_path, y_pred in result.items() - } - # ---------------- converting to annotations ------------------------------------------------------------------ - progress_bar = tqdm(pred_loader) + if device == 'cuda': + accelerator = 'gpu' + else: + accelerator = None + trainer_logger = lightning.loggers.TensorBoardLogger( + save_dir=output_dir + ) + trainer = lightning.Trainer(accelerator=accelerator, logger=trainer_logger) - annots = [] - logger.info("converting predictions to annotations") - for ind, batch in enumerate(progress_bar): - padding_mask, spect_path = batch["padding_mask"], batch["spect_path"] - padding_mask = np.squeeze(padding_mask) - if isinstance(spect_path, list) and len(spect_path) == 1: - spect_path = spect_path[0] - y_pred = pred_dict[spect_path] + logger.info(f"running predict method of {model_name}") + results = trainer.predict(model, pred_loader) + # TODO: figure out how to overload `on_predict_epoch_end` to return dict + pred_dict = { + spect_path: y_pred + for result in results + for spect_path, y_pred in result.items() + } + # ---------------- converting to annotations ------------------------------------------------------------------ + progress_bar = tqdm(pred_loader) - if save_net_outputs: - # not sure if there's a better way to get outputs into right shape; - # can't just call y_pred.reshape() because that basically flattens the whole array first - # meaning we end up with elements in the wrong order - # so instead we convert to sequence then stack horizontally, on column axis - net_output = torch.hstack(y_pred.unbind()) - net_output = net_output[:, padding_mask] - net_output = net_output.cpu().numpy() - net_output_path = output_dir.joinpath( - Path(spect_path).stem + f"{model_name}{constants.NET_OUTPUT_SUFFIX}" - ) - np.savez(net_output_path, net_output) + annots = [] + logger.info("converting predictions to annotations") + for ind, batch in enumerate(progress_bar): + padding_mask, spect_path = batch["padding_mask"], batch["spect_path"] + padding_mask = np.squeeze(padding_mask) + if isinstance(spect_path, list) and len(spect_path) == 1: + spect_path = spect_path[0] + y_pred = pred_dict[spect_path] - y_pred = torch.argmax(y_pred, dim=1) # assumes class dimension is 1 - y_pred = torch.flatten(y_pred).cpu().numpy()[padding_mask] + if save_net_outputs: + # not sure if there's a better way to get outputs into right shape; + # can't just call y_pred.reshape() because that basically flattens the whole array first + # meaning we end up with elements in the wrong order + # so instead we convert to sequence then stack horizontally, on column axis + net_output = torch.hstack(y_pred.unbind()) + net_output = net_output[:, padding_mask] + net_output = net_output.cpu().numpy() + net_output_path = output_dir.joinpath( + Path(spect_path).stem + f"{model_name}{constants.NET_OUTPUT_SUFFIX}" + ) + np.savez(net_output_path, net_output) - spect_dict = files.spect.load(spect_path) - t = spect_dict[timebins_key] + y_pred = torch.argmax(y_pred, dim=1) # assumes class dimension is 1 + y_pred = torch.flatten(y_pred).cpu().numpy()[padding_mask] - if majority_vote or min_segment_dur: - y_pred = transforms.labeled_timebins.postprocess( - y_pred, - timebin_dur=timebin_dur, - min_segment_dur=min_segment_dur, - majority_vote=majority_vote, - ) + spect_dict = files.spect.load(spect_path) + t = spect_dict[timebins_key] - labels, onsets_s, offsets_s = transforms.labeled_timebins.to_segments( + if majority_vote or min_segment_dur: + y_pred = transforms.labeled_timebins.postprocess( y_pred, - labelmap=labelmap, - t=t, - ) - if labels is None and onsets_s is None and offsets_s is None: - # handle the case when all time bins are predicted to be unlabeled - # see https://github.com/NickleDave/vak/issues/383 - continue - seq = crowsetta.Sequence.from_keyword( - labels=labels, onsets_s=onsets_s, offsets_s=offsets_s + timebin_dur=timebin_dur, + min_segment_dur=min_segment_dur, + majority_vote=majority_vote, ) - audio_fname = files.spect.find_audio_fname(spect_path) - annot = crowsetta.Annotation( - seq=seq, audio_path=audio_fname, annot_path=annot_csv_path.name - ) - annots.append(annot) + labels, onsets_s, offsets_s = transforms.labeled_timebins.to_segments( + y_pred, + labelmap=labelmap, + t=t, + ) + if labels is None and onsets_s is None and offsets_s is None: + # handle the case when all time bins are predicted to be unlabeled + # see https://github.com/NickleDave/vak/issues/383 + continue + seq = crowsetta.Sequence.from_keyword( + labels=labels, onsets_s=onsets_s, offsets_s=offsets_s + ) + + audio_fname = files.spect.find_audio_fname(spect_path) + annot = crowsetta.Annotation( + seq=seq, audio_path=audio_fname, annot_path=annot_csv_path.name + ) + annots.append(annot) - crowsetta.csv.annot2csv(annot=annots, csv_filename=annot_csv_path) + crowsetta.csv.annot2csv(annot=annots, csv_filename=annot_csv_path) diff --git a/src/vak/core/train.py b/src/vak/core/train.py index 5483ce47d..38814f483 100644 --- a/src/vak/core/train.py +++ b/src/vak/core/train.py @@ -27,7 +27,8 @@ def train( - model_config_map, + model_name, + model_config, csv_path, window_size, batch_size, @@ -51,9 +52,9 @@ def train( patience=None, device=None, ): - """Train models and save results. + """Train a model and save results. - Saves checkpoint files for models, + Saves checkpoint files for model, label map, and spectrogram scaler. These are saved either in ``results_path`` if specified, or a new directory @@ -61,8 +62,12 @@ def train( Parameters ---------- - model_config_map : dict - where each key-value pair is model name : dict of config parameters + model_name : str + Model name, must be one of vak.models.MODEL_NAMES. + model_config : dict + Model configuration in a ``dict``, + as loaded from a .toml file, + and used by the model method ``from_config``. csv_path : str path to where dataset was saved as a csv. window_size : int @@ -324,39 +329,40 @@ def train( if device is None: device = get_default_device() - models_map = models.from_model_config_map( - model_config_map, + model = models.get( + model_name, + model_config, num_classes=len(labelmap), input_shape=train_dataset.shape, labelmap=labelmap, ) - for model_name, model in models_map.items(): - if checkpoint_path is not None: - logger.info( - f"loading checkpoint for {model_name} from path: {checkpoint_path}", - ) - model.load_state_dict_from_path(checkpoint_path) - results_model_root = results_path.joinpath(model_name) - results_model_root.mkdir() - ckpt_root = results_model_root.joinpath("checkpoints") - ckpt_root.mkdir() - logger.info(f"training {model_name}") - max_steps = num_epochs * len(train_loader) - default_callback_kwargs = { - 'ckpt_root': ckpt_root, - 'ckpt_step': ckpt_step, - 'patience': patience, - } - trainer = get_default_trainer( - max_steps=max_steps, - log_save_dir=results_model_root, - val_step=val_step, - default_callback_kwargs=default_callback_kwargs, - device=device, - ) - trainer.fit( - model=model, - train_dataloaders=train_loader, - val_dataloaders=val_loader, + if checkpoint_path is not None: + logger.info( + f"loading checkpoint for {model_name} from path: {checkpoint_path}", ) + model.load_state_dict_from_path(checkpoint_path) + + results_model_root = results_path.joinpath(model_name) + results_model_root.mkdir() + ckpt_root = results_model_root.joinpath("checkpoints") + ckpt_root.mkdir() + logger.info(f"training {model_name}") + max_steps = num_epochs * len(train_loader) + default_callback_kwargs = { + 'ckpt_root': ckpt_root, + 'ckpt_step': ckpt_step, + 'patience': patience, + } + trainer = get_default_trainer( + max_steps=max_steps, + log_save_dir=results_model_root, + val_step=val_step, + default_callback_kwargs=default_callback_kwargs, + device=device, + ) + trainer.fit( + model=model, + train_dataloaders=train_loader, + val_dataloaders=val_loader, + ) diff --git a/src/vak/models/__init__.py b/src/vak/models/__init__.py index 6b56da483..7d7748228 100644 --- a/src/vak/models/__init__.py +++ b/src/vak/models/__init__.py @@ -3,8 +3,9 @@ decorator, definition, ) +from ._api import BUILTIN_MODELS, MODEL_NAMES from .base import Model -from .models import from_model_config_map +from .get import get from .teenytweetynet import TeenyTweetyNet from .tweetynet import TweetyNet from .windowed_frame_classification_model import WindowedFrameClassificationModel @@ -14,7 +15,7 @@ "base", "decorator", "definition", - "from_model_config_map", + "get", "Model", "TeenyTweetyNet", "TweetyNet", diff --git a/src/vak/models/_api.py b/src/vak/models/_api.py new file mode 100644 index 000000000..6ab2b5911 --- /dev/null +++ b/src/vak/models/_api.py @@ -0,0 +1,11 @@ +from .tweetynet import TweetyNet +from .teenytweetynet import TeenyTweetyNet + + +# TODO: Replace constant with decorator that registers models, https://github.com/vocalpy/vak/issues/623 +BUILTIN_MODELS = { + 'TweetyNet': TweetyNet, + 'TeenyTweetyNet': TeenyTweetyNet +} + +MODEL_NAMES = list(BUILTIN_MODELS.keys()) diff --git a/src/vak/models/get.py b/src/vak/models/get.py new file mode 100644 index 000000000..c7f9397db --- /dev/null +++ b/src/vak/models/get.py @@ -0,0 +1,64 @@ +"""Function that gets an instance of a model, +given its name and a configuration as a dict.""" +from __future__ import annotations + +from ._api import MODEL_NAMES + + +def get(name: str, + config: dict, + # TODO: move num_classes / input_shape into model configs + num_classes: int, + input_shape: tuple[int, int, int], + labelmap: dict): + """Get a model instance, given its name and + a configuration as a ``dict``. + + Parameters + ---------- + name : str + Model name, must be one of vak.models.MODEL_NAMES. + config: dict + Model configuration in a ``dict``, + as loaded from a .toml file, + and used by the model method ``from_config``. + num_classes : int + Number of classes model will be trained to classify. + input_shape : tuple + Of int values, sizes of dimensions, + e.g. (channels, height, width). + Batch size is not required for input shape. + post_tfm : callable + Post-processing transform that models applies during evaluation. + Default is None, in which case the model defaults to using + ``vak.transforms.labeled_timebins.ToLabels`` (that does not + apply any post-processing clean-ups). + To be valid, ``post_tfm`` must be either an instance of + ``vak.transforms.labeled_timebins.ToLabels`` or + ``vak.transforms.labeled_timebins.ToLabelsWithPostprocessing``. + + Returns + ------- + model : vak.models.Model + Instance of a sub-class of the base Model class, + e.g. a TweetyNet instance. + """ + import vak.models + + # TODO: move num_classes / input_shape into model configs + # TODO: add labelmap to config dynamically if needed? outside this function + config["network"].update( + num_classes=num_classes, + input_shape=input_shape, + ) + + try: + model_class = getattr(vak.models, name) + except AttributeError as e: + raise ValueError( + f"Invalid model name: '{name}'.\nValid model names are: {MODEL_NAMES}" + ) from e + + model = model_class.from_config(config=config, labelmap=labelmap) + + return model diff --git a/src/vak/models/models.py b/src/vak/models/models.py deleted file mode 100644 index 9d519407e..000000000 --- a/src/vak/models/models.py +++ /dev/null @@ -1,73 +0,0 @@ -"""Helper function to load models""" -from __future__ import annotations - -from .tweetynet import TweetyNet -from .teenytweetynet import TeenyTweetyNet - - -# TODO: Replace constant with decorator that registers models, https://github.com/vocalpy/vak/issues/623 -BUILTIN_MODELS = { - 'TweetyNet': TweetyNet, - 'TeenyTweetyNet': TeenyTweetyNet -} - -MODEL_NAMES = list(BUILTIN_MODELS.keys()) - - -def from_model_config_map(model_config_map: dict[str: dict], - # TODO: move num_classes / input_shape into model configs - num_classes: int, - input_shape: tuple[int, int, int], - labelmap: dict) -> dict: - """Get models that are ready to train, given their names and configurations. - - Given a dictionary that maps model names to configurations, - along with the number of classes they should be trained to discriminate and their input shape, - return a dictionary that maps model names to instances of the model - - Parameters - ---------- - model_config_map : dict - where each key-value pair is model name : dict of config parameters - num_classes : int - number of classes model will be trained to classify - input_shape : tuple, list - e.g. (channels, height, width). - Batch size is not required for input shape. - post_tfm : callable - Post-processing transform that models applies during evaluation. - Default is None, in which case the model defaults to using - ``vak.transforms.labeled_timebins.ToLabels`` (that does not - apply any post-processing clean-ups). - To be valid, ``post_tfm`` must be either an instance of - ``vak.transforms.labeled_timebins.ToLabels`` or - ``vak.transforms.labeled_timebins.ToLabelsWithPostprocessing``. - - Returns - ------- - models_map : dict - where keys are model names and values are instances of the models, ready for training - """ - import vak.models - - models_map = {} - for model_name, model_config in model_config_map.items(): - # pass section dict as kwargs to config parser function - # TODO: move num_classes / input_shape into model configs - # TODO: add labelmap to config dynamically if needed? outside this function - model_config["network"].update( - num_classes=num_classes, - input_shape=input_shape, - ) - - try: - model_class = getattr(vak.models, model_name) - except AttributeError as e: - raise ValueError( - f"Invalid model name: '{model_name}'.\nValid model names are: {MODEL_NAMES}" - ) from e - - model = model_class.from_config(config=model_config, labelmap=labelmap) - models_map[model_name] = model - - return models_map diff --git a/tests/data_for_tests/configs/invalid_option_config.toml b/tests/data_for_tests/configs/invalid_option_config.toml index 32057aecc..9b9e835a1 100644 --- a/tests/data_for_tests/configs/invalid_option_config.toml +++ b/tests/data_for_tests/configs/invalid_option_config.toml @@ -22,7 +22,7 @@ transform_type = 'log_spect' window_size = 88 [TRAIN] -models = ['TweetyNet', ] +model = 'TweetyNet' root_results_dir = '/home/user/data/subdir/' normalize_spectrograms = true num_epochs = 2 diff --git a/tests/data_for_tests/configs/invalid_section_config.toml b/tests/data_for_tests/configs/invalid_section_config.toml index cba003ffa..dfe13c504 100644 --- a/tests/data_for_tests/configs/invalid_section_config.toml +++ b/tests/data_for_tests/configs/invalid_section_config.toml @@ -22,7 +22,7 @@ transform_type = 'log_spect' window_size = 88 [TRIAN] # <-- invalid section 'TRIAN' (instead of 'TRAIN') -models = ['TweetyNet', ] +model = 'TweetyNet' root_results_dir = '/home/user/data/subdir/' normalize_spectrograms = true num_epochs = 2 diff --git a/tests/data_for_tests/configs/invalid_train_and_learncurve_config.toml b/tests/data_for_tests/configs/invalid_train_and_learncurve_config.toml index 40d27e9a1..c5e27c85f 100644 --- a/tests/data_for_tests/configs/invalid_train_and_learncurve_config.toml +++ b/tests/data_for_tests/configs/invalid_train_and_learncurve_config.toml @@ -21,7 +21,7 @@ window_size = 88 # this .toml file should cause 'vak.config.parse.from_toml' to raise a ValueError # because it defines both a TRAIN and a LEARNCURVE section [TRAIN] -models = "TweetyNet" +model = "TweetyNet" normalize_spectrograms = true batch_size = 11 num_epochs = 2 @@ -33,7 +33,7 @@ device = "cuda" root_results_dir = "./tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat" [LEARNCURVE] -models = 'TweetyNet' +model = 'TweetyNet' normalize_spectrograms = true batch_size = 11 num_epochs = 2 diff --git a/tests/data_for_tests/configs/teenytweetynet_eval_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/teenytweetynet_eval_audio_cbin_annot_notmat.toml index 90abdd5b4..646e6edca 100644 --- a/tests/data_for_tests/configs/teenytweetynet_eval_audio_cbin_annot_notmat.toml +++ b/tests/data_for_tests/configs/teenytweetynet_eval_audio_cbin_annot_notmat.toml @@ -19,7 +19,7 @@ window_size = 44 [EVAL] checkpoint_path = "~/Documents/repos/coding/birdsong/TeenyTweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/TeenyTweetyNet/checkpoints/max-val-acc-checkpoint.pt" labelmap_path = "~/Documents/repos/coding/birdsong/TeenyTweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/labelmap.json" -models = "TeenyTweetyNet" +model = "TeenyTweetyNet" batch_size = 4 num_workers = 2 device = "cuda" diff --git a/tests/data_for_tests/configs/teenytweetynet_learncurve_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/teenytweetynet_learncurve_audio_cbin_annot_notmat.toml index bfc1cfe0b..b23b508ef 100644 --- a/tests/data_for_tests/configs/teenytweetynet_learncurve_audio_cbin_annot_notmat.toml +++ b/tests/data_for_tests/configs/teenytweetynet_learncurve_audio_cbin_annot_notmat.toml @@ -20,7 +20,7 @@ transform_type = "log_spect" window_size = 44 [LEARNCURVE] -models = "TeenyTweetyNet" +model = "TeenyTweetyNet" normalize_spectrograms = true batch_size = 4 num_epochs = 2 diff --git a/tests/data_for_tests/configs/teenytweetynet_predict_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/teenytweetynet_predict_audio_cbin_annot_notmat.toml index 26614548d..228900188 100644 --- a/tests/data_for_tests/configs/teenytweetynet_predict_audio_cbin_annot_notmat.toml +++ b/tests/data_for_tests/configs/teenytweetynet_predict_audio_cbin_annot_notmat.toml @@ -18,7 +18,7 @@ window_size = 44 spect_scaler_path = "/home/user/results_181014_194418/spect_scaler" checkpoint_path = "~/Documents/repos/coding/birdsong/TeenyTweetyNet/results/BFSongRepository/bl26lb16/results_200620_164245/TeenyTweetyNet/checkpoints/max-val-acc-checkpoint.pt" labelmap_path = "~/Documents/repos/coding/birdsong/TeenyTweetyNet/results/BFSongRepository/bl26lb16/results_200620_164245/labelmap.json" -models = "TeenyTweetyNet" +model = "TeenyTweetyNet" batch_size = 4 num_workers = 2 device = "cuda" diff --git a/tests/data_for_tests/configs/teenytweetynet_predict_audio_wav_annot_birdsongrec.toml b/tests/data_for_tests/configs/teenytweetynet_predict_audio_wav_annot_birdsongrec.toml index 9cb219987..a02b34662 100644 --- a/tests/data_for_tests/configs/teenytweetynet_predict_audio_wav_annot_birdsongrec.toml +++ b/tests/data_for_tests/configs/teenytweetynet_predict_audio_wav_annot_birdsongrec.toml @@ -18,7 +18,7 @@ window_size = 44 spect_scaler_path = "/home/user/results_181014_194418/spect_scaler" checkpoint_path = "~/Documents/repos/coding/birdsong/TeenyTweetyNet/results/BFSongRepository/bl26lb16/results_200620_164245/TeenyTweetyNet/checkpoints/max-val-acc-checkpoint.pt" labelmap_path = "~/Documents/repos/coding/birdsong/TeenyTweetyNet/results/BFSongRepository/bl26lb16/results_200620_164245/labelmap.json" -models = "TeenyTweetyNet" +model = "TeenyTweetyNet" batch_size = 4 num_workers = 2 device = "cuda" diff --git a/tests/data_for_tests/configs/teenytweetynet_train_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/teenytweetynet_train_audio_cbin_annot_notmat.toml index 254fd16d2..ca5066359 100644 --- a/tests/data_for_tests/configs/teenytweetynet_train_audio_cbin_annot_notmat.toml +++ b/tests/data_for_tests/configs/teenytweetynet_train_audio_cbin_annot_notmat.toml @@ -20,7 +20,7 @@ transform_type = "log_spect" window_size = 44 [TRAIN] -models = "TeenyTweetyNet" +model = "TeenyTweetyNet" normalize_spectrograms = true batch_size = 4 num_epochs = 2 diff --git a/tests/data_for_tests/configs/teenytweetynet_train_audio_wav_annot_birdsongrec.toml b/tests/data_for_tests/configs/teenytweetynet_train_audio_wav_annot_birdsongrec.toml index 928796005..6749ce99e 100644 --- a/tests/data_for_tests/configs/teenytweetynet_train_audio_wav_annot_birdsongrec.toml +++ b/tests/data_for_tests/configs/teenytweetynet_train_audio_wav_annot_birdsongrec.toml @@ -21,7 +21,7 @@ transform_type = "log_spect" window_size = 44 [TRAIN] -models = "TeenyTweetyNet" +model = "TeenyTweetyNet" normalize_spectrograms = true batch_size = 4 num_epochs = 2 diff --git a/tests/data_for_tests/configs/teenytweetynet_train_continue_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/teenytweetynet_train_continue_audio_cbin_annot_notmat.toml index 794fab770..9e83e3322 100644 --- a/tests/data_for_tests/configs/teenytweetynet_train_continue_audio_cbin_annot_notmat.toml +++ b/tests/data_for_tests/configs/teenytweetynet_train_continue_audio_cbin_annot_notmat.toml @@ -20,7 +20,7 @@ transform_type = "log_spect" window_size = 44 [TRAIN] -models = "TeenyTweetyNet" +model = "TeenyTweetyNet" normalize_spectrograms = true batch_size = 4 num_epochs = 2 diff --git a/tests/data_for_tests/configs/teenytweetynet_train_continue_audio_wav_annot_birdsongrec.toml b/tests/data_for_tests/configs/teenytweetynet_train_continue_audio_wav_annot_birdsongrec.toml index b22610bfa..156f8a4f8 100644 --- a/tests/data_for_tests/configs/teenytweetynet_train_continue_audio_wav_annot_birdsongrec.toml +++ b/tests/data_for_tests/configs/teenytweetynet_train_continue_audio_wav_annot_birdsongrec.toml @@ -21,7 +21,7 @@ transform_type = "log_spect" window_size = 44 [TRAIN] -models = "TeenyTweetyNet" +model = "TeenyTweetyNet" normalize_spectrograms = true batch_size = 4 num_epochs = 2 diff --git a/tests/data_for_tests/configs/teenytweetynet_train_continue_spect_mat_annot_yarden.toml b/tests/data_for_tests/configs/teenytweetynet_train_continue_spect_mat_annot_yarden.toml index ea2f117a7..fa600b106 100644 --- a/tests/data_for_tests/configs/teenytweetynet_train_continue_spect_mat_annot_yarden.toml +++ b/tests/data_for_tests/configs/teenytweetynet_train_continue_spect_mat_annot_yarden.toml @@ -20,7 +20,7 @@ transform_type = "log_spect" window_size = 44 [TRAIN] -models = "TeenyTweetyNet" +model = "TeenyTweetyNet" normalize_spectrograms = false batch_size = 4 num_epochs = 2 diff --git a/tests/data_for_tests/configs/teenytweetynet_train_spect_mat_annot_yarden.toml b/tests/data_for_tests/configs/teenytweetynet_train_spect_mat_annot_yarden.toml index 66bd725e7..190c97dca 100644 --- a/tests/data_for_tests/configs/teenytweetynet_train_spect_mat_annot_yarden.toml +++ b/tests/data_for_tests/configs/teenytweetynet_train_spect_mat_annot_yarden.toml @@ -20,7 +20,7 @@ transform_type = "log_spect" window_size = 44 [TRAIN] -models = "TeenyTweetyNet" +model = "TeenyTweetyNet" normalize_spectrograms = false batch_size = 4 num_epochs = 2 diff --git a/tests/data_for_tests/configs/tweetynet_eval_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/tweetynet_eval_audio_cbin_annot_notmat.toml index 5c4f789b9..4a51f5e48 100644 --- a/tests/data_for_tests/configs/tweetynet_eval_audio_cbin_annot_notmat.toml +++ b/tests/data_for_tests/configs/tweetynet_eval_audio_cbin_annot_notmat.toml @@ -19,7 +19,7 @@ window_size = 88 [EVAL] checkpoint_path = "~/Documents/repos/coding/birdsong/tweetynet/results/BFSongRepository/gy6or6/results_200620_165308/TweetyNet/checkpoints/max-val-acc-checkpoint.pt" labelmap_path = "~/Documents/repos/coding/birdsong/tweetynet/results/BFSongRepository/gy6or6/results_200620_165308/labelmap.json" -models = "TweetyNet" +model = "TweetyNet" batch_size = 11 num_workers = 4 device = "cuda" diff --git a/tests/data_for_tests/configs/tweetynet_learncurve_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/tweetynet_learncurve_audio_cbin_annot_notmat.toml index b215ee873..2944e1852 100644 --- a/tests/data_for_tests/configs/tweetynet_learncurve_audio_cbin_annot_notmat.toml +++ b/tests/data_for_tests/configs/tweetynet_learncurve_audio_cbin_annot_notmat.toml @@ -20,7 +20,7 @@ transform_type = "log_spect" window_size = 88 [LEARNCURVE] -models = "TweetyNet" +model = "TweetyNet" normalize_spectrograms = true batch_size = 11 num_epochs = 2 diff --git a/tests/data_for_tests/configs/tweetynet_predict_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/tweetynet_predict_audio_cbin_annot_notmat.toml index 684a55ee7..1ca5e66dc 100644 --- a/tests/data_for_tests/configs/tweetynet_predict_audio_cbin_annot_notmat.toml +++ b/tests/data_for_tests/configs/tweetynet_predict_audio_cbin_annot_notmat.toml @@ -18,7 +18,7 @@ window_size = 88 spect_scaler_path = "/home/user/results_181014_194418/spect_scaler" checkpoint_path = "~/Documents/repos/coding/birdsong/tweetynet/results/BFSongRepository/bl26lb16/results_200620_164245/TweetyNet/checkpoints/max-val-acc-checkpoint.pt" labelmap_path = "~/Documents/repos/coding/birdsong/tweetynet/results/BFSongRepository/bl26lb16/results_200620_164245/labelmap.json" -models = "TweetyNet" +model = "TweetyNet" batch_size = 11 num_workers = 4 device = "cuda" diff --git a/tests/data_for_tests/configs/tweetynet_predict_audio_wav_annot_birdsongrec.toml b/tests/data_for_tests/configs/tweetynet_predict_audio_wav_annot_birdsongrec.toml index 67bdac1be..36cb8489e 100644 --- a/tests/data_for_tests/configs/tweetynet_predict_audio_wav_annot_birdsongrec.toml +++ b/tests/data_for_tests/configs/tweetynet_predict_audio_wav_annot_birdsongrec.toml @@ -18,7 +18,7 @@ window_size = 88 spect_scaler_path = "/home/user/results_181014_194418/spect_scaler" checkpoint_path = "~/Documents/repos/coding/birdsong/tweetynet/results/BFSongRepository/bl26lb16/results_200620_164245/TweetyNet/checkpoints/max-val-acc-checkpoint.pt" labelmap_path = "~/Documents/repos/coding/birdsong/tweetynet/results/BFSongRepository/bl26lb16/results_200620_164245/labelmap.json" -models = "TweetyNet" +model = "TweetyNet" batch_size = 11 num_workers = 4 device = "cuda" diff --git a/tests/data_for_tests/configs/tweetynet_train_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/tweetynet_train_audio_cbin_annot_notmat.toml index 4e23fd4f3..0d6422e5e 100644 --- a/tests/data_for_tests/configs/tweetynet_train_audio_cbin_annot_notmat.toml +++ b/tests/data_for_tests/configs/tweetynet_train_audio_cbin_annot_notmat.toml @@ -20,7 +20,7 @@ transform_type = "log_spect" window_size = 88 [TRAIN] -models = "TweetyNet" +model = "TweetyNet" normalize_spectrograms = true batch_size = 11 num_epochs = 2 diff --git a/tests/data_for_tests/configs/tweetynet_train_audio_wav_annot_birdsongrec.toml b/tests/data_for_tests/configs/tweetynet_train_audio_wav_annot_birdsongrec.toml index e00bdc0f4..5788f85d3 100644 --- a/tests/data_for_tests/configs/tweetynet_train_audio_wav_annot_birdsongrec.toml +++ b/tests/data_for_tests/configs/tweetynet_train_audio_wav_annot_birdsongrec.toml @@ -21,7 +21,7 @@ transform_type = "log_spect" window_size = 88 [TRAIN] -models = "TweetyNet" +model = "TweetyNet" normalize_spectrograms = true batch_size = 11 num_epochs = 2 diff --git a/tests/data_for_tests/configs/tweetynet_train_continue_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/tweetynet_train_continue_audio_cbin_annot_notmat.toml index bb4cf9e5c..97174a6b0 100644 --- a/tests/data_for_tests/configs/tweetynet_train_continue_audio_cbin_annot_notmat.toml +++ b/tests/data_for_tests/configs/tweetynet_train_continue_audio_cbin_annot_notmat.toml @@ -20,7 +20,7 @@ transform_type = "log_spect" window_size = 88 [TRAIN] -models = "TweetyNet" +model = "TweetyNet" normalize_spectrograms = true batch_size = 11 num_epochs = 2 diff --git a/tests/data_for_tests/configs/tweetynet_train_continue_audio_wav_annot_birdsongrec.toml b/tests/data_for_tests/configs/tweetynet_train_continue_audio_wav_annot_birdsongrec.toml index c55ef78a9..69847e035 100644 --- a/tests/data_for_tests/configs/tweetynet_train_continue_audio_wav_annot_birdsongrec.toml +++ b/tests/data_for_tests/configs/tweetynet_train_continue_audio_wav_annot_birdsongrec.toml @@ -21,7 +21,7 @@ transform_type = "log_spect" window_size = 88 [TRAIN] -models = "TweetyNet" +model = "TweetyNet" normalize_spectrograms = true batch_size = 11 num_epochs = 2 diff --git a/tests/data_for_tests/configs/tweetynet_train_continue_spect_mat_annot_yarden.toml b/tests/data_for_tests/configs/tweetynet_train_continue_spect_mat_annot_yarden.toml index bac3c5d91..6e7b0f707 100644 --- a/tests/data_for_tests/configs/tweetynet_train_continue_spect_mat_annot_yarden.toml +++ b/tests/data_for_tests/configs/tweetynet_train_continue_spect_mat_annot_yarden.toml @@ -20,7 +20,7 @@ transform_type = "log_spect" window_size = 88 [TRAIN] -models = "TweetyNet" +model = "TweetyNet" normalize_spectrograms = false batch_size = 11 num_epochs = 2 diff --git a/tests/data_for_tests/configs/tweetynet_train_spect_mat_annot_yarden.toml b/tests/data_for_tests/configs/tweetynet_train_spect_mat_annot_yarden.toml index 9d9b1c637..6766f4345 100644 --- a/tests/data_for_tests/configs/tweetynet_train_spect_mat_annot_yarden.toml +++ b/tests/data_for_tests/configs/tweetynet_train_spect_mat_annot_yarden.toml @@ -20,7 +20,7 @@ transform_type = "log_spect" window_size = 88 [TRAIN] -models = "TweetyNet" +model = "TweetyNet" normalize_spectrograms = false batch_size = 11 num_epochs = 2 diff --git a/tests/fixtures/config.py b/tests/fixtures/config.py index 0eb764ab4..603512599 100644 --- a/tests/fixtures/config.py +++ b/tests/fixtures/config.py @@ -5,11 +5,14 @@ import pytest import toml -from .test_data import GENERATED_TEST_DATA_ROOT +from .test_data import GENERATED_TEST_DATA_ROOT, TEST_DATA_ROOT + + +TEST_CONFIGS_ROOT = TEST_DATA_ROOT.joinpath("configs") @pytest.fixture -def test_configs_root(test_data_root): +def test_configs_root(): """Path that points to data_for_tests/configs Two types of config files in this directory: @@ -20,7 +23,7 @@ def test_configs_root(test_data_root): This fixture facilitates access to type (2), e.g. in test_config/test_parse """ - return test_data_root.joinpath("configs") + return TEST_CONFIGS_ROOT @pytest.fixture @@ -75,10 +78,13 @@ def generated_test_configs_root(): return GENERATED_TEST_CONFIGS_ROOT +ALL_GENERATED_CONFIGS = sorted(GENERATED_TEST_CONFIGS_ROOT.glob("*toml")) + + # ---- path to config files ---- @pytest.fixture -def all_generated_configs(generated_test_configs_root): - return sorted(generated_test_configs_root.glob("*toml")) +def all_generated_configs(): + return ALL_GENERATED_CONFIGS @pytest.fixture @@ -241,9 +247,12 @@ def _specific_config_toml( return _specific_config_toml +ALL_GENERATED_CONFIGS_TOML = [_return_toml(config) for config in ALL_GENERATED_CONFIGS] + + @pytest.fixture -def all_generated_configs_toml(all_generated_configs): - return [_return_toml(config) for config in all_generated_configs] +def all_generated_configs_toml(): + return ALL_GENERATED_CONFIGS_TOML @pytest.fixture @@ -266,17 +275,26 @@ def all_generated_predict_configs_toml(all_generated_predict_configs): return [_return_toml(config) for config in all_generated_predict_configs] +ALL_GENERATED_CONFIGS_TOML_PATH_PAIRS = list(zip( + [_return_toml(config) for config in ALL_GENERATED_CONFIGS], + ALL_GENERATED_CONFIGS, +)) + + # ---- config toml + path pairs ---- @pytest.fixture -def all_generated_configs_toml_path_pairs(all_generated_configs): +def all_generated_configs_toml_path_pairs(): """zip of tuple pairs: (dict, pathlib.Path) where ``Path`` is path to .toml config file and ``dict`` is the .toml config from that path loaded into a dict with the ``toml`` library """ + # we duplicate the constant above because we need to remake + # the variables for each unit test. Otherwise tests that modify values + # for config options cause other tests to fail return zip( - [_return_toml(config) for config in all_generated_configs], - all_generated_configs, + [_return_toml(config) for config in ALL_GENERATED_CONFIGS], + ALL_GENERATED_CONFIGS ) diff --git a/tests/test_config/test_model.py b/tests/test_config/test_model.py new file mode 100644 index 000000000..7f5e56c4b --- /dev/null +++ b/tests/test_config/test_model.py @@ -0,0 +1,62 @@ +import copy +import pytest + +from ..fixtures import ( + ALL_GENERATED_CONFIGS_TOML, + ALL_GENERATED_CONFIGS_TOML_PATH_PAIRS +) + +import vak.config.model + + +def _make_expected_config(model_config: dict) -> dict: + for attr in vak.config.model.MODEL_TABLES: + if attr not in model_config: + model_config[attr] = {} + return model_config + + +@pytest.mark.parametrize( + 'toml_dict', + ALL_GENERATED_CONFIGS_TOML +) +def test_config_from_toml_dict(toml_dict): + for section_name in ('TRAIN', 'EVAL', 'LEARNCURVE', 'PREDICT'): + try: + section = toml_dict[section_name] + except KeyError: + continue + model_name = section['model'] + # we need to copy so that we don't silently fail to detect mistakes + # by comparing a reference to the dict with itself + expected_model_config = copy.deepcopy( + toml_dict[model_name] + ) + expected_model_config = _make_expected_config(expected_model_config) + + model_config = vak.config.model.config_from_toml_dict(toml_dict, model_name) + + assert model_config == expected_model_config + + +@pytest.mark.parametrize( + 'toml_dict, toml_path', + ALL_GENERATED_CONFIGS_TOML_PATH_PAIRS +) +def test_config_from_toml_path(toml_dict, toml_path): + for section_name in ('TRAIN', 'EVAL', 'LEARNCURVE', 'PREDICT'): + try: + section = toml_dict[section_name] + except KeyError: + continue + model_name = section['model'] + # we need to copy so that we don't silently fail to detect mistakes + # by comparing a reference to the dict with itself + expected_model_config = copy.deepcopy( + toml_dict[model_name] + ) + expected_model_config = _make_expected_config(expected_model_config) + + model_config = vak.config.model.config_from_toml_path(toml_path, model_name) + + assert model_config == expected_model_config diff --git a/tests/test_config/test_parse.py b/tests/test_config/test_parse.py index 8d64398b8..d18c1c71d 100644 --- a/tests/test_config/test_parse.py +++ b/tests/test_config/test_parse.py @@ -91,7 +91,7 @@ def test_parse_config_section_model_not_installed_raises( if section_name.lower() in toml_path.name: break # use these. Only need to test on one - config_toml[section_name]["models"] = "NotInstalledNet, OtherNotInstalledNet" + config_toml[section_name]["model"] = "NotInstalledNet" with pytest.raises(ValueError): vak.config.parse.parse_config_section( config_toml=config_toml, section_name=section_name, toml_path=toml_path diff --git a/tests/test_core/test_eval.py b/tests/test_core/test_eval.py index f9f3dbf62..e71675093 100644 --- a/tests/test_core/test_eval.py +++ b/tests/test_core/test_eval.py @@ -8,12 +8,9 @@ # written as separate function so we can re-use in tests/unit/test_cli/test_eval.py -def eval_output_matches_expected(model_config_map, output_dir): - for model_name in model_config_map.keys(): - eval_csv = sorted(output_dir.glob(f"eval_{model_name}*csv")) - assert len(eval_csv) == 1 - - return True +def assert_eval_output_matches_expected(model_name, output_dir): + eval_csv = sorted(output_dir.glob(f"eval_{model_name}*csv")) + assert len(eval_csv) == 1 # -- we do eval with all possible configurations of post_tfm_kwargs @@ -71,11 +68,12 @@ def test_eval( options_to_change=options_to_change, ) cfg = vak.config.parse.from_toml_path(toml_path) - model_config_map = vak.config.models.map_from_path(toml_path, cfg.eval.models) + model_config = vak.config.model.config_from_toml_path(toml_path, cfg.eval.model) vak.core.eval( - cfg.eval.csv_path, - model_config_map, + model_name=cfg.eval.model, + model_config=model_config, + csv_path=cfg.eval.csv_path, checkpoint_path=cfg.eval.checkpoint_path, labelmap_path=cfg.eval.labelmap_path, output_dir=cfg.eval.output_dir, @@ -88,7 +86,7 @@ def test_eval( post_tfm_kwargs=post_tfm_kwargs, ) - assert eval_output_matches_expected(model_config_map, output_dir) + assert_eval_output_matches_expected(cfg.eval.model, output_dir) @pytest.mark.parametrize( @@ -130,11 +128,12 @@ def test_eval_raises_file_not_found( options_to_change=options_to_change, ) cfg = vak.config.parse.from_toml_path(toml_path) - model_config_map = vak.config.models.map_from_path(toml_path, cfg.eval.models) + model_config = vak.config.model.config_from_toml_path(toml_path, cfg.eval.model) with pytest.raises(FileNotFoundError): vak.core.eval( + model_name=cfg.eval.model, + model_config=model_config, csv_path=cfg.eval.csv_path, - model_config_map=model_config_map, checkpoint_path=cfg.eval.checkpoint_path, labelmap_path=cfg.eval.labelmap_path, output_dir=cfg.eval.output_dir, @@ -168,11 +167,12 @@ def test_eval_raises_not_a_directory( options_to_change=options_to_change, ) cfg = vak.config.parse.from_toml_path(toml_path) - model_config_map = vak.config.models.map_from_path(toml_path, cfg.eval.models) + model_config = vak.config.model.config_from_toml_path(toml_path, cfg.eval.model) with pytest.raises(NotADirectoryError): vak.core.eval( + model_name=cfg.eval.model, + model_config=model_config, csv_path=cfg.eval.csv_path, - model_config_map=model_config_map, checkpoint_path=cfg.eval.checkpoint_path, labelmap_path=cfg.eval.labelmap_path, output_dir=cfg.eval.output_dir, diff --git a/tests/test_core/test_learncurve.py b/tests/test_core/test_learncurve.py index 293f100d6..03e5313c8 100644 --- a/tests/test_core/test_learncurve.py +++ b/tests/test_core/test_learncurve.py @@ -7,7 +7,7 @@ import vak.paths -def learncurve_output_matches_expected(cfg, model_config_map, results_path): +def assert_learncurve_output_matches_expected(cfg, model_name, results_path): assert results_path.joinpath("learning_curve.csv").exists() for train_set_dur in cfg.learncurve.train_set_durs: @@ -30,27 +30,24 @@ def learncurve_output_matches_expected(cfg, model_config_map, results_path): if cfg.learncurve.normalize_spectrograms: assert replicate_path.joinpath("StandardizeSpect").exists() - for model_name in model_config_map.keys(): - eval_csv = sorted(replicate_path.glob(f"eval_{model_name}*csv")) - assert len(eval_csv) == 1 + eval_csv = sorted(replicate_path.glob(f"eval_{model_name}*csv")) + assert len(eval_csv) == 1 - model_path = replicate_path.joinpath(model_name) - assert model_path.exists() + model_path = replicate_path.joinpath(model_name) + assert model_path.exists() - tensorboard_log = sorted( - replicate_path.glob(f"lightning_logs/**/*events*") - ) - assert len(tensorboard_log) == 1 + tensorboard_log = sorted( + replicate_path.glob(f"lightning_logs/**/*events*") + ) + assert len(tensorboard_log) == 1 - checkpoints_path = model_path.joinpath("checkpoints") - assert checkpoints_path.exists() - assert checkpoints_path.joinpath("checkpoint.pt").exists() - if cfg.learncurve.val_step is not None: - assert checkpoints_path.joinpath( - "max-val-acc-checkpoint.pt" - ).exists() - - return True + checkpoints_path = model_path.joinpath("checkpoints") + assert checkpoints_path.exists() + assert checkpoints_path.joinpath("checkpoint.pt").exists() + if cfg.learncurve.val_step is not None: + assert checkpoints_path.joinpath( + "max-val-acc-checkpoint.pt" + ).exists() def test_learncurve(specific_config, tmp_path, model, device): @@ -65,12 +62,13 @@ def test_learncurve(specific_config, tmp_path, model, device): ) cfg = vak.config.parse.from_toml_path(toml_path) - model_config_map = vak.config.models.map_from_path(toml_path, cfg.learncurve.models) + model_config = vak.config.model.config_from_toml_path(toml_path, cfg.learncurve.model) results_path = vak.paths.generate_results_dir_name_as_path(tmp_path) results_path.mkdir() vak.core.learning_curve( - model_config_map, + model_name=cfg.learncurve.model, + model_config=model_config, train_set_durs=cfg.learncurve.train_set_durs, num_replicates=cfg.learncurve.num_replicates, csv_path=cfg.learncurve.csv_path, @@ -92,7 +90,7 @@ def test_learncurve(specific_config, tmp_path, model, device): device=cfg.learncurve.device, ) - assert learncurve_output_matches_expected(cfg, model_config_map, results_path) + assert_learncurve_output_matches_expected(cfg, cfg.learncurve.model, results_path) def test_learncurve_no_results_path(specific_config, tmp_path, model, device): @@ -117,10 +115,11 @@ def test_learncurve_no_results_path(specific_config, tmp_path, model, device): ) cfg = vak.config.parse.from_toml_path(toml_path) - model_config_map = vak.config.models.map_from_path(toml_path, cfg.learncurve.models) + model_config = vak.config.model.config_from_toml_path(toml_path, cfg.learncurve.model) vak.core.learning_curve( - model_config_map, + model_name=cfg.learncurve.model, + model_config=model_config, train_set_durs=cfg.learncurve.train_set_durs, num_replicates=cfg.learncurve.num_replicates, csv_path=cfg.learncurve.csv_path, @@ -146,7 +145,7 @@ def test_learncurve_no_results_path(specific_config, tmp_path, model, device): assert len(results_path) == 1 results_path = results_path[0] - assert learncurve_output_matches_expected(cfg, model_config_map, results_path) + assert_learncurve_output_matches_expected(cfg, cfg.learncurve.model, results_path) @pytest.mark.parametrize("window_size", @@ -184,10 +183,11 @@ def test_learncurve_previous_run_path( ) cfg = vak.config.parse.from_toml_path(toml_path) - model_config_map = vak.config.models.map_from_path(toml_path, cfg.learncurve.models) + model_config = vak.config.model.config_from_toml_path(toml_path, cfg.learncurve.model) vak.core.learning_curve( - model_config_map, + model_name=cfg.learncurve.model, + model_config=model_config, train_set_durs=cfg.learncurve.train_set_durs, num_replicates=cfg.learncurve.num_replicates, csv_path=cfg.learncurve.csv_path, @@ -213,7 +213,7 @@ def test_learncurve_previous_run_path( assert len(results_path) == 1 results_path = results_path[0] - assert learncurve_output_matches_expected(cfg, model_config_map, results_path) + assert_learncurve_output_matches_expected(cfg, cfg.learncurve.model, results_path) def test_learncurve_invalid_csv_path_raises(specific_config, tmp_path, device): @@ -232,14 +232,15 @@ def test_learncurve_invalid_csv_path_raises(specific_config, tmp_path, device): ) cfg = vak.config.parse.from_toml_path(toml_path) - model_config_map = vak.config.models.map_from_path(toml_path, cfg.learncurve.models) + model_config = vak.config.model.config_from_toml_path(toml_path, cfg.learncurve.model) results_path = vak.paths.generate_results_dir_name_as_path(tmp_path) results_path.mkdir() invalid_csv_path = '/obviously/doesnt/exist/dataset.csv' with pytest.raises(FileNotFoundError): vak.core.learning_curve( - model_config_map, + model_name=cfg.learncurve.model, + model_config=model_config, train_set_durs=cfg.learncurve.train_set_durs, num_replicates=cfg.learncurve.num_replicates, csv_path=invalid_csv_path, @@ -288,13 +289,14 @@ def test_learncurve_raises_not_a_directory(dir_option_to_change, options_to_change=options_to_change, ) cfg = vak.config.parse.from_toml_path(toml_path) - model_config_map = vak.config.models.map_from_path(toml_path, cfg.learncurve.models) + model_config = vak.config.model.config_from_toml_path(toml_path, cfg.learncurve.model) # mock behavior of cli.learncurve, building `results_path` from config option `root_results_dir` results_path = cfg.learncurve.root_results_dir / 'results-dir-timestamp' with pytest.raises(NotADirectoryError): vak.core.learning_curve( - model_config_map, + model_name=cfg.learncurve.model, + model_config=model_config, train_set_durs=cfg.learncurve.train_set_durs, num_replicates=cfg.learncurve.num_replicates, csv_path=cfg.learncurve.csv_path, diff --git a/tests/test_core/test_predict.py b/tests/test_core/test_predict.py index b1f2a6a05..a9e1e3086 100644 --- a/tests/test_core/test_predict.py +++ b/tests/test_core/test_predict.py @@ -10,12 +10,10 @@ # written as separate function so we can re-use in tests/unit/test_cli/test_predict.py -def predict_output_matches_expected(output_dir, annot_csv_filename): +def assert_predict_output_matches_expected(output_dir, annot_csv_filename): annot_csv = output_dir.joinpath(annot_csv_filename) assert annot_csv.exists() - return True - @pytest.mark.parametrize( "audio_format, spect_format, annot_format, save_net_outputs", @@ -55,13 +53,14 @@ def test_predict( ) cfg = vak.config.parse.from_toml_path(toml_path) - model_config_map = vak.config.models.map_from_path(toml_path, cfg.predict.models) + model_config = vak.config.model.config_from_toml_path(toml_path, cfg.predict.model) vak.core.predict( + model_name=cfg.predict.model, + model_config=model_config, csv_path=cfg.predict.csv_path, checkpoint_path=cfg.predict.checkpoint_path, labelmap_path=cfg.predict.labelmap_path, - model_config_map=model_config_map, window_size=cfg.dataloader.window_size, num_workers=cfg.predict.num_workers, spect_key=cfg.spect_params.spect_key, @@ -75,7 +74,7 @@ def test_predict( save_net_outputs=cfg.predict.save_net_outputs, ) - assert predict_output_matches_expected(output_dir, cfg.predict.annot_csv_filename) + assert_predict_output_matches_expected(output_dir, cfg.predict.annot_csv_filename) if save_net_outputs: net_outputs = sorted( Path(output_dir).glob(f"*{vak.constants.NET_OUTPUT_SUFFIX}") @@ -127,14 +126,15 @@ def test_predict_raises_file_not_found( ) cfg = vak.config.parse.from_toml_path(toml_path) - model_config_map = vak.config.models.map_from_path(toml_path, cfg.predict.models) + model_config = vak.config.model.config_from_toml_path(toml_path, cfg.predict.model) with pytest.raises(FileNotFoundError): vak.core.predict( + model_name=cfg.predict.model, + model_config=model_config, csv_path=cfg.predict.csv_path, checkpoint_path=cfg.predict.checkpoint_path, labelmap_path=cfg.predict.labelmap_path, - model_config_map=model_config_map, window_size=cfg.dataloader.window_size, num_workers=cfg.predict.num_workers, spect_key=cfg.spect_params.spect_key, @@ -168,14 +168,15 @@ def test_predict_raises_not_a_directory( options_to_change=options_to_change, ) cfg = vak.config.parse.from_toml_path(toml_path) - model_config_map = vak.config.models.map_from_path(toml_path, cfg.predict.models) + model_config = vak.config.model.config_from_toml_path(toml_path, cfg.predict.model) with pytest.raises(NotADirectoryError): vak.core.predict( + model_name=cfg.predict.model, + model_config=model_config, csv_path=cfg.predict.csv_path, checkpoint_path=cfg.predict.checkpoint_path, labelmap_path=cfg.predict.labelmap_path, - model_config_map=model_config_map, window_size=cfg.dataloader.window_size, num_workers=cfg.predict.num_workers, spect_key=cfg.spect_params.spect_key, diff --git a/tests/test_core/test_prep.py b/tests/test_core/test_prep.py index abec79cb3..50300ccd7 100644 --- a/tests/test_core/test_prep.py +++ b/tests/test_core/test_prep.py @@ -14,7 +14,7 @@ # written as separate function so we can re-use in tests/unit/test_cli/test_prep.py -def prep_output_matches_expected(csv_path, df_returned_by_prep): +def assert_prep_output_matches_expected(csv_path, df_returned_by_prep): assert Path(csv_path).exists() df_from_csv_path = pd.read_csv(csv_path) @@ -29,8 +29,6 @@ def prep_output_matches_expected(csv_path, df_returned_by_prep): check_exact=check_exact, ) - return True - @pytest.mark.parametrize( "config_type, audio_format, spect_format, annot_format", @@ -102,7 +100,7 @@ def test_prep( test_dur=cfg.prep.test_dur, ) - assert prep_output_matches_expected(csv_path, vak_df) + assert_prep_output_matches_expected(csv_path, vak_df) @pytest.mark.parametrize( diff --git a/tests/test_core/test_train.py b/tests/test_core/test_train.py index fbbb9111c..3f9df28ce 100644 --- a/tests/test_core/test_train.py +++ b/tests/test_core/test_train.py @@ -7,7 +7,7 @@ import vak.core.train -def train_output_matches_expected(cfg, model_config_map, results_path): +def assert_train_output_matches_expected(cfg, model_name, results_path): assert results_path.joinpath("labelmap.json").exists() if cfg.train.normalize_spectrograms or cfg.train.spect_scaler_path: @@ -15,22 +15,19 @@ def train_output_matches_expected(cfg, model_config_map, results_path): else: assert not results_path.joinpath("StandardizeSpect").exists() - for model_name in model_config_map.keys(): - model_path = results_path.joinpath(model_name) - assert model_path.exists() + model_path = results_path.joinpath(model_name) + assert model_path.exists() - tensorboard_log = sorted( - model_path.glob(f"lightning_logs/**/*events*") - ) - assert len(tensorboard_log) == 1 - - checkpoints_path = model_path.joinpath("checkpoints") - assert checkpoints_path.exists() - assert checkpoints_path.joinpath("checkpoint.pt").exists() - if cfg.train.val_step is not None: - assert checkpoints_path.joinpath("max-val-acc-checkpoint.pt").exists() + tensorboard_log = sorted( + model_path.glob(f"lightning_logs/**/*events*") + ) + assert len(tensorboard_log) == 1 - return True + checkpoints_path = model_path.joinpath("checkpoints") + assert checkpoints_path.exists() + assert checkpoints_path.joinpath("checkpoint.pt").exists() + if cfg.train.val_step is not None: + assert checkpoints_path.joinpath("max-val-acc-checkpoint.pt").exists() @pytest.mark.parametrize( @@ -59,10 +56,11 @@ def test_train( options_to_change=options_to_change, ) cfg = vak.config.parse.from_toml_path(toml_path) - model_config_map = vak.config.models.map_from_path(toml_path, cfg.train.models) + model_config = vak.config.model.config_from_toml_path(toml_path, cfg.train.model) vak.core.train( - model_config_map, + cfg.train.model, + model_config, cfg.train.csv_path, cfg.dataloader.window_size, cfg.train.batch_size, @@ -80,7 +78,7 @@ def test_train( device=cfg.train.device, ) - assert train_output_matches_expected(cfg, model_config_map, results_path) + assert_train_output_matches_expected(cfg, cfg.train.model, results_path) @pytest.mark.parametrize( @@ -109,10 +107,11 @@ def test_continue_training( options_to_change=options_to_change, ) cfg = vak.config.parse.from_toml_path(toml_path) - model_config_map = vak.config.models.map_from_path(toml_path, cfg.train.models) + model_config = vak.config.model.config_from_toml_path(toml_path, cfg.train.model) vak.core.train( - model_config_map=model_config_map, + model_name=cfg.train.model, + model_config=model_config, csv_path=cfg.train.csv_path, window_size=cfg.dataloader.window_size, batch_size=cfg.train.batch_size, @@ -131,7 +130,7 @@ def test_continue_training( device=cfg.train.device, ) - assert train_output_matches_expected(cfg, model_config_map, results_path) + assert_train_output_matches_expected(cfg, cfg.train.model, results_path) @pytest.mark.parametrize( @@ -163,13 +162,14 @@ def test_train_raises_file_not_found( options_to_change=options_to_change, ) cfg = vak.config.parse.from_toml_path(toml_path) - model_config_map = vak.config.models.map_from_path(toml_path, cfg.train.models) + model_config = vak.config.model.config_from_toml_path(toml_path, cfg.train.model) results_path = vak.paths.generate_results_dir_name_as_path(tmp_path) results_path.mkdir() with pytest.raises(FileNotFoundError): vak.core.train( - model_config_map=model_config_map, + model_name=cfg.train.model, + model_config=model_config, csv_path=cfg.train.csv_path, window_size=cfg.dataloader.window_size, batch_size=cfg.train.batch_size, @@ -210,14 +210,15 @@ def test_train_raises_not_a_directory( options_to_change=options_to_change, ) cfg = vak.config.parse.from_toml_path(toml_path) - model_config_map = vak.config.models.map_from_path(toml_path, cfg.train.models) + model_config = vak.config.model.config_from_toml_path(toml_path, cfg.train.model) # mock behavior of cli.train, building `results_path` from config option `root_results_dir` results_path = cfg.train.root_results_dir / 'results-dir-timestamp' with pytest.raises(NotADirectoryError): vak.core.train( - model_config_map=model_config_map, + model_name=cfg.train.model, + model_config=model_config, csv_path=cfg.train.csv_path, window_size=cfg.dataloader.window_size, batch_size=cfg.train.batch_size, @@ -264,13 +265,14 @@ def test_both_labelset_and_labelmap_raises( options_to_change=options_to_change, ) cfg = vak.config.parse.from_toml_path(toml_path) - model_config_map = vak.config.models.map_from_path(toml_path, cfg.train.models) + model_config = vak.config.model.config_from_toml_path(toml_path, cfg.train.model) results_path = vak.paths.generate_results_dir_name_as_path(tmp_path) results_path.mkdir() with pytest.raises(ValueError): vak.core.train( - model_config_map=model_config_map, + model_name=cfg.train.model, + model_config=model_config, csv_path=cfg.train.csv_path, window_size=cfg.dataloader.window_size, batch_size=cfg.train.batch_size, diff --git a/tests/test_models/test_windowed_frame_classification_model.py b/tests/test_models/test_windowed_frame_classification_model.py index 65cb74ad3..0495181a8 100644 --- a/tests/test_models/test_windowed_frame_classification_model.py +++ b/tests/test_models/test_windowed_frame_classification_model.py @@ -98,8 +98,7 @@ def test_from_config(self, vak.models.WindowedFrameClassificationModel, 'definition', definition, raising=False ) - model_config_map = vak.config.models.map_from_path(toml_path, cfg.train.models) - model_name, config = list(model_config_map.items())[0] + config = vak.config.model.config_from_toml_path(toml_path, cfg.train.model) config["network"].update( num_classes=len(labelmap), input_shape=self.MOCK_INPUT_SHAPE,