-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
CLN: Remove unused multiple model functionality, fix #538
> The config allows instantiating multiple models for training / prediction but this was never fully implemented Squash commits: - Make config just use "model" option, not "models": - Remove `comma_separated_list` from converters - Change option name 'models' -> 'model' in config/valid.toml - Rewrite is_valid_model_name to test a single string, not a list of strings - Change attribute `models` -> `model` in config/eval.py - Change attribute `models` -> `model` in config/learncurve.py - Change attribute `models` -> `model` in config/predict.py - Change attribute `models` -> `model` in config/train.py - Rewrite/rename config.models -> model.config_from_toml_path, model.config_from_toml_dict - Fix option 'models' -> model = 'str', in all .toml files in tests/data_for_tests/configs - Rewrite `models.from_model_config_map` as `models.get`: - Add src/vak/models/_api.py with BUILTIN_MODELS and MODEL_NAMES to use for validation in `models.get` - Rewrite models/models.py `from_model_config_map` as `models.get` - Import get and _api in vak/models/__init__.py - Rewrite core/train.py to take model_name and model_config then use models.get - Fix cli/train.py to pass model_name and model_config into core.train - Rewrite core/eval.py to take model_name and model_config then use models.get - Fix cli/eval.py to pass model_name and model_config into core.eval - Rewrite core/learncurve.py to take model_name and model_config - Fix cli/learncurve.py to pass model_name and model_config into core.learncurve - Rewrite core/predict.py to take model_name and model_config then use models.get - Fix cli/predict.py to pass model_name and model_config into core.predict - Make 'model' not 'models' required in src/vak/config/parse.py - Use models.MODEL_NAMES in src/vak/config/validators.py - Use models.MODEL_NAMES in config/model.py - Fix tests - Fix tests to use vak.config.model.config_from_toml_path: - tests/test_models/test_windowed_frame_classification_model.py - tests/test_core/test_train.py - tests/test_core/test_predict.py - tests/test_core/test_learncurve.py - tests/test_core/test_eval.py - Fix test to use 'model' option not 'models' in tests/test_config/test_parse.py - Fix assert helper function in tests/test_core - test_eval.py - test_learncurve.py - test_predict.py - test_prep.py - test_train.py - Rewrite fixture module with constants we can import in test modules to parametrize: tests/fixtures/config.py - Add tests/test_config/test_model.py
- Loading branch information
1 parent
f8c2535
commit b206669
Showing
55 changed files
with
677 additions
and
599 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,7 @@ | |
dataloader, | ||
eval, | ||
learncurve, | ||
models, | ||
model, | ||
parse, | ||
predict, | ||
prep, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
from __future__ import annotations | ||
import pathlib | ||
|
||
import toml | ||
|
||
from .. import models | ||
|
||
|
||
MODEL_TABLES = [ | ||
"network", | ||
"optimizer", | ||
"loss", | ||
"metrics", | ||
] | ||
|
||
|
||
def config_from_toml_dict(toml_dict: dict, model_name: str) -> dict: | ||
"""Get configuration for a model from a .toml configuration file | ||
loaded into a ``dict``. | ||
Parameters | ||
---------- | ||
toml_dict : dict | ||
Configuration from a .toml file, loaded into a dictionary. | ||
model_name : str | ||
Name of a model, specified as the ``model`` option in a table | ||
(such as TRAIN or PREDICT), | ||
that should have its own corresponding table | ||
specifying its configuration: hyperparameters such as learning rate, etc. | ||
Returns | ||
------- | ||
model_config : dict | ||
Model configuration in a ``dict``, | ||
as loaded from a .toml file, | ||
and used by the model method ``from_config``. | ||
""" | ||
if model_name not in models.MODEL_NAMES: | ||
raise ValueError( | ||
f"Invalid model name: {model_name}.\nValid model names are: {models.MODEL_NAMES}" | ||
) | ||
|
||
try: | ||
model_config = toml_dict[model_name] | ||
except KeyError as e: | ||
raise ValueError( | ||
f"A config section specifies the model name '{model_name}', " | ||
f"but there is no section named '{model_name}' in the config." | ||
) from e | ||
|
||
# check if config declares parameters for required attributes; | ||
# if not, just put an empty dict that will get passed as the "kwargs" | ||
for attr in MODEL_TABLES: | ||
if attr not in model_config: | ||
model_config[attr] = {} | ||
|
||
return model_config | ||
|
||
|
||
def config_from_toml_path(toml_path: str | pathlib.Path, model_name: str) -> dict: | ||
"""Get configuration for a model from a .toml configuration file, | ||
given the path to the file. | ||
Parameters | ||
---------- | ||
toml_path : str, Path | ||
to configuration file in .toml format | ||
model_name : str | ||
of str, i.e. names of models specified by a section | ||
(such as TRAIN or PREDICT) that should each have corresponding sections | ||
specifying their configuration: hyperparameters such as learning rate, etc. | ||
Returns | ||
------- | ||
model_config : dict | ||
Model configuration in a ``dict``, | ||
as loaded from a .toml file, | ||
and used by the model method ``from_config``. | ||
""" | ||
toml_path = pathlib.Path(toml_path) | ||
if not toml_path.is_file(): | ||
raise FileNotFoundError( | ||
f"File not found, or not recognized as a file: {toml_path}" | ||
) | ||
|
||
with toml_path.open("r") as fp: | ||
config_dict = toml.load(fp) | ||
return config_from_toml_dict(config_dict, model_name) |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.