Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Passing model as callable to step#141 #111 #148

Merged
merged 7 commits into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions art/core/base_components/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,15 +229,16 @@ def ml_train(self, data: Dict):

return data

def get_hash(self):
@classmethod
def get_hash(cls):
"""
Get hash of the model.

Returns:
str: Hash of the model.
"""
return hashlib.md5(
inspect.getsource(self.__class__).encode("utf-8")
inspect.getsource(cls).encode("utf-8")
).hexdigest()

def unify_type(self: Any, x: Any):
Expand Down
44 changes: 22 additions & 22 deletions ...re/visualizer/visualization_decorators.py → art/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,61 +8,61 @@
"""


def visualize(visualizing_function_in=None, visualizing_function_out=None):
def art_decorate_single_func(visualizing_function_in=None, visualizing_function_out=None):
"""
Decorator for visualizing input and output of a function.
Decorates input and output of a function.

Args:
visualizing_function_in (function, optional): Function to visualize input. Defaults to None.
visualizing_function_out (function, optional): Function to visualize output. Defaults to None.
function_in (function, optional): Function applied on the input. Defaults to None.
function_out (function, optional): Function applied on the output. Defaults to None.

Returns:
function: Decorated function.
"""
def decorator_visualize_input(func):
def decorator(func):
"""
Decorator for visualizing input of a function.
Decorator

Args:
func (function): Function to decorate.
"""
def wrapper_visualize_input(*args, **kwargs):
def wrapper(*args, **kwargs):
"""
Wrapper for visualizing input of a function.
Wrapper

Returns:
function: Decorated function.
"""
if visualizing_function_in is not None:
visualizing_function_in(*args, **kwargs)
# first arguments is the `self` object. We don't want to pass it to the visualizing function
to_be_passed = args[1:]
visualizing_function_in(*to_be_passed, **kwargs)
output = func(*args, **kwargs)
if visualizing_function_out is not None:
visualizing_function_out(output)
return output

return wrapper_visualize_input
return wrapper

return decorator_visualize_input
return decorator


def set_visualization(
def art_decorate(
functions: List[Tuple[object, str]],
visualizing_function_in=None,
visualizing_function_out=None,
function_in=None,
function_out=None,
):
"""
Set visualization for a list of functions.

Decorates list of objects functions. It doesn't modify output of a function
put can be used for logging additional information during training.

Args:
functions (List[Tuple[object, str]]): List of tuples of objects and methods to decorate.
visualizing_function_in (function, optional): Function to visualize input. Defaults to None.
visualizing_function_out (function, optional): Function to visualize output. Defaults to None.
function_in (function, optional): Function applied on the input. Defaults to None.
function_out (function, optional): Function applied on the output. Defaults to None.
"""
for obj, method in functions:
decorated = visualize(visualizing_function_in, visualizing_function_out)(
decorated = art_decorate_single_func(function_in, function_out)(
getattr(obj, method)
)
setattr(obj, method, decorated)

if hasattr(obj, "reset_pipelines"):
obj.reset_pipelines()
2 changes: 1 addition & 1 deletion art/experiment/Experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def fill_step_states(self, step: "Step"):
Args:
step (Step): The step whose results need to be recorded.
"""
self.state.step_states[step.get_model_name()][
self.state.step_states[step.model_name][
step.get_name_with_id()
] = step.get_latest_run()

Expand Down
105 changes: 59 additions & 46 deletions art/step/step.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
import datetime
import gc
import hashlib
import inspect
import subprocess
from abc import ABC, abstractmethod
from typing import Any, Dict, Iterable, Optional, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Union

import lightning as L
import torch
from lightning import Trainer
from lightning.pytorch.loggers import Logger

from art.utils.art_logger import art_logger, add_logger, remove_logger, get_new_log_file_name, get_run_id
from art.core.base_components.base_model import ArtModule
from art.core.exceptions import MissingLogParamsException
from art.core.MetricCalculator import MetricCalculator
from art.utils.paths import get_checkpoint_logs_folder_path
from art.step.step_savers import JSONStepSaver
from art.utils.art_logger import (add_logger, art_logger,
get_new_log_file_name, get_run_id,
remove_logger)
from art.utils.enums import TrainingStage
from art.utils.paths import get_checkpoint_logs_folder_path


class NoModelUsed:
Expand All @@ -41,6 +45,7 @@ def __init__(self):
"succesfull": False,
}
self.finalized = False
self.model_name = ""

