-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Ensemble #204
base: develop
Are you sure you want to change the base?
Ensemble #204
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,6 +26,7 @@ | |
from art.utils.enums import TrainingStage | ||
from art.utils.paths import get_checkpoint_logs_folder_path | ||
from art.utils.savers import JSONStepSaver | ||
from art.utils.ensemble import ArtEnsemble | ||
|
||
|
||
class NoModelUsed: | ||
|
@@ -719,10 +720,6 @@ def do(self, previous_states: Dict): | |
# TODO how to solve this? | ||
|
||
|
||
class Squeeze(ModelStep): | ||
pass | ||
|
||
|
||
class TransferLearning(ModelStep): | ||
"""This step tries performing proper transfer learning""" | ||
|
||
|
@@ -833,3 +830,73 @@ def change_lr(model): | |
model.lr = self.fine_tune_lr | ||
|
||
self.model_modifiers.append(change_lr) | ||
|
||
|
||
class Ensemble(ModelStep): | ||
"""This step tries to ensemble models""" | ||
|
||
name = "Ensemble" | ||
description = "Ensembles models" | ||
|
||
def __init__( | ||
self, | ||
model: ArtModule, | ||
num_models: int = 5, | ||
logger: Optional[Logger] = None, | ||
trainer_kwargs: Dict = {}, | ||
model_kwargs: Dict = {}, | ||
model_modifiers: List[Callable] = [], | ||
datamodule_modifiers: List[Callable] = [], | ||
): | ||
""" | ||
This method initializes the step | ||
|
||
Args: | ||
models (List[ArtModule]): models | ||
logger (Logger, optional): logger. Defaults to None. | ||
trainer_kwargs (Dict, optional): Kwargs passed to lightning Trainer. Defaults to {}. | ||
model_kwargs (Dict, optional): Kwargs passed to model. Defaults to {}. | ||
model_modifiers (List[Callable], optional): model modifiers. Defaults to []. | ||
datamodule_modifiers (List[Callable], optional): datamodule modifiers. Defaults to []. | ||
""" | ||
super().__init__( | ||
model, | ||
trainer_kwargs, | ||
model_kwargs, | ||
model_modifiers, | ||
datamodule_modifiers, | ||
logger=logger, | ||
) | ||
self.num_models = num_models | ||
|
||
def do(self, previous_states: Dict): | ||
""" | ||
This method trains the model | ||
|
||
Args: | ||
previous_states (Dict): previous states | ||
""" | ||
models_paths = [] | ||
for _ in range(self.num_models): | ||
self.reset_trainer( | ||
logger=self.trainer.logger, trainer_kwargs=self.trainer_kwargs | ||
) | ||
self.train(trainer_kwargs={"datamodule": self.datamodule}) | ||
models_paths.append(self.trainer.checkpoint_callback.best_model_path) | ||
|
||
initialized_models = [] | ||
for path in models_paths: | ||
model = self.model_class.load_from_checkpoint(path) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think ArtEnsamble should be able to take list of checkpoints, a model class and initialize ensamble by itself. For user it would be much more convenient. |
||
model.eval() | ||
initialized_models.append(model) | ||
|
||
self.model = ArtEnsemble(initialized_models) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe that self.model is never being used. Trainer is initialized and probably None is returned inside initialize model. So you in fact evaluate last model of the ensamble. Moreover this leads to a memory leak. My suggestion is to add a flag |
||
self.validate(trainer_kwargs={"datamodule": self.datamodule}) | ||
|
||
def get_check_stage(self): | ||
"""Returns check stage""" | ||
return TrainingStage.VALIDATION.value | ||
|
||
def log_model_params(self, model): | ||
self.results["parameters"]["num_models"] = self.num_models | ||
super().log_model_params(model) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from art.core import ArtModule | ||
from art.utils.enums import BATCH, PREDICTION | ||
|
||
import torch | ||
from torch import nn | ||
|
||
from typing import List | ||
from copy import deepcopy | ||
|
||
|
||
class ArtEnsemble(ArtModule): | ||
""" | ||
Base class for ensembles. | ||
""" | ||
|
||
def __init__(self, models: List[ArtModule]): | ||
super().__init__() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add an option to provide model_class + list of checkpoints. But leave the list of models too. Make everything Optional. |
||
self.models = nn.ModuleList(models) | ||
|
||
def predict(self, data): | ||
predictions = torch.stack([self.predict_on_model_from_dataloader(model, deepcopy(data)) for model in self.models]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This won't work - predict is expected to return a dictionary to be compatible with compute metrics. What I'd suggest is to write forward method - which expects a batch and makes prediction - This will be used later by the user. In predict use this forward. Why you do a deepcopy? |
||
return torch.mean(predictions, dim=0) | ||
|
||
def predict_on_model_from_dataloader(self, model, dataloader): | ||
predictions = [] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why this is even needed? You never get dataloader from lightning in any stage |
||
for batch in dataloader: | ||
model.to(self.device) | ||
batch_processed = model.parse_data({BATCH: batch}) | ||
predictions.append(model.predict(batch_processed)[PREDICTION]) | ||
return torch.cat(predictions) | ||
|
||
def log_params(self): | ||
return { | ||
"num_models": len(self.models), | ||
"models": [model.log_params() for model in self.models], | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd prefer this to be a class variable. After ensamble is trained some artifact with it is saved so that user can later utilize ensamble. Now I believe it's impossible to utilize withouth running entire step once more