From 8dc2f1154affa76f442f6f8a1c08454b720560dc Mon Sep 17 00:00:00 2001 From: Kacper Trebacz Date: Mon, 6 Nov 2023 20:56:55 +0100 Subject: [PATCH 1/6] passing model as callable --- art/step/step.py | 20 ++++++++++++++------ art/step/step_savers.py | 39 ++++++++++++++++++++++++++++++++++++++- art/step/steps.py | 2 +- 3 files changed, 53 insertions(+), 8 deletions(-) diff --git a/art/step/step.py b/art/step/step.py index dae6ccb..7ab5545 100644 --- a/art/step/step.py +++ b/art/step/step.py @@ -1,18 +1,20 @@ import datetime +import gc import hashlib import inspect import subprocess from abc import ABC, abstractmethod -from typing import Any, Dict, Iterable, Optional, Union +from typing import Any, Dict, Iterable, Optional, Union, Callable import lightning as L +import torch from lightning import Trainer from lightning.pytorch.loggers import Logger from art.core.base_components.base_model import ArtModule from art.core.exceptions import MissingLogParamsException from art.core.MetricCalculator import MetricCalculator -from art.step.step_savers import JSONStepSaver +from art.step.step_savers import JSONStepSaver, ModelSaver from art.utils.enums import TrainingStage @@ -192,7 +194,7 @@ class ModelStep(Step): def __init__( self, - model: ArtModule, + model_func: Callable[[], ArtModule], trainer_kwargs: Dict = {}, logger: Optional[Union[Logger, Iterable[Logger], bool]] = None, ): @@ -200,15 +202,15 @@ def __init__( Initialize a model-based step. Args: - model (ArtModule): The model associated with this step. + model_func (ArtModule): The model associated with this step. trainer_kwargs (Dict, optional): Arguments to be passed to the trainer. Defaults to {}. logger (Optional[Union[Logger, Iterable[Logger], bool]], optional): Logger to be used. Defaults to None. """ super().__init__() if logger is not None: logger.add_tags(self.name) - - self.model = model + assert isinstance(model_func, Callable) + self.model = model_func() self.trainer = Trainer(**trainer_kwargs, logger=logger) def __call__( @@ -227,6 +229,12 @@ def __call__( """ self.model.set_metric_calculator(metric_calculator) super().__call__(previous_states, datamodule, metric_calculator) + #save model to file + ModelSaver().save(self.model, self.get_step_id(), self.name) + del self.model + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() @abstractmethod def do(self, previous_states: Dict): diff --git a/art/step/step_savers.py b/art/step/step_savers.py index c0a5808..d6127c5 100644 --- a/art/step/step_savers.py +++ b/art/step/step_savers.py @@ -2,8 +2,11 @@ from abc import ABC, abstractmethod from pathlib import Path from typing import Any - +import lightning as L import matplotlib.pyplot as plt +import torch + +from art.core.base_components.base_model import ArtModule BASE_PATH = Path("checkpoints") @@ -164,3 +167,37 @@ def load(self, step_id, step_name: str, filename: str): NotImplementedError: This method is not implemented. """ raise NotImplementedError() + + +class ModelSaver(StepSaver): + def save(self, obj: ArtModule, step_id: str, step_name: str, filename: str = "model.ckpt"): + """ + Save a PyTorch Lightning model. + + Args: + obj (L.LightningModule): The model to save. + step_id (str): The ID of the step. + step_name (str): The name of the step. + filename (str): The name of the file to save the model to. + """ + self.ensure_directory(step_id, step_name) + filepath = self.get_path(step_id, step_name, filename) + filepath.parent.mkdir(exist_ok=True) + torch.save(obj.state_dict(), filepath) + + def load(self, step_id: str, step_name: str, model: ArtModule, filename: str = "model.ckpt"): + """ + Load a PyTorch Lightning model. + + Args: + step_id (str): The ID of the step. + step_name (str): The name of the step. + filename (str): The name of the file containing the model. + + Returns: + L.LightningModule: The loaded model. + """ + filepath = self.get_path(step_id, step_name, filename) + model.load_state_dict(torch.load(filepath)) + return model + diff --git a/art/step/steps.py b/art/step/steps.py index 207ef69..38a007a 100644 --- a/art/step/steps.py +++ b/art/step/steps.py @@ -185,7 +185,7 @@ def __init__( model: ArtModule, logger: Optional[Union[Logger, Iterable[Logger], bool]] = None, ): - super().__init__(model=model, logger=logger) + super().__init__(model_func=model, logger=logger) def do(self, previous_states: Dict): """ From 35d04c3fc58ff8a6b3afea1864ae5f29e53740ff Mon Sep 17 00:00:00 2001 From: SebChw Date: Wed, 15 Nov 2023 20:33:14 +0100 Subject: [PATCH 2/6] Step now takes Class not object instance --- art/step/step.py | 63 ++++++++++++++++++++++++++++++------------------ 1 file changed, 39 insertions(+), 24 deletions(-) diff --git a/art/step/step.py b/art/step/step.py index 7ab5545..a115fe6 100644 --- a/art/step/step.py +++ b/art/step/step.py @@ -194,24 +194,35 @@ class ModelStep(Step): def __init__( self, - model_func: Callable[[], ArtModule], + model_class: ArtModule, trainer_kwargs: Dict = {}, + model_kwargs: Dict = {}, + model_modifiers: List[Callable] = [], logger: Optional[Union[Logger, Iterable[Logger], bool]] = None, ): """ Initialize a model-based step. Args: - model_func (ArtModule): The model associated with this step. + model_class (ArtModule): The model's class associated with this step. trainer_kwargs (Dict, optional): Arguments to be passed to the trainer. Defaults to {}. + model_kwargs (Dict, optional): Arguments to be passed to the model. Defaults to {}. + model_modifiers (List[Callable], optional): List of functions to be applied to the model. Defaults to []. + datamodule_modifiers (List[Callable], optional): List of functions to be applied to the data module. Defaults to []. logger (Optional[Union[Logger, Iterable[Logger], bool]], optional): Logger to be used. Defaults to None. """ super().__init__() if logger is not None: logger.add_tags(self.name) - assert isinstance(model_func, Callable) - self.model = model_func() - self.trainer = Trainer(**trainer_kwargs, logger=logger) + + if not inspect.isclass(model_class): + raise ValueError("model_func must be class inhertiting from Art Module or path to the checkpoint. This is to avoid memory leaks. Simplest way of doing this is to use lambda function lambda : ArtModule()") + + self.model_class = model_class + self.model_kwargs = model_kwargs + self.model_modifiers = model_modifiers + self.logger = logger + self.trainer_kwargs = trainer_kwargs def __call__( self, @@ -227,13 +238,10 @@ def __call__( datamodule (L.LightningDataModule): Data module to be used. metric_calculator (MetricCalculator): Metric calculator for this step. """ - self.model.set_metric_calculator(metric_calculator) + self.trainer = Trainer(**self.trainer_kwargs, logger=self.logger) + self.metric_calculator = metric_calculator super().__call__(previous_states, datamodule, metric_calculator) - #save model to file - ModelSaver().save(self.model, self.get_step_id(), self.name) - del self.model - if torch.cuda.is_available(): - torch.cuda.empty_cache() + del self.trainer gc.collect() @abstractmethod @@ -246,6 +254,20 @@ def do(self, previous_states: Dict): """ pass + def initialize_model(self,) -> ArtModule: + """ + Initializes the model. + """ + if self.trainer.model is not None: + return None + + model = self.model_class(**self.model_kwargs) + for modifier in self.model_modifiers: + modifier(model) + model.set_metric_calculator(self.metric_calculator) + + return model + def train(self, trainer_kwargs: Dict): """ Train the model using the provided trainer arguments. @@ -253,8 +275,9 @@ def train(self, trainer_kwargs: Dict): Args: trainer_kwargs (Dict): Arguments to be passed to the trainer for training the model. """ - self.trainer.fit(model=self.model, **trainer_kwargs) + self.trainer.fit(model=self.initialize_model(), **trainer_kwargs) logged_metrics = {k: v.item() for k, v in self.trainer.logged_metrics.items()} + self.results["scores"].update(logged_metrics) def validate(self, trainer_kwargs: Dict): @@ -264,8 +287,9 @@ def validate(self, trainer_kwargs: Dict): Args: trainer_kwargs (Dict): Arguments to be passed to the trainer for validating the model. """ - print(f"Validating model {self.get_model_name()}") - result = self.trainer.validate(model=self.model, **trainer_kwargs) + print(f"Validating model {self.model_name}") + + result = self.trainer.validate(model=self.initialize_model(), **trainer_kwargs) self.results["scores"].update(result[0]) def test(self, trainer_kwargs: Dict): @@ -275,18 +299,9 @@ def test(self, trainer_kwargs: Dict): Args: trainer_kwargs (Dict): Arguments to be passed to the trainer for testing the model. """ - result = self.trainer.test(model=self.model, **trainer_kwargs) + result = self.trainer.test(model=self.initialize_model(), **trainer_kwargs) self.results["scores"].update(result[0]) - def get_model_name(self) -> str: - """ - Retrieve the name of the model associated with the step. - - Returns: - str: Name of the model. - """ - return self.model.__class__.__name__ - def get_step_id(self) -> str: """ Retrieve the step ID, combining model name (if available) with the index. From edeaf7d9d4b82a7ca11fef790a8db09d43ad3486 Mon Sep 17 00:00:00 2001 From: SebChw Date: Wed, 15 Nov 2023 20:34:57 +0100 Subject: [PATCH 3/6] adapt art to these changes --- art/core/base_components/base_model.py | 5 +-- .../visualizer/visualization_decorators.py | 6 ++-- art/experiment/Experiment.py | 2 +- art/step/step.py | 33 ++++++------------- art/step/steps.py | 14 ++++---- 5 files changed, 24 insertions(+), 36 deletions(-) diff --git a/art/core/base_components/base_model.py b/art/core/base_components/base_model.py index e5b2840..a094b1c 100644 --- a/art/core/base_components/base_model.py +++ b/art/core/base_components/base_model.py @@ -229,7 +229,8 @@ def ml_train(self, data: Dict): return data - def get_hash(self): + @classmethod + def get_hash(cls): """ Get hash of the model. @@ -237,7 +238,7 @@ def get_hash(self): str: Hash of the model. """ return hashlib.md5( - inspect.getsource(self.__class__).encode("utf-8") + inspect.getsource(cls).encode("utf-8") ).hexdigest() def unify_type(self: Any, x: Any): diff --git a/art/core/visualizer/visualization_decorators.py b/art/core/visualizer/visualization_decorators.py index 93939de..d5715d7 100644 --- a/art/core/visualizer/visualization_decorators.py +++ b/art/core/visualizer/visualization_decorators.py @@ -34,7 +34,8 @@ def wrapper_visualize_input(*args, **kwargs): function: Decorated function. """ if visualizing_function_in is not None: - visualizing_function_in(*args, **kwargs) + to_be_passed = args[1:] + visualizing_function_in(*to_be_passed, **kwargs) output = func(*args, **kwargs) if visualizing_function_out is not None: visualizing_function_out(output) @@ -63,6 +64,3 @@ def set_visualization( getattr(obj, method) ) setattr(obj, method, decorated) - - if hasattr(obj, "reset_pipelines"): - obj.reset_pipelines() diff --git a/art/experiment/Experiment.py b/art/experiment/Experiment.py index 7a8a4eb..e7f0b2b 100644 --- a/art/experiment/Experiment.py +++ b/art/experiment/Experiment.py @@ -61,7 +61,7 @@ def fill_step_states(self, step: "Step"): Args: step (Step): The step whose results need to be recorded. """ - self.state.step_states[step.get_model_name()][ + self.state.step_states[step.model_name][ step.get_name_with_id() ] = step.get_latest_run() diff --git a/art/step/step.py b/art/step/step.py index a115fe6..b3afaaf 100644 --- a/art/step/step.py +++ b/art/step/step.py @@ -4,7 +4,7 @@ import inspect import subprocess from abc import ABC, abstractmethod -from typing import Any, Dict, Iterable, Optional, Union, Callable +from typing import Any, Callable, Dict, Iterable, List, Optional, Union import lightning as L import torch @@ -14,7 +14,7 @@ from art.core.base_components.base_model import ArtModule from art.core.exceptions import MissingLogParamsException from art.core.MetricCalculator import MetricCalculator -from art.step.step_savers import JSONStepSaver, ModelSaver +from art.step.step_savers import JSONStepSaver from art.utils.enums import TrainingStage @@ -41,6 +41,7 @@ def __init__(self): "succesfull": False, } self.finalized = False + self.model_name = "" def __call__( self, @@ -154,22 +155,13 @@ def was_run(self) -> bool: ) return path.exists() - def get_model_name(self) -> str: - """ - Retrieve the model name associated with the step. By default, it's empty. - - Returns: - str: Model name. - """ - return "" def __repr__(self) -> str: """Representation of the step""" result_repr = "\n".join( f"\t{k}: {v}" for k, v in self.results["scores"].items() ) - model = self.model.__class__.__name__ - return f"Step: {self.name}, Model: {model}, Passed: {self.results['succesfull']}. Results:\n{result_repr}" + return f"Step: {self.name}, Model: {self.model_name}, Passed: {self.results['succesfull']}. Results:\n{result_repr}" def set_succesfull(self): self.results["succesfull"] = True @@ -224,6 +216,10 @@ def __init__( self.logger = logger self.trainer_kwargs = trainer_kwargs + + self.model_name = model_class.__name__ + self.hash = self.model_class.get_hash() + def __call__( self, previous_states: Dict, @@ -310,20 +306,11 @@ def get_step_id(self) -> str: str: The step ID. """ return ( - f"{self.get_model_name()}_{self.idx}" - if self.get_model_name() != "" + f"{self.model_name}_{self.idx}" + if self.model_name != "" else f"{self.idx}" ) - def get_hash(self) -> str: - """ - Compute a hash for the model associated with the step. - - Returns: - str: Hash of the model. - """ - return self.model.get_hash() - def get_current_stage(self) -> str: """ Retrieve the current training stage of the trainer. diff --git a/art/step/steps.py b/art/step/steps.py index 38a007a..1120eca 100644 --- a/art/step/steps.py +++ b/art/step/steps.py @@ -32,8 +32,9 @@ class EvaluateBaseline(ModelStep): def __init__( self, baseline: ArtModule, + device: Optional[str] = "cpu", ): - super().__init__(baseline, {"accelerator": baseline.device.type}) + super().__init__(baseline, {"accelerator": device}) def do(self, previous_states: Dict): """ @@ -42,9 +43,11 @@ def do(self, previous_states: Dict): Args: previous_states (Dict): previous states """ - self.model.ml_train({"dataloader": self.datamodule.train_dataloader()}) - self.validate(trainer_kwargs={"datamodule": self.datamodule}) - + model = self.model_class() + model.ml_train({"dataloader": self.datamodule.train_dataloader()}) + model.set_metric_calculator(self.metric_calculator) + result = self.trainer.validate(model=model, datamodule= self.datamodule) + self.results["scores"].update(result[0]) class CheckLossOnInit(ModelStep): """This step checks whether the loss on init is as expected""" @@ -78,7 +81,7 @@ class OverfitOneBatch(ModelStep): def __init__( self, model: ArtModule, - number_of_steps: int = 100, + number_of_steps: int = 50, ): self.number_of_steps = number_of_steps super().__init__(model, {"overfit_batches": 1, "max_epochs": number_of_steps}) @@ -165,7 +168,6 @@ def do(self, previous_states: Dict): Args: previous_states (Dict): previous states """ - self.model.turn_on_model_regularizations() self.datamodule.turn_on_regularizations() self.train(trainer_kwargs={"datamodule": self.datamodule}) From c66c72227d37fd12207ae13f8860d4254ff3f11a Mon Sep 17 00:00:00 2001 From: SebChw Date: Wed, 15 Nov 2023 20:35:16 +0100 Subject: [PATCH 4/6] save obtained model --- .../visualizer/visualization_decorators.py | 1 + art/step/step.py | 1 + art/step/step_savers.py | 38 +------------------ 3 files changed, 4 insertions(+), 36 deletions(-) diff --git a/art/core/visualizer/visualization_decorators.py b/art/core/visualizer/visualization_decorators.py index d5715d7..b440e18 100644 --- a/art/core/visualizer/visualization_decorators.py +++ b/art/core/visualizer/visualization_decorators.py @@ -34,6 +34,7 @@ def wrapper_visualize_input(*args, **kwargs): function: Decorated function. """ if visualizing_function_in is not None: + # first arguments is the `self` object. We don't want to pass it to the visualizing function to_be_passed = args[1:] visualizing_function_in(*to_be_passed, **kwargs) output = func(*args, **kwargs) diff --git a/art/step/step.py b/art/step/step.py index b3afaaf..37c1a87 100644 --- a/art/step/step.py +++ b/art/step/step.py @@ -275,6 +275,7 @@ def train(self, trainer_kwargs: Dict): logged_metrics = {k: v.item() for k, v in self.trainer.logged_metrics.items()} self.results["scores"].update(logged_metrics) + self.results["model_path"] = self.trainer.checkpoint_callback.best_model_path def validate(self, trainer_kwargs: Dict): """ diff --git a/art/step/step_savers.py b/art/step/step_savers.py index d6127c5..aa05fe2 100644 --- a/art/step/step_savers.py +++ b/art/step/step_savers.py @@ -2,6 +2,7 @@ from abc import ABC, abstractmethod from pathlib import Path from typing import Any + import lightning as L import matplotlib.pyplot as plt import torch @@ -109,8 +110,7 @@ def save(self, step: "Step", filename: str = RESULT_NAME): if results_file.exists(): current_results = self.load(step_id, step_name, filename) else: - model = step.model.__class__.__name__ - current_results = {"name": step_name, "model": model, "runs": []} + current_results = {"name": step_name, "model": step.model_name, "runs": []} current_results["runs"].insert(0, step.results) @@ -167,37 +167,3 @@ def load(self, step_id, step_name: str, filename: str): NotImplementedError: This method is not implemented. """ raise NotImplementedError() - - -class ModelSaver(StepSaver): - def save(self, obj: ArtModule, step_id: str, step_name: str, filename: str = "model.ckpt"): - """ - Save a PyTorch Lightning model. - - Args: - obj (L.LightningModule): The model to save. - step_id (str): The ID of the step. - step_name (str): The name of the step. - filename (str): The name of the file to save the model to. - """ - self.ensure_directory(step_id, step_name) - filepath = self.get_path(step_id, step_name, filename) - filepath.parent.mkdir(exist_ok=True) - torch.save(obj.state_dict(), filepath) - - def load(self, step_id: str, step_name: str, model: ArtModule, filename: str = "model.ckpt"): - """ - Load a PyTorch Lightning model. - - Args: - step_id (str): The ID of the step. - step_name (str): The name of the step. - filename (str): The name of the file containing the model. - - Returns: - L.LightningModule: The loaded model. - """ - filepath = self.get_path(step_id, step_name, filename) - model.load_state_dict(torch.load(filepath)) - return model - From 01af40154786e40a38f27052aa3540dd6303ec40 Mon Sep 17 00:00:00 2001 From: SebChw Date: Thu, 16 Nov 2023 10:10:49 +0100 Subject: [PATCH 5/6] fix log params lack of model --- art/step/step.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/art/step/step.py b/art/step/step.py index 37c1a87..00e7ea7 100644 --- a/art/step/step.py +++ b/art/step/step.py @@ -60,7 +60,6 @@ def __call__( self.datamodule = datamodule self.fill_basic_results() self.do(previous_states) - self.log_params() self.finalized = True def set_step_id(self, idx: int): @@ -262,6 +261,7 @@ def initialize_model(self,) -> ArtModule: modifier(model) model.set_metric_calculator(self.metric_calculator) + self.log_params(model) return model def train(self, trainer_kwargs: Dict): @@ -330,9 +330,9 @@ def get_check_stage(self) -> str: """ return TrainingStage.VALIDATION.value - def log_params(self): - if hasattr(self.model, "log_params"): - model_params = self.model.log_params() + def log_params(self, model): + if hasattr(model, "log_params"): + model_params = model.log_params() self.results["parameters"].update(model_params) else: From 2a3a6e897838bc791827406f80b6c23914ccf064 Mon Sep 17 00:00:00 2001 From: SebChw Date: Thu, 16 Nov 2023 10:11:12 +0100 Subject: [PATCH 6/6] rename art_decorator --- ...ualization_decorators.py => decorators.py} | 37 ++++++++++--------- 1 file changed, 19 insertions(+), 18 deletions(-) rename art/{core/visualizer/visualization_decorators.py => decorators.py} (57%) diff --git a/art/core/visualizer/visualization_decorators.py b/art/decorators.py similarity index 57% rename from art/core/visualizer/visualization_decorators.py rename to art/decorators.py index b440e18..08e98cb 100644 --- a/art/core/visualizer/visualization_decorators.py +++ b/art/decorators.py @@ -8,27 +8,27 @@ """ -def visualize(visualizing_function_in=None, visualizing_function_out=None): +def art_decorate_single_func(visualizing_function_in=None, visualizing_function_out=None): """ - Decorator for visualizing input and output of a function. + Decorates input and output of a function. Args: - visualizing_function_in (function, optional): Function to visualize input. Defaults to None. - visualizing_function_out (function, optional): Function to visualize output. Defaults to None. + function_in (function, optional): Function applied on the input. Defaults to None. + function_out (function, optional): Function applied on the output. Defaults to None. Returns: function: Decorated function. """ - def decorator_visualize_input(func): + def decorator(func): """ - Decorator for visualizing input of a function. + Decorator Args: func (function): Function to decorate. """ - def wrapper_visualize_input(*args, **kwargs): + def wrapper(*args, **kwargs): """ - Wrapper for visualizing input of a function. + Wrapper Returns: function: Decorated function. @@ -42,26 +42,27 @@ def wrapper_visualize_input(*args, **kwargs): visualizing_function_out(output) return output - return wrapper_visualize_input + return wrapper - return decorator_visualize_input + return decorator -def set_visualization( +def art_decorate( functions: List[Tuple[object, str]], - visualizing_function_in=None, - visualizing_function_out=None, + function_in=None, + function_out=None, ): """ - Set visualization for a list of functions. - + Decorates list of objects functions. It doesn't modify output of a function + put can be used for logging additional information during training. + Args: functions (List[Tuple[object, str]]): List of tuples of objects and methods to decorate. - visualizing_function_in (function, optional): Function to visualize input. Defaults to None. - visualizing_function_out (function, optional): Function to visualize output. Defaults to None. + function_in (function, optional): Function applied on the input. Defaults to None. + function_out (function, optional): Function applied on the output. Defaults to None. """ for obj, method in functions: - decorated = visualize(visualizing_function_in, visualizing_function_out)( + decorated = art_decorate_single_func(function_in, function_out)( getattr(obj, method) ) setattr(obj, method, decorated)