Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Frameworks #55

Merged
merged 6 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 10 additions & 42 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,8 @@

from retinal_rl.framework_interface import TrainingFramework
from retinal_rl.rl.sample_factory.sf_framework import SFFramework
from runner.analyze import analyze
from runner.dataset import get_datasets
from runner.initialize import initialize
from runner.classification.classification_framework import ClassificationFramework
from runner.sweep import launch_sweep
from runner.train import train
from runner.util import create_brain, delete_results

# Load the eval resolver for OmegaConf
Expand All @@ -25,7 +22,7 @@
# Hydra entry point
@hydra.main(config_path="config/base", config_name="config", version_base=None)
def _program(cfg: DictConfig):
#TODO: Instead of doing checks of the config here, we should implement
# TODO: Instead of doing checks of the config here, we should implement
# sth like the configstore which ensures config parameters are present

if cfg.command == "clean":
Expand Down Expand Up @@ -56,50 +53,21 @@ def _program(cfg: DictConfig):
cache_path = os.path.join(hydra.utils.get_original_cwd(), "cache")
if cfg.framework == "rl":
framework = SFFramework(cfg, data_root=cache_path)
elif cfg.framework == "classification":
framework = ClassificationFramework(cfg)
else:
# TODO: Make ClassifierEngine
train_set, test_set = get_datasets(cfg)

