Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensemble #204

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ python -m art.cli bert-transfer-learning-tutorial
```
3. A tutorial showing how to use ART for regularization
```sh
python -m art.cli regularization_tutorial
python -m art.cli regularization-tutorial
```

## API Cheatsheet
Expand Down
75 changes: 71 additions & 4 deletions art/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from art.utils.enums import TrainingStage
from art.utils.paths import get_checkpoint_logs_folder_path
from art.utils.savers import JSONStepSaver
from art.utils.ensemble import ArtEnsemble


class NoModelUsed:
Expand Down Expand Up @@ -719,10 +720,6 @@ def do(self, previous_states: Dict):
# TODO how to solve this?


class Squeeze(ModelStep):
pass


class TransferLearning(ModelStep):
"""This step tries performing proper transfer learning"""

Expand Down Expand Up @@ -833,3 +830,73 @@ def change_lr(model):
model.lr = self.fine_tune_lr

self.model_modifiers.append(change_lr)


class Ensemble(ModelStep):
"""This step tries to ensemble models"""

name = "Ensemble"
description = "Ensembles models"

def __init__(
self,
model: ArtModule,
num_models: int = 5,
logger: Optional[Logger] = None,
trainer_kwargs: Dict = {},
model_kwargs: Dict = {},
model_modifiers: List[Callable] = [],
datamodule_modifiers: List[Callable] = [],
):
"""
This method initializes the step

Args:
models (List[ArtModule]): models
logger (Logger, optional): logger. Defaults to None.
trainer_kwargs (Dict, optional): Kwargs passed to lightning Trainer. Defaults to {}.
model_kwargs (Dict, optional): Kwargs passed to model. Defaults to {}.
model_modifiers (List[Callable], optional): model modifiers. Defaults to [].
datamodule_modifiers (List[Callable], optional): datamodule modifiers. Defaults to [].
"""
super().__init__(
model,
trainer_kwargs,
model_kwargs,
model_modifiers,
datamodule_modifiers,
logger=logger,
)
self.num_models = num_models

def do(self, previous_states: Dict):
"""
This method trains the model

Args:
previous_states (Dict): previous states
"""
models_paths = []
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer this to be a class variable. After ensamble is trained some artifact with it is saved so that user can later utilize ensamble. Now I believe it's impossible to utilize withouth running entire step once more

for _ in range(self.num_models):
self.reset_trainer(
logger=self.trainer.logger, trainer_kwargs=self.trainer_kwargs
)
self.train(trainer_kwargs={"datamodule": self.datamodule})
models_paths.append(self.trainer.checkpoint_callback.best_model_path)

initialized_models = []
for path in models_paths:
model = self.model_class.load_from_checkpoint(path)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think ArtEnsamble should be able to take list of checkpoints, a model class and initialize ensamble by itself. For user it would be much more convenient.

model.eval()
initialized_models.append(model)

self.model = ArtEnsemble(initialized_models)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe that self.model is never being used. Trainer is initialized and probably None is returned inside initialize model. So you in fact evaluate last model of the ensamble. Moreover this leads to a memory leak.

My suggestion is to add a flag self.ensamble_ready and overwrite initialize_model - If ensamble is ready it returns ensamble otherwise it calls super().initialize_model()

self.validate(trainer_kwargs={"datamodule": self.datamodule})

def get_check_stage(self):
"""Returns check stage"""
return TrainingStage.VALIDATION.value

def log_model_params(self, model):
self.results["parameters"]["num_models"] = self.num_models
super().log_model_params(model)
36 changes: 36 additions & 0 deletions art/utils/ensemble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from art.core import ArtModule
from art.utils.enums import BATCH, PREDICTION

import torch
from torch import nn

from typing import List
from copy import deepcopy


class ArtEnsemble(ArtModule):
"""
Base class for ensembles.
"""

def __init__(self, models: List[ArtModule]):
super().__init__()
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add an option to provide model_class + list of checkpoints. But leave the list of models too. Make everything Optional.

self.models = nn.ModuleList(models)

def predict(self, data):
predictions = torch.stack([self.predict_on_model_from_dataloader(model, deepcopy(data)) for model in self.models])
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't work - predict is expected to return a dictionary to be compatible with compute metrics.

What I'd suggest is to write forward method - which expects a batch and makes prediction - This will be used later by the user.

In predict use this forward.

Why you do a deepcopy?

return torch.mean(predictions, dim=0)

def predict_on_model_from_dataloader(self, model, dataloader):
predictions = []
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this is even needed? You never get dataloader from lightning in any stage

for batch in dataloader:
model.to(self.device)
batch_processed = model.parse_data({BATCH: batch})
predictions.append(model.predict(batch_processed)[PREDICTION])
return torch.cat(predictions)

def log_params(self):
return {
"num_models": len(self.models),
"models": [model.log_params() for model in self.models],
}
Loading