Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensemble #204

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open

Ensemble #204

wants to merge 3 commits into from

Conversation

kordc
Copy link
Collaborator

@kordc kordc commented Dec 2, 2023

Closes #147

Example usage (using our tutorial's code from MNIST example):

    import torch.nn as nn
    from dataset import MNISTDataModule

    from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
    from art.metrics import build_metric_name
    from art.utils.enums import TrainingStage

    from art.utils.enums import (
        INPUT,
        TARGET,
    )

    def get_data_module(n_train=200):
        mnist_data = datasets.load_dataset("mnist")

        mnist_data = mnist_data.rename_columns({"image": INPUT, "label": TARGET})
        mnist_data['train'] = mnist_data['train'].select(range(n_train))

        return MNISTDataModule(mnist_data)

    datamodule = get_data_module()
    project = ArtProject(name="mnist-ensemble", datamodule=datamodule)
    accuracy_metric, ce_loss = Accuracy(
        task="multiclass", num_classes=10), nn.CrossEntropyLoss()
    project.register_metrics([accuracy_metric, ce_loss])
    checkpoint = ModelCheckpoint(monitor=build_metric_name(
        accuracy_metric, TrainingStage.VALIDATION.value), mode="max")
    early_stopping = EarlyStopping(monitor=build_metric_name(
        ce_loss, TrainingStage.VALIDATION.value), mode="min")
    project.add_step(Ensemble(MNISTModel, 10, trainer_kwargs={
                     "max_epochs": 6, "callbacks": [checkpoint, early_stopping], "check_val_every_n_epoch": 5}))

    project.run_all(force_rerun=True)

@kordc kordc requested a review from SebChw December 2, 2023 17:59
This commit add ensemble step to steps.py and ensemble.py to utils,
where the Ensemble model as ArtModule is stored.

Example usage (using our tutorial's code from MNIST example):

```python
import torch.nn as nn
from dataset import MNISTDataModule

from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
from art.metrics import build_metric_name
from art.utils.enums import TrainingStage

from art.utils.enums import (
    INPUT,
    TARGET,
)

def get_data_module(n_train=200):
    mnist_data = datasets.load_dataset("mnist")

    mnist_data = mnist_data.rename_columns({"image": INPUT, "label": TARGET})
    mnist_data['train'] = mnist_data['train'].select(range(n_train))

    return MNISTDataModule(mnist_data)

datamodule = get_data_module()
project = ArtProject(name="mnist-ensemble", datamodule=datamodule)
accuracy_metric, ce_loss = Accuracy(
    task="multiclass", num_classes=10), nn.CrossEntropyLoss()
project.register_metrics([accuracy_metric, ce_loss])
checkpoint = ModelCheckpoint(monitor=build_metric_name(
    accuracy_metric, TrainingStage.VALIDATION.value), mode="max")
early_stopping = EarlyStopping(monitor=build_metric_name(
    ce_loss, TrainingStage.VALIDATION.value), mode="min")
project.add_step(Ensemble(MNISTModel, 10, trainer_kwargs={
                 "max_epochs": 6, "callbacks": [checkpoint, early_stopping], "check_val_every_n_epoch": 5}))

project.run_all(force_rerun=True)
```
@kordc kordc force-pushed the kcyganik-ensemble branch from 7c5b23b to c9cd71d Compare December 2, 2023 18:03
Copy link
Owner

@SebChw SebChw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The structure is very nice, just fix bugs and this will be ready to go

The example is very cool we can potentially put it somewhere in the documentation later.

Please create a test out of this example. I suggest to do it in a following way: Write a pytest.fixture that will do everything up to early_stopping = EarlyStopping(monitor=build_metric_name( ce_loss, TrainingStage.VALIDATION.value), mode="min") and return dictionary with created stuff. Then we will be able to reuse thix fixture later. In a test itself just add_step, run everything, and do some inference on the obtained model

model.eval()
initialized_models.append(model)

self.model = ArtEnsemble(initialized_models)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe that self.model is never being used. Trainer is initialized and probably None is returned inside initialize model. So you in fact evaluate last model of the ensamble. Moreover this leads to a memory leak.

My suggestion is to add a flag self.ensamble_ready and overwrite initialize_model - If ensamble is ready it returns ensamble otherwise it calls super().initialize_model()

Args:
previous_states (Dict): previous states
"""
models_paths = []
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer this to be a class variable. After ensamble is trained some artifact with it is saved so that user can later utilize ensamble. Now I believe it's impossible to utilize withouth running entire step once more


initialized_models = []
for path in models_paths:
model = self.model_class.load_from_checkpoint(path)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think ArtEnsamble should be able to take list of checkpoints, a model class and initialize ensamble by itself. For user it would be much more convenient.

"""

def __init__(self, models: List[ArtModule]):
super().__init__()
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add an option to provide model_class + list of checkpoints. But leave the list of models too. Make everything Optional.

self.models = nn.ModuleList(models)

def predict(self, data):
predictions = torch.stack([self.predict_on_model_from_dataloader(model, deepcopy(data)) for model in self.models])
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't work - predict is expected to return a dictionary to be compatible with compute metrics.

What I'd suggest is to write forward method - which expects a batch and makes prediction - This will be used later by the user.

In predict use this forward.

Why you do a deepcopy?

return torch.mean(predictions, dim=0)

def predict_on_model_from_dataloader(self, model, dataloader):
predictions = []
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this is even needed? You never get dataloader from lightning in any stage

@kordc kordc changed the title Kcyganik ensemble Ensemble Dec 13, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants