Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

data analysis #118

Merged
merged 2 commits into from
Oct 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions art/experiment/Experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ def run_all(self):
def get_steps(self):
return self.steps

def get_step(self, step_id):
return self.steps[step_id]["step"]


def replace_step(self, step: "Step", step_id=-1):
self.steps[step_id]["step"] = step
if step_id == -1:
Expand Down
44 changes: 35 additions & 9 deletions art/step/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,36 @@ class Check(ABC):
description: str
required_files: List[str]

@abstractmethod
def check(self, step) -> ResultOfCheck:
pass


class CheckResult(Check):

@abstractmethod
def _check_method(self, result) -> ResultOfCheck:
pass

def check(self, step) -> ResultOfCheck:
result = step.get_results()
return self._check_method(result)


class CheckResultExists(CheckResult):
def __init__(self, required_key):
self.required_key = required_key
def _check_method(self, result) -> ResultOfCheck:
if self.required_key in result:
return ResultOfCheck(is_positive=True)
else:
return ResultOfCheck(
is_positive=False,
error=f"Score {self.required_key} is not in results.json",
)


class CheckScore(CheckResult):
def __init__(
self,
metric, # This requires an object which was used to calculate metric
Expand All @@ -23,9 +53,6 @@ def __init__(
self.metric = metric
self.value = value

@abstractmethod
def _check_method(self, result) -> ResultOfCheck:
pass

def build_required_key(self, step, metric):
metric = metric.__class__.__name__
Expand All @@ -39,8 +66,7 @@ def check(self, step) -> ResultOfCheck:
self.build_required_key(step, self.metric)
return self._check_method(result)


class CheckScoreExists(Check):
class CheckScoreExists(CheckScore):
def __init__(self, metric):
super().__init__(metric, None)

Expand All @@ -54,7 +80,7 @@ def _check_method(self, result) -> ResultOfCheck:
)


class CheckScoreEqualsTo(Check):
class CheckScoreEqualsTo(CheckScore):
def _check_method(self, result) -> ResultOfCheck:
if result[self.required_key] == self.value:
return ResultOfCheck(is_positive=True)
Expand All @@ -65,7 +91,7 @@ def _check_method(self, result) -> ResultOfCheck:
)


class CheckScoreCloseTo(Check):
class CheckScoreCloseTo(CheckScore):
def __init__(
self,
metric, # This requires an object which was used to calculate metric
Expand All @@ -92,7 +118,7 @@ def _check_method(self, result) -> ResultOfCheck:
)


class CheckScoreGreaterThan(Check):
class CheckScoreGreaterThan(CheckScore):
def _check_method(self, result) -> ResultOfCheck:
if result[self.required_key] > self.value:
return ResultOfCheck(is_positive=True)
Expand All @@ -103,7 +129,7 @@ def _check_method(self, result) -> ResultOfCheck:
)


class CheckScoreLessThan(Check):
class CheckScoreLessThan(CheckScore):
def _check_method(self, result) -> ResultOfCheck:
if result[self.required_key] < self.value:
return ResultOfCheck(is_positive=True)
Expand Down
85 changes: 57 additions & 28 deletions art/step/step.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABC, abstractmethod
from typing import Any, Dict

import hashlib
import inspect
import lightning as L
from lightning import Trainer

Expand All @@ -15,11 +16,63 @@ class Step(ABC):
description: str
idx: int = -1

def __init__(self, model: ArtModule, trainer: Trainer):
self.results: Dict[str, Any] = {}
def __init__(self):
self.results = {}


def __call__(
self,
previous_states: Dict,
datamodule: L.LightningDataModule,
metric_calculator: MetricCalculator,
):
self.datamodule = datamodule
self.do(previous_states)
JSONStepSaver().save(
self.results, self.get_step_id(), self.name, "results.json"
)

def set_step_id(self, idx: int):
self.idx = idx

def get_step_id(self) -> str:
return f"{self.idx}"

def get_name_with_id(self) -> str:
return f"{self.idx}_{self.name}"

def get_full_step_name(self) -> str:
return f"{self.get_step_id()}_{self.name}"

def get_hash(self):
return hashlib.md5(
inspect.getsource(self.__class__).encode("utf-8")
).hexdigest()


def add_result(self, name: str, value: Any):
self.results[name] = value

def get_results(self) -> Dict:
return self.results

def load_results(self):
self.results = JSONStepSaver().load(self.get_step_id(), self.name)

def was_run(self):
path = JSONStepSaver().get_path(
self.get_step_id(), self.name, JSONStepSaver.RESULT_NAME
)
return path.exists()

def get_model_name(self) -> str:
return ""

class ModelStep(Step):
def __init__(self, model: ArtModule, trainer: L.Trainer):
super().__init__()
self.model = model
self.trainer = trainer
self.results = {}

def __call__(
self,
Expand Down Expand Up @@ -51,9 +104,6 @@ def test(self, trainer_kwargs: Dict):
result = self.trainer.test(model=self.model, **trainer_kwargs)
self.results.update(result[0])

def set_step_id(self, idx: int):
self.idx = idx

def get_model_name(self) -> str:
return self.model.__class__.__name__

Expand All @@ -64,32 +114,11 @@ def get_step_id(self) -> str:
else f"{self.idx}"
)

def get_name_with_id(self) -> str:
return f"{self.idx}_{self.name}"

def get_full_step_name(self) -> str:
return f"{self.get_step_id()}_{self.name}"

def get_hash(self) -> str:
return self.model.get_hash()

def add_result(self, name: str, value: Any):
self.results[name] = value

def get_results(self) -> Dict:
return self.results

def load_results(self):
self.results = JSONStepSaver().load(self.get_step_id(), self.name)

def get_current_stage(self) -> str:
return self.trainer.state.stage.value

def was_run(self):
path = JSONStepSaver().get_path(
self.get_step_id(), self.name, JSONStepSaver.RESULT_NAME
)
return path.exists()

def get_check_stage(self) -> str:
return TrainingStage.VALIDATION.value
19 changes: 19 additions & 0 deletions art/step/step_savers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
from abc import ABC, abstractmethod
from pathlib import Path
import matplotlib.pyplot as plt
from typing import Any

BASE_PATH = Path("checkpoints")
Expand All @@ -21,6 +22,9 @@ def ensure_directory(self, step_id: str, step_name: str):
def get_path(self, step_id: str, step_name: str, filename: str):
return BASE_PATH / f"{step_id}_{step_name}" / filename

def exists(self, step_id: str, step_name: str, filename: str):
return self.get_path(step_id, step_name, filename).exists()


class JSONStepSaver(StepSaver):
RESULT_NAME = "results.json"
Expand All @@ -33,3 +37,18 @@ def save(self, obj: Any, step_id: str, step_name: str, filename: str = RESULT_NA
def load(self, step_id: str, step_name: str, filename: str = RESULT_NAME):
with open(self.get_path(step_id, step_name, filename), "r") as f:
return json.load(f)

def exists(self, step_id: str, step_name: str, filename: str):
self.exists(step_id, step_name, RESULT_NAME)


class MatplotLibSaver(StepSaver):
def save(self, obj: plt.Figure, step_id: str, step_name: str, filename: str = ""):
self.ensure_directory(step_id, step_name)
filepath = self.get_path(step_id, step_name, filename)
filepath.parent.mkdir(exist_ok=True)
obj.savefig(filepath)
plt.close(obj)

def load(self, step_id, step_name: str, filename: str):
raise NotImplementedError()
24 changes: 14 additions & 10 deletions art/step/steps.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
from typing import Dict, Iterable, Optional, Union

from lightning import Trainer
from lightning import LightningDataModule, Trainer
from lightning.pytorch.loggers import Logger

from art.core.base_components.base_model import ArtModule
from art.step.step import Step
from art.step.step import Step, ModelStep
from art.utils.enums import TrainingStage


class ExploreData(Step):
"""This class checks whether we have some markdown file description of the dataset + we implemented visualizations"""
name = "Data analysis"
description = "This step allows you to perform data analysis and extract information that is necessery in next steps"

def get_step_id(self) -> str:
return f"data_analysis"


class EvaluateBaseline(Step):
class EvaluateBaseline(ModelStep):
"""This class takes a baseline and evaluates/trains it on the dataset"""

name = "Evaluate Baseline"
Expand All @@ -30,7 +34,7 @@ def do(self, previous_states: Dict):
self.validate(trainer_kwargs={"datamodule": self.datamodule})


class CheckLossOnInit(Step):
class CheckLossOnInit(ModelStep):
name = "Check Loss On Init"
description = "Checks loss on init"

Expand All @@ -45,7 +49,7 @@ def do(self, previous_states: Dict):
self.validate(trainer_kwargs={"dataloaders": train_loader})


class OverfitOneBatch(Step):
class OverfitOneBatch(ModelStep):
name = "Overfit One Batch"
description = "Overfits one batch"

Expand All @@ -70,7 +74,7 @@ def get_check_stage(self):
return TrainingStage.TRAIN.value


class Overfit(Step):
class Overfit(ModelStep):
name = "Overfit"
description = "Overfits model"

Expand All @@ -92,7 +96,7 @@ def get_check_stage(self):
return TrainingStage.TRAIN.value


class Regularize(Step):
class Regularize(ModelStep):
name = "Regularize"
description = "Regularizes model"

Expand All @@ -111,7 +115,7 @@ def do(self, previous_states: Dict):
self.train(trainer_kwargs={"datamodule": self.datamodule})


class Tune(Step):
class Tune(ModelStep):
name = "Tune"
description = "Tunes model"

Expand All @@ -131,5 +135,5 @@ def do(self, previous_states: Dict):
trainer.tune(model=self.model, datamodule=self.datamodule)


class Squeeze(Step):
class Squeeze(ModelStep):
pass