-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Ensemble #204
base: develop
Are you sure you want to change the base?
Ensemble #204
Conversation
This commit add ensemble step to steps.py and ensemble.py to utils, where the Ensemble model as ArtModule is stored. Example usage (using our tutorial's code from MNIST example): ```python import torch.nn as nn from dataset import MNISTDataModule from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping from art.metrics import build_metric_name from art.utils.enums import TrainingStage from art.utils.enums import ( INPUT, TARGET, ) def get_data_module(n_train=200): mnist_data = datasets.load_dataset("mnist") mnist_data = mnist_data.rename_columns({"image": INPUT, "label": TARGET}) mnist_data['train'] = mnist_data['train'].select(range(n_train)) return MNISTDataModule(mnist_data) datamodule = get_data_module() project = ArtProject(name="mnist-ensemble", datamodule=datamodule) accuracy_metric, ce_loss = Accuracy( task="multiclass", num_classes=10), nn.CrossEntropyLoss() project.register_metrics([accuracy_metric, ce_loss]) checkpoint = ModelCheckpoint(monitor=build_metric_name( accuracy_metric, TrainingStage.VALIDATION.value), mode="max") early_stopping = EarlyStopping(monitor=build_metric_name( ce_loss, TrainingStage.VALIDATION.value), mode="min") project.add_step(Ensemble(MNISTModel, 10, trainer_kwargs={ "max_epochs": 6, "callbacks": [checkpoint, early_stopping], "check_val_every_n_epoch": 5})) project.run_all(force_rerun=True) ```
7c5b23b
to
c9cd71d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The structure is very nice, just fix bugs and this will be ready to go
The example is very cool we can potentially put it somewhere in the documentation later.
Please create a test
out of this example. I suggest to do it in a following way: Write a pytest.fixture that will do everything up to early_stopping = EarlyStopping(monitor=build_metric_name( ce_loss, TrainingStage.VALIDATION.value), mode="min")
and return dictionary with created stuff. Then we will be able to reuse thix fixture later. In a test itself just add_step, run everything, and do some inference on the obtained model
model.eval() | ||
initialized_models.append(model) | ||
|
||
self.model = ArtEnsemble(initialized_models) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe that self.model is never being used. Trainer is initialized and probably None is returned inside initialize model. So you in fact evaluate last model of the ensamble. Moreover this leads to a memory leak.
My suggestion is to add a flag self.ensamble_ready
and overwrite initialize_model
- If ensamble is ready it returns ensamble otherwise it calls super().initialize_model()
Args: | ||
previous_states (Dict): previous states | ||
""" | ||
models_paths = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd prefer this to be a class variable. After ensamble is trained some artifact with it is saved so that user can later utilize ensamble. Now I believe it's impossible to utilize withouth running entire step once more
|
||
initialized_models = [] | ||
for path in models_paths: | ||
model = self.model_class.load_from_checkpoint(path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think ArtEnsamble should be able to take list of checkpoints, a model class and initialize ensamble by itself. For user it would be much more convenient.
""" | ||
|
||
def __init__(self, models: List[ArtModule]): | ||
super().__init__() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add an option to provide model_class + list of checkpoints. But leave the list of models too. Make everything Optional.
self.models = nn.ModuleList(models) | ||
|
||
def predict(self, data): | ||
predictions = torch.stack([self.predict_on_model_from_dataloader(model, deepcopy(data)) for model in self.models]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This won't work - predict is expected to return a dictionary to be compatible with compute metrics.
What I'd suggest is to write forward method - which expects a batch and makes prediction - This will be used later by the user.
In predict use this forward.
Why you do a deepcopy?
return torch.mean(predictions, dim=0) | ||
|
||
def predict_on_model_from_dataloader(self, model, dataloader): | ||
predictions = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why this is even needed? You never get dataloader from lightning in any stage
Closes #147
Example usage (using our tutorial's code from MNIST example):