brain, optimizer, histories, completed_epochs = initialize(
cfg,
brain,
optimizer,
raise NotImplementedError(
"only 'rl' or 'classification' framework implemented currently"
)
if cfg.command == "train":
train(
cfg,
device,
brain,
objective,
optimizer,
train_set,
test_set,
completed_epochs,
histories,
)
sys.exit(0)

if cfg.command == "analyze":
analyze(
cfg,
device,
brain,
objective,
histories,
train_set,
test_set,
completed_epochs,
)
sys.exit(0)

raise ValueError("Invalid run_mode")

brain, optimizer = framework.initialize(brain, optimizer)

if cfg.command == "train":
framework.train()
framework.train(device, brain, optimizer, objective)
sys.exit(0)

if cfg.command == "analyze":
framework.analyze(cfg, device, brain, histories, None, None, completed_epochs)
framework.analyze(device, brain, objective)
sys.exit(0)

raise ValueError("Invalid run_mode")
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[tool.pyright]
include = ["retinal_rl"]
# include = ["retinal_rl*"]
typeCheckingMode = "strict"
reportMissingTypeStubs = "warning"
reportUnknownVariableType = "warning"
Expand Down Expand Up @@ -43,4 +43,4 @@ name = "retinal_rl"
version="0.0.1"

[tool.setuptools.packages.find]
include =["retinal_rl*"]
include =["retinal_rl*"]
29 changes: 15 additions & 14 deletions retinal_rl/framework_interface.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,28 @@
from typing import Dict, List, Protocol, Tuple
from typing import Optional, Protocol, Tuple

import torch
from omegaconf import DictConfig
from torch import Tensor
from torch.utils.data import Dataset

from retinal_rl.models.brain import Brain
from retinal_rl.models.loss import ContextT
from retinal_rl.models.objective import Objective


class TrainingFramework(Protocol):
# TODO: Check if all parameters applicable and sort arguments
# Especially get rid of config were possible (train? initialize could store all relevant parameters...)
def train(self): ...
def initialize(
self, brain: Brain, optimizer: torch.optim.Optimizer
) -> Tuple[Brain, torch.optim.Optimizer]: ...

def train(
self,
device: torch.device,
brain: Brain,
optimizer: torch.optim.Optimizer,
objective: Optional[Objective[ContextT]] = None,
): ...

# TODO: make static to be able to evaluate models from other stuff as well?
def analyze(
self,
cfg: DictConfig,
device: torch.device,
brain: Brain,
histories: Dict[str, List[float]],
train_set: Dataset[Tuple[Tensor, int]],
test_set: Dataset[Tuple[Tensor, int]],
epoch: int,
copy_checkpoint: bool = False,
objective: Optional[Objective[ContextT]] = None,
): ...
54 changes: 37 additions & 17 deletions retinal_rl/rl/sample_factory/sf_framework.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import argparse
import json
import os
import warnings
from argparse import Namespace
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, Optional

# from retinal_rl.rl.sample_factory.observer import RetinalAlgoObserver
import torch
Expand All @@ -20,11 +21,11 @@
from sample_factory.train import make_runner
from sample_factory.utils.attr_dict import AttrDict
from sample_factory.utils.typing import Config
from torch import Tensor
from torch.utils.data import Dataset

from retinal_rl.framework_interface import TrainingFramework
from retinal_rl.models.brain import Brain
from retinal_rl.models.loss import ContextT
from retinal_rl.models.objective import Objective
from retinal_rl.rl.sample_factory.arguments import (
add_retinal_env_args,
add_retinal_env_eval_args,
Expand All @@ -36,7 +37,6 @@


class SFFramework(TrainingFramework):

def __init__(self, cfg: DictConfig, data_root: str):
self.data_root = data_root

Expand All @@ -48,7 +48,24 @@ def __init__(self, cfg: DictConfig, data_root: str):
register_retinal_env(self.sf_cfg.env, self.data_root, self.sf_cfg.input_satiety)
global_model_factory().register_actor_critic_factory(SampleFactoryBrain)

def train(self):
def initialize(self, brain: Brain, optimizer: torch.optim.Optimizer):
# brain = SFFramework.load_brain_from_checkpoint(...)
# TODO: Implement load brain and optimizer state
return brain, optimizer

def train(
self,
device: torch.device,
brain: Brain,
optimizer: torch.optim.Optimizer,
objective: Optional[Objective[ContextT]] = None,
):
warnings.warn(
"device, brain, optimizer are initialized differently in sample_factory and thus there current state will be ignored"
)
warnings.warn(
"objective is currently not supported for sample factory simulations"
)
# Run simulation
if not (self.sf_cfg.dry_run):
cfg, runner = make_runner(self.sf_cfg)
Expand Down Expand Up @@ -109,24 +126,25 @@ def to_sf_cfg(self, cfg: DictConfig) -> Config:
self._set_cfg_cli_argument(sf_cfg, "env", cfg.dataset.env_name)
self._set_cfg_cli_argument(sf_cfg, "input_satiety", cfg.dataset.input_satiety)
self._set_cfg_cli_argument(sf_cfg, "device", cfg.system.device)
optimizer_name = str.lower(str.split(cfg.optimizer.optimizer._target_, sep='.')[-1])
optimizer_name = str.lower(
str.split(cfg.optimizer.optimizer._target_, sep=".")[-1]
)
self._set_cfg_cli_argument(sf_cfg, "optimizer", optimizer_name)

self._set_cfg_cli_argument(sf_cfg, "brain", OmegaConf.to_object(cfg.brain))
return sf_cfg

def analyze(
self,
cfg: DictConfig,
device: torch.device,
brain: Brain,
histories: Dict[str, List[float]],
train_set: Dataset[Tuple[Tensor | int]],
test_set: Dataset[Tuple[Tensor | int]],
epoch: int,
copy_checkpoint: bool = False,
objective: Optional[Objective[ContextT]] = None,
):
return enjoy(self.sf_cfg)
warnings.warn(
"device, brain, optimizer are initialized differently in sample_factory and thus there current state will be ignored"
)
enjoy(self.sf_cfg)
# TODO: Implement analyze function for sf framework

@staticmethod
def _set_cfg_cli_argument(cfg: Namespace, name: str, value: Any):
Expand Down Expand Up @@ -159,18 +177,20 @@ def get_checkpoint(cfg: Config) -> tuple[Dict[str, Any], AttrDict]:
"""
Load the model from checkpoint, initialize the environment, and return both.
"""
#verbose = False
# verbose = False

cfg = load_from_checkpoint(cfg)

device = torch.device("cpu" if cfg.device == "cpu" else "cuda")

policy_id = cfg.policy_index
name_prefix = dict(latest="checkpoint", best="best")[cfg.load_checkpoint_kind]
checkpoints = Learner.get_checkpoints(Learner.checkpoint_dir(cfg, policy_id), f"{name_prefix}_*")
checkpoint_dict:Dict[str, Any] = Learner.load_checkpoint(checkpoints, device)
checkpoints = Learner.get_checkpoints(
Learner.checkpoint_dir(cfg, policy_id), f"{name_prefix}_*"
)
checkpoint_dict: Dict[str, Any] = Learner.load_checkpoint(checkpoints, device)

return checkpoint_dict,cfg
return checkpoint_dict, cfg


def brain_from_actor_critic(actor_critic: SampleFactoryBrain) -> Brain:
Expand Down
File renamed without changes.
64 changes: 64 additions & 0 deletions runner/classification/classification_framework.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from typing import Optional

import torch
from omegaconf import DictConfig

from retinal_rl.framework_interface import TrainingFramework
from retinal_rl.models.brain import Brain
from retinal_rl.models.loss import ContextT
from retinal_rl.models.objective import Objective
from runner.classification.analyze import analyze
from runner.classification.dataset import get_datasets
from runner.classification.initialize import initialize
from runner.classification.train import train


class ClassificationFramework(TrainingFramework):
def __init__(self, cfg: DictConfig):
self.cfg = cfg
self.train_set, self.test_set = get_datasets(self.cfg)

def initialize(self, brain: Brain, optimizer: torch.optim.Optimizer):
brain, optimizer, self.histories, self.completed_epochs = initialize(
self.cfg,
brain,
optimizer,
)
return brain, optimizer

def train(
self,
device: torch.device,
brain: Brain,
optimizer: torch.optim.Optimizer,
objective: Optional[Objective[ContextT]] = None,
):
# TODO: check objective type
train(
self.cfg,
device,
brain,
objective,
optimizer,
self.train_set,
self.test_set,
self.completed_epochs,
self.histories,
)

def analyze(
self,
device: torch.device,
brain: Brain,
objective: Optional[Objective[ContextT]] = None,
):
analyze(
self.cfg,
device,
brain,
objective,
self.histories,
self.train_set,
self.test_set,
self.completed_epochs,
)
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion runner/train.py → runner/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from retinal_rl.classification.training import process_dataset, run_epoch
from retinal_rl.models.brain import Brain
from retinal_rl.models.objective import Objective
from runner.analyze import analyze
from runner.classification.analyze import analyze
from runner.util import save_checkpoint

# Initialize the logger
Expand Down
39 changes: 18 additions & 21 deletions tests/ci/lint.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#===============================================================================
# Description: Runs ruff either on all Python files or only on changed files
# compared to master branch using a specified Singularity container
# Runs both code formatting and linter, but additional arguments (other than
# --all and --fix) will only apply to linter.
#
# Arguments:
# $1 - Path to Singularity (.sif) container
Expand All @@ -23,32 +25,27 @@
CONTAINER="$1"
shift

# Check or fix
check="--check"
if [[ "$@" == *"--fix"* ]]; then
check=""
fi

# Check if --all flag is present
if [ "$1" = "--all" ]; then
changed_files="."
# Remove --all from arguments
shift

# Check or fix
check="--check"
if [[ "$@" == *"--fix"* ]]; then
check=""
fi

# Format
apptainer exec "$CONTAINER" ruff format "$check" .

# Run ruff on all files with any remaining arguments
apptainer exec "$CONTAINER" ruff check . "$@"
else
# Get changed Python files
changed_files=$(tests/ci/changed_py_files.sh)
if [ -n "$changed_files" ]; then
# Format
apptainer exec "$CONTAINER" ruff format "$check" $changed_files
fi

# Run ruff on changed files with any remaining arguments
apptainer exec "$CONTAINER" ruff check $changed_files "$@"
else
echo "No .py files changed"
fi
fi
if [ -n "$changed_files" ]; then
# Format
apptainer exec "$CONTAINER" ruff format $changed_files $check
# Run ruff on changed files with any remaining arguments
apptainer exec "$CONTAINER" ruff check $changed_files "$@"
else
echo "No .py files changed"
fi