From b3c2cd4e821fbbe664957e26c1b7a93dcd18b04f Mon Sep 17 00:00:00 2001 From: Karol Cyganik Date: Sat, 2 Dec 2023 15:55:17 +0100 Subject: [PATCH 1/3] delete squeeze step --- art/steps.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/art/steps.py b/art/steps.py index ac5f709..5b267d4 100644 --- a/art/steps.py +++ b/art/steps.py @@ -719,10 +719,6 @@ def do(self, previous_states: Dict): # TODO how to solve this? -class Squeeze(ModelStep): - pass - - class TransferLearning(ModelStep): """This step tries performing proper transfer learning""" From edd17dd909444cb36351492beb7e3002db3f9f3c Mon Sep 17 00:00:00 2001 From: Karol Cyganik Date: Sat, 2 Dec 2023 16:07:59 +0100 Subject: [PATCH 2/3] fix readme for regularization tutorial --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index bdce397..9971db3 100644 --- a/README.md +++ b/README.md @@ -151,7 +151,7 @@ python -m art.cli bert-transfer-learning-tutorial ``` 3. A tutorial showing how to use ART for regularization ```sh -python -m art.cli regularization_tutorial +python -m art.cli regularization-tutorial ``` ## API Cheatsheet From c9cd71d66190d7f6dfca8fe6bb59f5d912ad09d2 Mon Sep 17 00:00:00 2001 From: Karol Cyganik Date: Sat, 2 Dec 2023 16:18:52 +0100 Subject: [PATCH 3/3] add ensemble step This commit add ensemble step to steps.py and ensemble.py to utils, where the Ensemble model as ArtModule is stored. Example usage (using our tutorial's code from MNIST example): ```python import torch.nn as nn from dataset import MNISTDataModule from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping from art.metrics import build_metric_name from art.utils.enums import TrainingStage from art.utils.enums import ( INPUT, TARGET, ) def get_data_module(n_train=200): mnist_data = datasets.load_dataset("mnist") mnist_data = mnist_data.rename_columns({"image": INPUT, "label": TARGET}) mnist_data['train'] = mnist_data['train'].select(range(n_train)) return MNISTDataModule(mnist_data) datamodule = get_data_module() project = ArtProject(name="mnist-ensemble", datamodule=datamodule) accuracy_metric, ce_loss = Accuracy( task="multiclass", num_classes=10), nn.CrossEntropyLoss() project.register_metrics([accuracy_metric, ce_loss]) checkpoint = ModelCheckpoint(monitor=build_metric_name( accuracy_metric, TrainingStage.VALIDATION.value), mode="max") early_stopping = EarlyStopping(monitor=build_metric_name( ce_loss, TrainingStage.VALIDATION.value), mode="min") project.add_step(Ensemble(MNISTModel, 10, trainer_kwargs={ "max_epochs": 6, "callbacks": [checkpoint, early_stopping], "check_val_every_n_epoch": 5})) project.run_all(force_rerun=True) ``` --- art/steps.py | 71 +++++++++++++++++++++++++++++++++++++++++++ art/utils/ensemble.py | 36 ++++++++++++++++++++++ 2 files changed, 107 insertions(+) create mode 100644 art/utils/ensemble.py diff --git a/art/steps.py b/art/steps.py index 5b267d4..b54c276 100644 --- a/art/steps.py +++ b/art/steps.py @@ -26,6 +26,7 @@ from art.utils.enums import TrainingStage from art.utils.paths import get_checkpoint_logs_folder_path from art.utils.savers import JSONStepSaver +from art.utils.ensemble import ArtEnsemble class NoModelUsed: @@ -829,3 +830,73 @@ def change_lr(model): model.lr = self.fine_tune_lr self.model_modifiers.append(change_lr) + + +class Ensemble(ModelStep): + """This step tries to ensemble models""" + + name = "Ensemble" + description = "Ensembles models" + + def __init__( + self, + model: ArtModule, + num_models: int = 5, + logger: Optional[Logger] = None, + trainer_kwargs: Dict = {}, + model_kwargs: Dict = {}, + model_modifiers: List[Callable] = [], + datamodule_modifiers: List[Callable] = [], + ): + """ + This method initializes the step + + Args: + models (List[ArtModule]): models + logger (Logger, optional): logger. Defaults to None. + trainer_kwargs (Dict, optional): Kwargs passed to lightning Trainer. Defaults to {}. + model_kwargs (Dict, optional): Kwargs passed to model. Defaults to {}. + model_modifiers (List[Callable], optional): model modifiers. Defaults to []. + datamodule_modifiers (List[Callable], optional): datamodule modifiers. Defaults to []. + """ + super().__init__( + model, + trainer_kwargs, + model_kwargs, + model_modifiers, + datamodule_modifiers, + logger=logger, + ) + self.num_models = num_models + + def do(self, previous_states: Dict): + """ + This method trains the model + + Args: + previous_states (Dict): previous states + """ + models_paths = [] + for _ in range(self.num_models): + self.reset_trainer( + logger=self.trainer.logger, trainer_kwargs=self.trainer_kwargs + ) + self.train(trainer_kwargs={"datamodule": self.datamodule}) + models_paths.append(self.trainer.checkpoint_callback.best_model_path) + + initialized_models = [] + for path in models_paths: + model = self.model_class.load_from_checkpoint(path) + model.eval() + initialized_models.append(model) + + self.model = ArtEnsemble(initialized_models) + self.validate(trainer_kwargs={"datamodule": self.datamodule}) + + def get_check_stage(self): + """Returns check stage""" + return TrainingStage.VALIDATION.value + + def log_model_params(self, model): + self.results["parameters"]["num_models"] = self.num_models + super().log_model_params(model) diff --git a/art/utils/ensemble.py b/art/utils/ensemble.py new file mode 100644 index 0000000..e2505e3 --- /dev/null +++ b/art/utils/ensemble.py @@ -0,0 +1,36 @@ +from art.core import ArtModule +from art.utils.enums import BATCH, PREDICTION + +import torch +from torch import nn + +from typing import List +from copy import deepcopy + + +class ArtEnsemble(ArtModule): + """ + Base class for ensembles. + """ + + def __init__(self, models: List[ArtModule]): + super().__init__() + self.models = nn.ModuleList(models) + + def predict(self, data): + predictions = torch.stack([self.predict_on_model_from_dataloader(model, deepcopy(data)) for model in self.models]) + return torch.mean(predictions, dim=0) + + def predict_on_model_from_dataloader(self, model, dataloader): + predictions = [] + for batch in dataloader: + model.to(self.device) + batch_processed = model.parse_data({BATCH: batch}) + predictions.append(model.predict(batch_processed)[PREDICTION]) + return torch.cat(predictions) + + def log_params(self): + return { + "num_models": len(self.models), + "models": [model.log_params() for model in self.models], + }