def __call__(
self,
Expand All @@ -63,7 +68,6 @@ def __call__(
self.datamodule = datamodule
self.fill_basic_results()
self.do(previous_states)
self.log_params()
self.finalized = True
except Exception as e:
art_logger.exception(f"Error while executing step {self.name}!")
Expand Down Expand Up @@ -164,22 +168,13 @@ def was_run(self) -> bool:
)
return path.exists()

def get_model_name(self) -> str:
"""
Retrieve the model name associated with the step. By default, it's empty.

Returns:
str: Model name.
"""
return ""

def __repr__(self) -> str:
"""Representation of the step"""
result_repr = "\n".join(
f"\t{k}: {v}" for k, v in self.results["scores"].items()
)
model = self.model.__class__.__name__
return f"Step: {self.name}, Model: {model}, Passed: {self.results['succesfull']}. Results:\n{result_repr}"
return f"Step: {self.name}, Model: {self.model_name}, Passed: {self.results['succesfull']}. Results:\n{result_repr}"

def set_succesfull(self):
self.results["succesfull"] = True
Expand All @@ -204,24 +199,39 @@ class ModelStep(Step):

def __init__(
self,
model: ArtModule,
model_class: ArtModule,
trainer_kwargs: Dict = {},
model_kwargs: Dict = {},
model_modifiers: List[Callable] = [],
logger: Optional[Union[Logger, Iterable[Logger], bool]] = None,
):
"""
Initialize a model-based step.

Args:
model (ArtModule): The model associated with this step.
model_class (ArtModule): The model's class associated with this step.
trainer_kwargs (Dict, optional): Arguments to be passed to the trainer. Defaults to {}.
model_kwargs (Dict, optional): Arguments to be passed to the model. Defaults to {}.
model_modifiers (List[Callable], optional): List of functions to be applied to the model. Defaults to [].
datamodule_modifiers (List[Callable], optional): List of functions to be applied to the data module. Defaults to [].
logger (Optional[Union[Logger, Iterable[Logger], bool]], optional): Logger to be used. Defaults to None.
"""
super().__init__()
if logger is not None:
logger.add_tags(self.name)

self.model = model
self.trainer = Trainer(**trainer_kwargs, logger=logger)
if not inspect.isclass(model_class):
raise ValueError("model_func must be class inhertiting from Art Module or path to the checkpoint. This is to avoid memory leaks. Simplest way of doing this is to use lambda function lambda : ArtModule()")

self.model_class = model_class
self.model_kwargs = model_kwargs
self.model_modifiers = model_modifiers
self.logger = logger
self.trainer_kwargs = trainer_kwargs


self.model_name = model_class.__name__
self.hash = self.model_class.get_hash()

def __call__(
self,
Expand All @@ -238,8 +248,11 @@ def __call__(
datamodule (L.LightningDataModule): Data module to be used.
metric_calculator (MetricCalculator): Metric calculator for this step.
"""
self.model.set_metric_calculator(metric_calculator)
self.trainer = Trainer(**self.trainer_kwargs, logger=self.logger)
self.metric_calculator = metric_calculator
super().__call__(previous_states, datamodule, metric_calculator, run_id)
del self.trainer
gc.collect()

@abstractmethod
def do(self, previous_states: Dict):
Expand All @@ -251,16 +264,33 @@ def do(self, previous_states: Dict):
"""
pass

def initialize_model(self,) -> ArtModule:
"""
Initializes the model.
"""
if self.trainer.model is not None:
return None

model = self.model_class(**self.model_kwargs)
for modifier in self.model_modifiers:
modifier(model)
model.set_metric_calculator(self.metric_calculator)

self.log_params(model)
return model

def train(self, trainer_kwargs: Dict):
"""
Train the model using the provided trainer arguments.

Args:
trainer_kwargs (Dict): Arguments to be passed to the trainer for training the model.
"""
self.trainer.fit(model=self.model, **trainer_kwargs)
self.trainer.fit(model=self.initialize_model(), **trainer_kwargs)
logged_metrics = {k: v.item() for k, v in self.trainer.logged_metrics.items()}

self.results["scores"].update(logged_metrics)
self.results["model_path"] = self.trainer.checkpoint_callback.best_model_path

def validate(self, trainer_kwargs: Dict):
"""
Expand All @@ -269,8 +299,9 @@ def validate(self, trainer_kwargs: Dict):
Args:
trainer_kwargs (Dict): Arguments to be passed to the trainer for validating the model.
"""
art_logger.info(f"Validating model {self.get_model_name()}")
result = self.trainer.validate(model=self.model, **trainer_kwargs)
art_logger.info(f"Validating model {self.model_name}")

result = self.trainer.validate(model=self.initialize_model(), **trainer_kwargs)
self.results["scores"].update(result[0])

def test(self, trainer_kwargs: Dict):
Expand All @@ -280,18 +311,9 @@ def test(self, trainer_kwargs: Dict):
Args:
trainer_kwargs (Dict): Arguments to be passed to the trainer for testing the model.
"""
result = self.trainer.test(model=self.model, **trainer_kwargs)
result = self.trainer.test(model=self.initialize_model(), **trainer_kwargs)
self.results["scores"].update(result[0])

def get_model_name(self) -> str:
"""
Retrieve the name of the model associated with the step.

Returns:
str: Name of the model.
"""
return self.model.__class__.__name__

def get_step_id(self) -> str:
"""
Retrieve the step ID, combining model name (if available) with the index.
Expand All @@ -300,20 +322,11 @@ def get_step_id(self) -> str:
str: The step ID.
"""
return (
f"{self.get_model_name()}_{self.idx}"
if self.get_model_name() != ""
f"{self.model_name}_{self.idx}"
if self.model_name != ""
else f"{self.idx}"
)

def get_hash(self) -> str:
"""
Compute a hash for the model associated with the step.

Returns:
str: Hash of the model.
"""
return self.model.get_hash()

def get_current_stage(self) -> str:
"""
Retrieve the current training stage of the trainer.
Expand All @@ -332,9 +345,9 @@ def get_check_stage(self) -> str:
"""
return TrainingStage.VALIDATION.value

def log_params(self):
if hasattr(self.model, "log_params"):
model_params = self.model.log_params()
def log_params(self, model):
if hasattr(model, "log_params"):
model_params = model.log_params()
self.results["parameters"].update(model_params)

else:
Expand Down
7 changes: 5 additions & 2 deletions art/step/step_savers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
from pathlib import Path
from typing import Any

import lightning as L
import matplotlib.pyplot as plt
import torch

from art.core.base_components.base_model import ArtModule

from art.utils.paths import get_checkpoint_step_dir_path

Expand Down Expand Up @@ -106,8 +110,7 @@ def save(self, step: "Step", filename: str = RESULT_NAME):
if results_file.exists():
current_results = self.load(step_id, step_name, filename)
else:
model = step.model.__class__.__name__
current_results = {"name": step_name, "model": model, "runs": []}
current_results = {"name": step_name, "model": step.model_name, "runs": []}

current_results["runs"].insert(0, step.results)

Expand Down
Loading