Skip to content

Commit

Permalink
Churned a bunch of code for an alpha version. Scan runs
Browse files Browse the repository at this point in the history
  • Loading branch information
alex404 committed Oct 16, 2024
1 parent 9bf848e commit 5721a8b
Show file tree
Hide file tree
Showing 9 changed files with 147 additions and 309 deletions.
8 changes: 3 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf

from retinal_rl.classification.loss import ClassificationContext
from retinal_rl.framework_interface import TrainingFramework
from retinal_rl.models.brain import Brain
from retinal_rl.models.goal import Goal
from retinal_rl.rl.sample_factory.sf_framework import SFFramework
from runner.analyze import analyze
from runner.dataset import get_datasets
Expand Down Expand Up @@ -40,7 +38,7 @@ def _program(cfg: DictConfig):

brain = Brain(**cfg.brain).to(device)
if hasattr(cfg, "optimizer"):
goal = Goal[ClassificationContext](brain, dict(cfg.optimizer.goal))
objective = instantiate(cfg.optimizer.losses, brain=brain)
optimizer = instantiate(cfg.optimizer.optimizer, brain.parameters())
else:
warnings.warn("No optimizer config specified, is that wanted?")
Expand Down Expand Up @@ -68,7 +66,7 @@ def _program(cfg: DictConfig):
cfg,
device,
brain,
goal,
objective,
optimizer,
train_set,
test_set,
Expand All @@ -82,7 +80,7 @@ def _program(cfg: DictConfig):
cfg,
device,
brain,
goal,
objective,
histories,
train_set,
test_set,
Expand Down
70 changes: 25 additions & 45 deletions resources/config_templates/user/optimizer/class-recon.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,55 +2,35 @@ optimizer: # torch.optim Class and parameters
_target_: torch.optim.Adam
lr: 0.0003

goal:
recon:
min_epoch: 0 # Epoch to start optimizer
max_epoch: 100 # Epoch to stop optimizer
losses: # Weighted optimizer losses as defined in retinal-rl
- _target_: retinal_rl.models.loss.ReconstructionLoss
weight: ${recon_weight_retina}
- _target_: retinal_rl.classification.loss.ClassificationLoss
weight: ${eval:'1-${recon_weight_retina}'}
losses:
- _target_: retinal_rl.classification.loss.PercentCorrect
- _target_: retinal_rl.classification.loss.ClassificationLoss
target_circuits: # Circuit parameters to optimize with this optimizer. We train the retina and the decoder exclusively to maximize reconstruction
- retina
decode:
min_epoch: 0 # Epoch to start optimizer
max_epoch: 100 # Epoch to stop optimizer
losses: # Weighted optimizer losses as defined in retinal-rl
- _target_: retinal_rl.models.loss.ReconstructionLoss
weight: 1
target_circuits: # Circuit parameters to optimize with this optimizer. We train the retina and the decoder exclusively to maximize reconstruction
- decoder
- inferotemporal_decoder
mixed:
min_epoch: 0
max_epoch: 100
losses:
- _target_: retinal_rl.models.loss.ReconstructionLoss
weight: ${recon_weight_thalamus}
- _target_: retinal_rl.classification.loss.ClassificationLoss
weight: ${eval:'1-${recon_weight_thalamus}'}
target_circuits: # The thalamus is somewhat sensitive to task losses
- thalamus
cortex:
min_epoch: 0
max_epoch: 100
losses:
- _target_: retinal_rl.models.loss.ReconstructionLoss
weight: ${recon_weight_cortex}
- _target_: retinal_rl.classification.loss.ClassificationLoss
weight: ${eval:'1-${recon_weight_cortex}'}
target_circuits: # Visual cortex and downstream layers are driven by the task
- visual_cortex
- inferotemporal
class:
min_epoch: 0
max_epoch: 100
losses:
- _target_: retinal_rl.classification.loss.ClassificationLoss
weight: 1
- _target_: retinal_rl.classification.loss.PercentCorrect
weight: 0
target_circuits: # Visual cortex and downstream layers are driven by the task
- prefrontal
- classifier
weights:
- ${eval:'1-${recon_weight_retina}'}
- ${eval:'1-${recon_weight_thalamus}'}
- ${eval:'1-${recon_weight_cortex}'}
- 1
- 1
- 1
- _target_: retinal_rl.models.loss.ReconstructionLoss
min_epoch: 0 # Epoch to start optimizer
max_epoch: 100 # Epoch to stop optimizer
target_circuits: # Circuit parameters to optimize with this optimizer. We train the retina and the decoder exclusively to maximize reconstruction
- retina
- thalamus
- visual_cortex
- decoder
- inferotemporal_decoder
weights:
- ${recon_weight_retina}
- ${recon_weight_thalamus}
- ${recon_weight_cortex}
- 1
- 1
12 changes: 6 additions & 6 deletions retinal_rl/analysis/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torchvision.utils import make_grid

from retinal_rl.models.brain import Brain
from retinal_rl.models.goal import ContextT, Goal
from retinal_rl.models.objective import ContextT, Objective
from retinal_rl.util import FloatArray


