From d815dd85093a052185aa87c0d946f3f1773a8905 Mon Sep 17 00:00:00 2001 From: Kacper Trebacz Date: Sun, 8 Oct 2023 13:39:00 +0200 Subject: [PATCH] data analysis --- art/experiment/Experiment.py | 4 ++ art/step/checks.py | 44 +++++++++++++++---- art/step/step.py | 83 ++++++++++++++++++++++++------------ art/step/step_savers.py | 19 +++++++++ art/step/steps.py | 22 ++++++---- 5 files changed, 127 insertions(+), 45 deletions(-) diff --git a/art/experiment/Experiment.py b/art/experiment/Experiment.py index 06752ab..f7c15e5 100644 --- a/art/experiment/Experiment.py +++ b/art/experiment/Experiment.py @@ -97,6 +97,10 @@ def run_all(self): def get_steps(self): return self.steps + def get_step(self, step_id): + return self.steps[step_id]["step"] + + def replace_step(self, step: "Step", step_id=-1): self.steps[step_id]["step"] = step if step_id == -1: diff --git a/art/step/checks.py b/art/step/checks.py index f44cf8f..83bf51b 100644 --- a/art/step/checks.py +++ b/art/step/checks.py @@ -18,6 +18,36 @@ class Check(ABC): description: str required_files: List[str] + @abstractmethod + def check(self, step) -> ResultOfCheck: + pass + + +class CheckResult(Check): + + @abstractmethod + def _check_method(self, result) -> ResultOfCheck: + pass + + def check(self, step) -> ResultOfCheck: + result = step.get_results() + return self._check_method(result) + + +class CheckResultExists(CheckResult): + def __init__(self, required_key): + self.required_key = required_key + def _check_method(self, result) -> ResultOfCheck: + if self.required_key in result: + return ResultOfCheck(is_positive=True) + else: + return ResultOfCheck( + is_positive=False, + error=f"Score {self.required_key} is not in results.json", + ) + + +class CheckScore(CheckResult): def __init__( self, metric, # This requires an object which was used to calculate metric @@ -26,9 +56,6 @@ def __init__( self.metric = metric self.value = value - @abstractmethod - def _check_method(self, result) -> ResultOfCheck: - pass def build_required_key(self, step, metric): metric = metric.__class__.__name__ @@ -42,8 +69,7 @@ def check(self, step) -> ResultOfCheck: self.build_required_key(step, self.metric) return self._check_method(result) - -class CheckScoreExists(Check): +class CheckScoreExists(CheckScore): def __init__(self, metric): super().__init__(metric, None) @@ -57,7 +83,7 @@ def _check_method(self, result) -> ResultOfCheck: ) -class CheckScoreEqualsTo(Check): +class CheckScoreEqualsTo(CheckScore): def _check_method(self, result) -> ResultOfCheck: if result[self.required_key] == self.value: return ResultOfCheck(is_positive=True) @@ -68,7 +94,7 @@ def _check_method(self, result) -> ResultOfCheck: ) -class CheckScoreCloseTo(Check): +class CheckScoreCloseTo(CheckScore): def __init__( self, metric, # This requires an object which was used to calculate metric @@ -95,7 +121,7 @@ def _check_method(self, result) -> ResultOfCheck: ) -class CheckScoreGreaterThan(Check): +class CheckScoreGreaterThan(CheckScore): def _check_method(self, result) -> ResultOfCheck: if result[self.required_key] > self.value: return ResultOfCheck(is_positive=True) @@ -106,7 +132,7 @@ def _check_method(self, result) -> ResultOfCheck: ) -class CheckScoreLessThan(Check): +class CheckScoreLessThan(CheckScore): def _check_method(self, result) -> ResultOfCheck: if result[self.required_key] < self.value: return ResultOfCheck(is_positive=True) diff --git a/art/step/step.py b/art/step/step.py index 1d6ee82..e2ad099 100644 --- a/art/step/step.py +++ b/art/step/step.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from typing import Any, Dict - +import hashlib +import inspect import lightning as L from art.core.base_components.base_model import ArtModule @@ -14,11 +15,63 @@ class Step(ABC): description: str idx: int = None - def __init__(self, model: ArtModule, trainer: L.Trainer): + def __init__(self): self.results = {} + + + def __call__( + self, + previous_states: Dict, + datamodule: L.LightningDataModule, + metric_calculator: MetricCalculator, + ): + self.datamodule = datamodule + self.do(previous_states) + JSONStepSaver().save( + self.results, self.get_step_id(), self.name, "results.json" + ) + + def set_step_id(self, idx: int): + self.idx = idx + + def get_step_id(self) -> str: + return f"{self.idx}" + + def get_name_with_id(self) -> str: + return f"{self.idx}_{self.name}" + + def get_full_step_name(self) -> str: + return f"{self.get_step_id()}_{self.name}" + + def get_hash(self): + return hashlib.md5( + inspect.getsource(self.__class__).encode("utf-8") + ).hexdigest() + + + def add_result(self, name: str, value: Any): + self.results[name] = value + + def get_results(self) -> Dict: + return self.results + + def load_results(self): + self.results = JSONStepSaver().load(self.get_step_id(), self.name) + + def was_run(self): + path = JSONStepSaver().get_path( + self.get_step_id(), self.name, JSONStepSaver.RESULT_NAME + ) + return path.exists() + + def get_model_name(self) -> str: + return "" + +class ModelStep(Step): + def __init__(self, model: ArtModule, trainer: L.Trainer): + super().__init__() self.model = model self.trainer = trainer - self.results = {} def __call__( self, @@ -50,9 +103,6 @@ def test(self, trainer_kwargs: Dict): result = self.trainer.test(model=self.model, **trainer_kwargs) self.results.update(result[0]) - def set_step_id(self, idx: int): - self.idx = idx - def get_model_name(self) -> str: return self.model.__class__.__name__ @@ -63,32 +113,11 @@ def get_step_id(self) -> str: else f"{self.idx}" ) - def get_name_with_id(self) -> str: - return f"{self.idx}_{self.name}" - - def get_full_step_name(self) -> str: - return f"{self.get_step_id()}_{self.name}" - def get_hash(self) -> str: return self.model.get_hash() - def add_result(self, name: str, value: Any): - self.results[name] = value - - def get_results(self) -> Dict: - return self.results - - def load_results(self): - self.results = JSONStepSaver().load(self.get_step_id(), self.name) - def get_current_stage(self) -> str: return self.trainer.state.stage.value - def was_run(self): - path = JSONStepSaver().get_path( - self.get_step_id(), self.name, JSONStepSaver.RESULT_NAME - ) - return path.exists() - def get_check_stage(self) -> str: return TrainingStage.VALIDATION.value diff --git a/art/step/step_savers.py b/art/step/step_savers.py index 82243a5..8cc39e5 100644 --- a/art/step/step_savers.py +++ b/art/step/step_savers.py @@ -1,6 +1,7 @@ import json from abc import ABC, abstractmethod from pathlib import Path +import matplotlib.pyplot as plt BASE_PATH = Path("checkpoints") @@ -20,6 +21,9 @@ def ensure_directory(self, step_id: str, step_name: str): def get_path(self, step_id: str, step_name: str, filename: str): return BASE_PATH / f"{step_id}_{step_name}" / filename + def exists(self, step_id: str, step_name: str, filename: str): + return self.get_path(step_id, step_name, filename).exists() + class JSONStepSaver(StepSaver): RESULT_NAME = "results.json" @@ -32,3 +36,18 @@ def save(self, obj: any, step_id: str, step_name: str, filename: str = RESULT_NA def load(self, step_id: str, step_name: str, filename: str = RESULT_NAME): with open(self.get_path(step_id, step_name, filename), "r") as f: return json.load(f) + + def exists(self, step_id: str, step_name: str, filename: str): + self.exists(step_id, step_name, RESULT_NAME) + + +class MatplotLibSaver(StepSaver): + def save(self, obj: plt.Figure, step_id: str, step_name: str, filename: str = ""): + self.ensure_directory(step_id, step_name) + filepath = self.get_path(step_id, step_name, filename) + filepath.parent.mkdir(exist_ok=True) + obj.savefig(filepath) + plt.close(obj) + + def load(self, step_id, step_name: str, filename: str): + raise NotImplementedError() diff --git a/art/step/steps.py b/art/step/steps.py index a2a340d..0467af7 100644 --- a/art/step/steps.py +++ b/art/step/steps.py @@ -1,18 +1,22 @@ from typing import Dict, Iterable, Optional, Union - from lightning import LightningDataModule, Trainer from lightning.pytorch.loggers import Logger from art.core.base_components.base_model import ArtModule -from art.step.step import Step +from art.step.step import Step, ModelStep from art.utils.enums import TrainingStage class ExploreData(Step): """This class checks whether we have some markdown file description of the dataset + we implemented visualizations""" + name = "Data analysis" + description = "This step allows you to perform data analysis and extract information that is necessery in next steps" + + def get_step_id(self) -> str: + return f"data_analysis" -class EvaluateBaseline(Step): +class EvaluateBaseline(ModelStep): """This class takes a baseline and evaluates/trains it on the dataset""" name = "Evaluate Baseline" @@ -30,7 +34,7 @@ def do(self, previous_states: Dict): self.validate(trainer_kwargs={"datamodule": self.datamodule}) -class CheckLossOnInit(Step): +class CheckLossOnInit(ModelStep): name = "Check Loss On Init" description = "Checks loss on init" @@ -45,7 +49,7 @@ def do(self, previous_states: Dict): self.validate(trainer_kwargs={"dataloaders": train_loader}) -class OverfitOneBatch(Step): +class OverfitOneBatch(ModelStep): name = "Overfit One Batch" description = "Overfits one batch" @@ -70,7 +74,7 @@ def get_check_stage(self): return TrainingStage.TRAIN.value -class Overfit(Step): +class Overfit(ModelStep): name = "Overfit" description = "Overfits model" @@ -92,7 +96,7 @@ def get_check_stage(self): return TrainingStage.TRAIN.value -class Regularize(Step): +class Regularize(ModelStep): name = "Regularize" description = "Regularizes model" @@ -111,7 +115,7 @@ def do(self, previous_states: Dict): self.train(trainer_kwargs={"datamodule": self.datamodule}) -class Tune(Step): +class Tune(ModelStep): name = "Tune" description = "Tunes model" @@ -131,5 +135,5 @@ def do(self, previous_states: Dict): trainer.tune(model=self.model, datamodule=self.datamodule) -class Squeeze(Step): +class Squeeze(ModelStep): pass