Expand Down Expand Up @@ -107,7 +107,7 @@ def plot_transforms(
return fig


def plot_brain_and_optimizers(brain: Brain, goal: Goal[ContextT]) -> Figure:
def plot_brain_and_optimizers(brain: Brain, objective: Objective[ContextT]) -> Figure:
"""Visualize the Brain's connectome organized by depth and highlight optimizer targets using border colors.
Args:
Expand Down Expand Up @@ -147,7 +147,7 @@ def plot_brain_and_optimizers(brain: Brain, goal: Goal[ContextT]) -> Figure:
color_map = {"sensor": "lightblue", "circuit": "lightgreen"}

# Generate colors for optimizers
optimizer_colors = sns.color_palette("husl", len(goal.losses))
optimizer_colors = sns.color_palette("husl", len(objective.losses))

# Prepare node colors and edge colors
node_colors: List[str] = []
Expand All @@ -160,8 +160,8 @@ def plot_brain_and_optimizers(brain: Brain, goal: Goal[ContextT]) -> Figure:

# Determine if the node is targeted by an optimizer
edge_color = "none"
for i, optimizer_name in enumerate(goal.losses.keys()):
if node in goal.target_circuits[optimizer_name]:
for i, optimizer_name in enumerate(objective.losses.keys()):
if node in objective.target_circuits[optimizer_name]:
edge_color = optimizer_colors[i]
break
edge_colors.append(edge_color)
Expand Down Expand Up @@ -192,7 +192,7 @@ def plot_brain_and_optimizers(brain: Brain, goal: Goal[ContextT]) -> Figure:
markersize=15,
markeredgewidth=3,
)
for name, color in zip(goal.losses.keys(), optimizer_colors)
for name, color in zip(objective.losses.keys(), optimizer_colors)
]

# Add legend elements for sensor and circuit
Expand Down
23 changes: 17 additions & 6 deletions retinal_rl/classification/loss.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Objectives for training models."""

from typing import Dict, Tuple
from typing import Dict, List, Tuple

import torch
import torch.nn as nn
Expand Down Expand Up @@ -38,9 +38,15 @@ def __init__(
class ClassificationLoss(Loss[ClassificationContext]):
"""Loss for computing the cross entropy loss."""

def __init__(self, weight: float = 1.0):
def __init__(
self,
min_epoch: int = 0,
max_epoch: int = 1,
target_circuits: List[str] = [],
weights: List[float] = [],
):
"""Initialize the classification loss."""
super().__init__(weight)
super().__init__(min_epoch, max_epoch, target_circuits, weights)
self.loss_fn = nn.CrossEntropyLoss()

def compute_value(self, context: ClassificationContext) -> Tensor:
Expand All @@ -59,9 +65,14 @@ def compute_value(self, context: ClassificationContext) -> Tensor:
class PercentCorrect(Loss[ClassificationContext]):
"""(Inverse) Loss for computing the percent correct classification."""

def __init__(self, weight: float = 1.0):
"""Initialize the percent correct classification loss."""
super().__init__(weight)
def __init__(
self,
min_epoch: int = 0,
max_epoch: int = 1,
target_circuits: List[str] = [],
weights: List[float] = [],
):
super().__init__(min_epoch, max_epoch, target_circuits, weights)

def compute_value(self, context: ClassificationContext) -> Tensor:
"""Compute the percent correct classification."""
Expand Down
20 changes: 10 additions & 10 deletions retinal_rl/classification/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@
get_classification_context,
)
from retinal_rl.models.brain import Brain
from retinal_rl.models.goal import Goal
from retinal_rl.models.objective import Objective

logger = logging.getLogger(__name__)


def run_epoch(
device: torch.device,
brain: Brain,
goal: Goal[ClassificationContext],
objective: Objective[ClassificationContext],
optimizer: Optimizer,
history: Dict[str, List[float]],
epoch: int,
Expand All @@ -43,7 +43,7 @@ def run_epoch(
----
device (torch.device): The device to run the computations on.
brain (Brain): The Brain model to train and evaluate.
goal (Goal): The goal object specifying the training objectives.
objective (Objective): The objective object specifying the training objectives.
optimizer (Optimizer): The optimizer for updating the model parameters.
history (Dict[str, List[float]]): A dictionary to store the training history.
epoch (int): The current epoch number.
Expand All @@ -56,10 +56,10 @@ def run_epoch(
"""
train_losses = process_dataset(
device, brain, goal, optimizer, epoch, trainloader, is_training=True
device, brain, objective, optimizer, epoch, trainloader, is_training=True
)
test_losses = process_dataset(
device, brain, goal, optimizer, epoch, testloader, is_training=False
device, brain, objective, optimizer, epoch, testloader, is_training=False
)

# Update history
Expand All @@ -76,7 +76,7 @@ def run_epoch(
def process_dataset(
device: torch.device,
brain: Brain,
goal: Goal[ClassificationContext],
objective: Objective[ClassificationContext],
optimizer: Optimizer,
epoch: int,
dataloader: DataLoader[Tuple[Tensor, Tensor, int]],
Expand Down Expand Up @@ -109,20 +109,20 @@ def process_dataset(

if is_training:
brain.train()
losses, obj_dict = goal.backward(context)
losses = objective.backward(context)
optimizer.step()
optimizer.zero_grad(set_to_none=True)

else:
with torch.no_grad():
brain.eval()
losses, obj_dict = goal.evaluate_objectives(context)
losses: Dict[str, float] = {}
for loss in objective.losses:
losses[loss.key_name] = loss(context).item()

# Accumulate losses and objectives
for key, value in losses.items():
total_losses[key] = total_losses.get(key, 0.0) + value
for key, value in obj_dict.items():
total_losses[key] = total_losses.get(key, 0.0) + value

steps += 1

Expand Down
Loading

0 comments on commit 5721a8b

Please sign in to comment.