diff --git a/main.py b/main.py index 532f154d..8547ba93 100644 --- a/main.py +++ b/main.py @@ -9,10 +9,8 @@ from hydra.utils import instantiate from omegaconf import DictConfig, OmegaConf -from retinal_rl.classification.loss import ClassificationContext from retinal_rl.framework_interface import TrainingFramework from retinal_rl.models.brain import Brain -from retinal_rl.models.goal import Goal from retinal_rl.rl.sample_factory.sf_framework import SFFramework from runner.analyze import analyze from runner.dataset import get_datasets @@ -40,7 +38,7 @@ def _program(cfg: DictConfig): brain = Brain(**cfg.brain).to(device) if hasattr(cfg, "optimizer"): - goal = Goal[ClassificationContext](brain, dict(cfg.optimizer.goal)) + objective = instantiate(cfg.optimizer.losses, brain=brain) optimizer = instantiate(cfg.optimizer.optimizer, brain.parameters()) else: warnings.warn("No optimizer config specified, is that wanted?") @@ -68,7 +66,7 @@ def _program(cfg: DictConfig): cfg, device, brain, - goal, + objective, optimizer, train_set, test_set, @@ -82,7 +80,7 @@ def _program(cfg: DictConfig): cfg, device, brain, - goal, + objective, histories, train_set, test_set, diff --git a/resources/config_templates/user/optimizer/class-recon.yaml b/resources/config_templates/user/optimizer/class-recon.yaml index 8c518a50..1a0b99d9 100644 --- a/resources/config_templates/user/optimizer/class-recon.yaml +++ b/resources/config_templates/user/optimizer/class-recon.yaml @@ -2,55 +2,35 @@ optimizer: # torch.optim Class and parameters _target_: torch.optim.Adam lr: 0.0003 -goal: - recon: - min_epoch: 0 # Epoch to start optimizer - max_epoch: 100 # Epoch to stop optimizer - losses: # Weighted optimizer losses as defined in retinal-rl - - _target_: retinal_rl.models.loss.ReconstructionLoss - weight: ${recon_weight_retina} - - _target_: retinal_rl.classification.loss.ClassificationLoss - weight: ${eval:'1-${recon_weight_retina}'} +losses: + - _target_: retinal_rl.classification.loss.PercentCorrect + - _target_: retinal_rl.classification.loss.ClassificationLoss target_circuits: # Circuit parameters to optimize with this optimizer. We train the retina and the decoder exclusively to maximize reconstruction - retina - decode: - min_epoch: 0 # Epoch to start optimizer - max_epoch: 100 # Epoch to stop optimizer - losses: # Weighted optimizer losses as defined in retinal-rl - - _target_: retinal_rl.models.loss.ReconstructionLoss - weight: 1 - target_circuits: # Circuit parameters to optimize with this optimizer. We train the retina and the decoder exclusively to maximize reconstruction - - decoder - - inferotemporal_decoder - mixed: - min_epoch: 0 - max_epoch: 100 - losses: - - _target_: retinal_rl.models.loss.ReconstructionLoss - weight: ${recon_weight_thalamus} - - _target_: retinal_rl.classification.loss.ClassificationLoss - weight: ${eval:'1-${recon_weight_thalamus}'} - target_circuits: # The thalamus is somewhat sensitive to task losses - thalamus - cortex: - min_epoch: 0 - max_epoch: 100 - losses: - - _target_: retinal_rl.models.loss.ReconstructionLoss - weight: ${recon_weight_cortex} - - _target_: retinal_rl.classification.loss.ClassificationLoss - weight: ${eval:'1-${recon_weight_cortex}'} - target_circuits: # Visual cortex and downstream layers are driven by the task - visual_cortex - inferotemporal - class: - min_epoch: 0 - max_epoch: 100 - losses: - - _target_: retinal_rl.classification.loss.ClassificationLoss - weight: 1 - - _target_: retinal_rl.classification.loss.PercentCorrect - weight: 0 - target_circuits: # Visual cortex and downstream layers are driven by the task - prefrontal - classifier + weights: + - ${eval:'1-${recon_weight_retina}'} + - ${eval:'1-${recon_weight_thalamus}'} + - ${eval:'1-${recon_weight_cortex}'} + - 1 + - 1 + - 1 + - _target_: retinal_rl.models.loss.ReconstructionLoss + min_epoch: 0 # Epoch to start optimizer + max_epoch: 100 # Epoch to stop optimizer + target_circuits: # Circuit parameters to optimize with this optimizer. We train the retina and the decoder exclusively to maximize reconstruction + - retina + - thalamus + - visual_cortex + - decoder + - inferotemporal_decoder + weights: + - ${recon_weight_retina} + - ${recon_weight_thalamus} + - ${recon_weight_cortex} + - 1 + - 1 diff --git a/retinal_rl/analysis/plot.py b/retinal_rl/analysis/plot.py index 1897e12e..413d64ee 100644 --- a/retinal_rl/analysis/plot.py +++ b/retinal_rl/analysis/plot.py @@ -18,7 +18,7 @@ from torchvision.utils import make_grid from retinal_rl.models.brain import Brain -from retinal_rl.models.goal import ContextT, Goal +from retinal_rl.models.objective import ContextT, Objective from retinal_rl.util import FloatArray @@ -107,7 +107,7 @@ def plot_transforms( return fig -def plot_brain_and_optimizers(brain: Brain, goal: Goal[ContextT]) -> Figure: +def plot_brain_and_optimizers(brain: Brain, objective: Objective[ContextT]) -> Figure: """Visualize the Brain's connectome organized by depth and highlight optimizer targets using border colors. Args: @@ -147,7 +147,7 @@ def plot_brain_and_optimizers(brain: Brain, goal: Goal[ContextT]) -> Figure: color_map = {"sensor": "lightblue", "circuit": "lightgreen"} # Generate colors for optimizers - optimizer_colors = sns.color_palette("husl", len(goal.losses)) + optimizer_colors = sns.color_palette("husl", len(objective.losses)) # Prepare node colors and edge colors node_colors: List[str] = [] @@ -160,8 +160,8 @@ def plot_brain_and_optimizers(brain: Brain, goal: Goal[ContextT]) -> Figure: # Determine if the node is targeted by an optimizer edge_color = "none" - for i, optimizer_name in enumerate(goal.losses.keys()): - if node in goal.target_circuits[optimizer_name]: + for i, optimizer_name in enumerate(objective.losses.keys()): + if node in objective.target_circuits[optimizer_name]: edge_color = optimizer_colors[i] break edge_colors.append(edge_color) @@ -192,7 +192,7 @@ def plot_brain_and_optimizers(brain: Brain, goal: Goal[ContextT]) -> Figure: markersize=15, markeredgewidth=3, ) - for name, color in zip(goal.losses.keys(), optimizer_colors) + for name, color in zip(objective.losses.keys(), optimizer_colors) ] # Add legend elements for sensor and circuit diff --git a/retinal_rl/classification/loss.py b/retinal_rl/classification/loss.py index 0061b3a7..8fcbca31 100644 --- a/retinal_rl/classification/loss.py +++ b/retinal_rl/classification/loss.py @@ -1,6 +1,6 @@ """Objectives for training models.""" -from typing import Dict, Tuple +from typing import Dict, List, Tuple import torch import torch.nn as nn @@ -38,9 +38,15 @@ def __init__( class ClassificationLoss(Loss[ClassificationContext]): """Loss for computing the cross entropy loss.""" - def __init__(self, weight: float = 1.0): + def __init__( + self, + min_epoch: int = 0, + max_epoch: int = 1, + target_circuits: List[str] = [], + weights: List[float] = [], + ): """Initialize the classification loss.""" - super().__init__(weight) + super().__init__(min_epoch, max_epoch, target_circuits, weights) self.loss_fn = nn.CrossEntropyLoss() def compute_value(self, context: ClassificationContext) -> Tensor: @@ -59,9 +65,14 @@ def compute_value(self, context: ClassificationContext) -> Tensor: class PercentCorrect(Loss[ClassificationContext]): """(Inverse) Loss for computing the percent correct classification.""" - def __init__(self, weight: float = 1.0): - """Initialize the percent correct classification loss.""" - super().__init__(weight) + def __init__( + self, + min_epoch: int = 0, + max_epoch: int = 1, + target_circuits: List[str] = [], + weights: List[float] = [], + ): + super().__init__(min_epoch, max_epoch, target_circuits, weights) def compute_value(self, context: ClassificationContext) -> Tensor: """Compute the percent correct classification.""" diff --git a/retinal_rl/classification/training.py b/retinal_rl/classification/training.py index 5fa5849d..ce1bf104 100644 --- a/retinal_rl/classification/training.py +++ b/retinal_rl/classification/training.py @@ -18,7 +18,7 @@ get_classification_context, ) from retinal_rl.models.brain import Brain -from retinal_rl.models.goal import Goal +from retinal_rl.models.objective import Objective logger = logging.getLogger(__name__) @@ -26,7 +26,7 @@ def run_epoch( device: torch.device, brain: Brain, - goal: Goal[ClassificationContext], + objective: Objective[ClassificationContext], optimizer: Optimizer, history: Dict[str, List[float]], epoch: int, @@ -43,7 +43,7 @@ def run_epoch( ---- device (torch.device): The device to run the computations on. brain (Brain): The Brain model to train and evaluate. - goal (Goal): The goal object specifying the training objectives. + objective (Objective): The objective object specifying the training objectives. optimizer (Optimizer): The optimizer for updating the model parameters. history (Dict[str, List[float]]): A dictionary to store the training history. epoch (int): The current epoch number. @@ -56,10 +56,10 @@ def run_epoch( """ train_losses = process_dataset( - device, brain, goal, optimizer, epoch, trainloader, is_training=True + device, brain, objective, optimizer, epoch, trainloader, is_training=True ) test_losses = process_dataset( - device, brain, goal, optimizer, epoch, testloader, is_training=False + device, brain, objective, optimizer, epoch, testloader, is_training=False ) # Update history @@ -76,7 +76,7 @@ def run_epoch( def process_dataset( device: torch.device, brain: Brain, - goal: Goal[ClassificationContext], + objective: Objective[ClassificationContext], optimizer: Optimizer, epoch: int, dataloader: DataLoader[Tuple[Tensor, Tensor, int]], @@ -109,20 +109,20 @@ def process_dataset( if is_training: brain.train() - losses, obj_dict = goal.backward(context) + losses = objective.backward(context) optimizer.step() optimizer.zero_grad(set_to_none=True) else: with torch.no_grad(): brain.eval() - losses, obj_dict = goal.evaluate_objectives(context) + losses: Dict[str, float] = {} + for loss in objective.losses: + losses[loss.key_name] = loss(context).item() # Accumulate losses and objectives for key, value in losses.items(): total_losses[key] = total_losses.get(key, 0.0) + value - for key, value in obj_dict.items(): - total_losses[key] = total_losses.get(key, 0.0) + value steps += 1 diff --git a/retinal_rl/models/goal.py b/retinal_rl/models/goal.py deleted file mode 100644 index c90d0f50..00000000 --- a/retinal_rl/models/goal.py +++ /dev/null @@ -1,210 +0,0 @@ -"""Module for managing optimization of complex neural network models with multiple circuits.""" - -import logging -from typing import Dict, Generic, List, Tuple - -import torch -from hydra.utils import instantiate -from omegaconf import DictConfig -from torch.nn.parameter import Parameter - -from retinal_rl.models.brain import Brain -from retinal_rl.models.loss import ContextT, Loss - -logger = logging.getLogger(__name__) - - -class Goal(Generic[ContextT]): - """Manages multiple optimizers that target NeuralCircuits in a Brain. - - This class handles the initialization, state management, and optimization steps - for multiple optimizers, each associated with specific circuits and objectives. - - - Attributes - ---------- - brain (Brain): The neural network model being optimized. - losses (OrderedDict[str, Optimizer]): Instantiated optimizers, sorted based on connectome. - objectives (Dict[str, List[WeightedLoss]]): Losses for each optimizer. - target_circuits (Dict[str, List[str]]): Target circuits for each optimizer. - min_epochs (Dict[str, int]): Minimum epochs for each optimizer. - max_epochs (Dict[str, int]): Maximum epochs for each optimizer. - - """ - - def __init__(self, brain: Brain, objective_configs: Dict[str, DictConfig]): - """Initialize the BrainOptimizer. - - Args: - ---- - brain (Brain): The neural network model to optimize. - optimizer (Optimizer): The optimizer to use for training. - objective_configs (Dict[str, DictConfig]): Configuration for each optimizer. - Each config should specify target_circuits, optimizer settings, and objectives. - - Raises: - ------ - ValueError: If a specified circuit is not found in the brain. - - """ - self.device = next(brain.parameters()).device - self.losses: Dict[str, List[Loss[ContextT]]] = {} - self.target_circuits: Dict[str, List[str]] = {} - self.min_epochs: Dict[str, int] = {} - self.max_epochs: Dict[str, int] = {} - self.params: Dict[str, List[Parameter]] = {} - - for objective, config in objective_configs.items(): - # Collect parameters from target circuits - params = [] - self.min_epochs[objective] = config.get("min_epoch", 0) - self.max_epochs[objective] = config.get("max_epoch", -1) - self.target_circuits[objective] = config.target_circuits - if not set(config.target_circuits).issubset(brain.connectome.nodes): - logger.warning( - f"Some target circuits for objective: {objective} are not in the brain's connectome" - ) - for circuit_name in config.target_circuits: - if circuit_name in brain.circuits: - params.extend(brain.circuits[circuit_name].parameters()) - - self.params[objective] = params - - # Initialize objectives - self.losses[objective] = [ - instantiate(obj_config) for obj_config in config.losses - ] - logger.info( - f"Initialized objective: {objective}, with losses: {[obj.key_name for obj in self.losses[objective]]}, and target circuits: {[circuit_name for circuit_name in self.target_circuits[objective]]}" - ) - - def evaluate_objective( - self, objective: str, context: ContextT - ) -> Tuple[torch.Tensor, Dict[str, float]]: - """Compute the total loss for a specific objective. - - Args: - ---- - objective (str): Name of the objective. - context (ContextT]): Context information for computing objectives. - - Returns: - ------- - Tuple[torch.Tensor, Dict[str, float]]: A tuple containing the total loss - and a dictionary of raw loss values for each objective. - - """ - total_loss = torch.tensor(0.0, device=self.device) - loss_dict: Dict[str, float] = {} - for loss in self.losses[objective]: - weighted_loss, raw_loss = loss(context) - total_loss += weighted_loss - loss_dict[loss.key_name] = raw_loss.item() - return total_loss, loss_dict - - def evaluate_objectives( - self, context: ContextT - ) -> Tuple[Dict[str, float], Dict[str, float]]: - """Compute all objectives without computing gradients. - - This method is useful for evaluation purposes. - - Args: - ---- - context (Dict[str, Any]): Context information for computing objectives. - - Returns: - ------- - Tuple[Dict[str, float], Dict[str, float]]: A tuple containing dictionaries - of total objectives and raw loss values for each objective. - - """ - objectives: Dict[str, float] = {} - loss_dict: Dict[str, float] = {} - for objective in self.losses.keys(): - loss, sub_obj_dict = self.evaluate_objective(objective, context) - objectives[f"{objective}_objective"] = loss.item() - loss_dict.update(sub_obj_dict) - return objectives, loss_dict - - def _is_training_epoch(self, name: str, epoch: int) -> bool: - """Check if the objective should currently be pursued. - - Args: - ---- - name (str): Name of the optimizer. - epoch (int): Current epoch number. - - Returns: - ------- - bool: True if the objective should continue training, False otherwise. - - """ - if epoch < self.min_epochs[name]: - return False - if self.max_epochs[name] < 0: - return True - return epoch < self.max_epochs[name] - - def backward(self, context: ContextT) -> Tuple[Dict[str, float], Dict[str, float]]: - """Compute a backward pass over the brain with respect to all objectives. - - This method computes losses, performs backpropagation, and updates parameters - for all NeuralCircuits. - - Args: - ---- - context (ContextT): Context information for computing objectives. - - Returns: - ------- - Tuple[Dict[str, float], Dict[str, float]]: A tuple containing dictionaries - of total losses and raw loss values for each objective. - - """ - objectives: Dict[str, float] = {} - loss_dict: Dict[str, float] = {} - - retain_graph = True - - for i, objective in enumerate(self.losses.keys()): - # Compute losses - loss, sub_loss_dict = self.evaluate_objective(objective, context) - objectives[f"{objective}_objective"] = loss.item() - loss_dict.update(sub_loss_dict) - - # Skip training if the optimizer is not at a training epoch - if not self._is_training_epoch(objective, context.epoch): - continue - - # Set retain_graph to True for all but the last optimizer - retain_graph = i < len(self.losses) - 1 - - # Get parameters for this optimizer - params = self.params[objective] - - # Compute gradients - grads = torch.autograd.grad( - loss, params, create_graph=False, retain_graph=retain_graph - ) - - # Manually update parameters - with torch.no_grad(): - for param, grad in zip(params, grads): - if param.grad is None: - param.grad = grad - else: - param.grad += grad - - # Perform optimization step - return objectives, loss_dict - - def num_epochs(self) -> int: - """Get the maximum number of epochs over all optimizers. - - Returns - ------- - int: The maximum number of epochs across all optimizers. - - """ - return max(self.max_epochs.values()) diff --git a/retinal_rl/models/loss.py b/retinal_rl/models/loss.py index 4417d0ec..6a0e7073 100644 --- a/retinal_rl/models/loss.py +++ b/retinal_rl/models/loss.py @@ -1,7 +1,7 @@ """Losses for training models, and the context required to evaluate them.""" from abc import abstractmethod -from typing import Dict, Generic, List, Tuple, TypeVar +from typing import Dict, Generic, List, TypeVar import torch import torch.nn as nn @@ -41,11 +41,20 @@ def __init__( class Loss(Generic[ContextT]): """Base class for losses that can be used to define a multiobjective optimization problem.""" - def __init__(self, weight: float = 1.0): + def __init__( + self, + min_epoch: int = 0, + max_epoch: int = 1, + target_circuits: List[str] = [], + weights: List[float] = [], + ): """Initialize the loss with a weight.""" - self.weight = weight + self.min_epoch = min_epoch + self.max_epoch = max_epoch + self.target_circuits = target_circuits + self.weights = weights - def __call__(self, context: ContextT) -> Tuple[Tensor, Tensor]: + def __call__(self, context: ContextT) -> Tensor: """Compute the weighted loss for this loss. Args: @@ -54,11 +63,28 @@ def __call__(self, context: ContextT) -> Tuple[Tensor, Tensor]: Returns: ------- - Tuple[Tensor, Tensor]: A tuple containing the weighted loss and the raw loss value. + Tensor: A tuple containing the weighted loss and the raw loss value. """ - value = self.compute_value(context) - return (self.weight * value, value) + return self.compute_value(context) + + def is_training_epoch(self, epoch: int) -> bool: + """Check if the objective should currently be pursued. + + Args: + ---- + epoch (int): Current epoch number. + + Returns: + ------- + bool: True if the objective should continue training, False otherwise. + + """ + if epoch < self.min_epoch: + return False + if self.max_epoch < 0: + return True + return epoch < self.max_epoch @abstractmethod def compute_value(self, context: ContextT) -> Tensor: @@ -74,9 +100,15 @@ def key_name(self) -> str: class ReconstructionLoss(Loss[ContextT]): """Loss for computing the reconstruction loss between inputs and reconstructions.""" - def __init__(self, weight: float = 1.0): + def __init__( + self, + min_epoch: int = 0, + max_epoch: int = 1, + target_circuits: List[str] = [], + weights: List[float] = [], + ): """Initialize the reconstruction loss loss.""" - super().__init__(weight) + super().__init__(min_epoch, max_epoch, target_circuits, weights) self.loss_fn = nn.MSELoss(reduction="mean") def compute_value(self, context: ContextT) -> Tensor: @@ -95,10 +127,19 @@ def compute_value(self, context: ContextT) -> Tensor: class L1Sparsity(Loss[ContextT]): """Loss for computing the L1 sparsity of activations.""" - def __init__(self, weight: float, target_responses: List[str]): + def __init__( + self, + target_responses: List[str], + min_epoch: int = 0, + max_epoch: int = 1, + target_circuits: List[str] = [], + weights: List[float] = [], + ): + """Initialize the reconstruction loss loss.""" + super().__init__(min_epoch, max_epoch, target_circuits, weights) + """Initialize the L1 sparsity loss.""" self.target_responses = target_responses - super().__init__(weight) def compute_value(self, context: ContextT) -> Tensor: """Compute the L1 sparsity of activations.""" @@ -115,12 +156,18 @@ class KLDivergenceSparsity(Loss[ContextT]): """Loss for computing the KL divergence sparsity of activations.""" def __init__( - self, weight: float, target_responses: List[str], target_sparsity: float = 0.05 + self, + target_responses: List[str], + target_sparsity: float = 0.05, + min_epoch: int = 0, + max_epoch: int = 1, + target_circuits: List[str] = [], + weights: List[float] = [], ): """Initialize the KL divergence sparsity loss.""" + super().__init__(min_epoch, max_epoch, target_circuits, weights) self.target_responses = target_responses self.target_sparsity = target_sparsity - super().__init__(weight) def compute_value(self, context: ContextT) -> torch.Tensor: """Compute the KL divergence sparsity of activations.""" diff --git a/runner/analyze.py b/runner/analyze.py index e1503df5..b3c97d0e 100644 --- a/runner/analyze.py +++ b/runner/analyze.py @@ -5,10 +5,10 @@ import matplotlib.pyplot as plt import torch +import wandb from matplotlib.figure import Figure from omegaconf import DictConfig -import wandb from retinal_rl.analysis.plot import ( layer_receptive_field_plots, plot_brain_and_optimizers, @@ -25,7 +25,7 @@ ) from retinal_rl.dataset import Imageset from retinal_rl.models.brain import Brain -from retinal_rl.models.goal import ContextT, Goal +from retinal_rl.models.objective import ContextT, Objective logger = logging.getLogger(__name__) @@ -91,7 +91,7 @@ def analyze( cfg: DictConfig, device: torch.device, brain: Brain, - goal: Goal[ContextT], + objective: Objective[ContextT], histories: Dict[str, List[float]], train_set: Imageset, test_set: Imageset, @@ -111,7 +111,7 @@ def analyze( if epoch == 0: rf_sizes_fig = plot_receptive_field_sizes(cnn_analysis) _process_figure(cfg, False, rf_sizes_fig, init_dir, "receptive_field_sizes", 0) - graph_fig = plot_brain_and_optimizers(brain, goal) + graph_fig = plot_brain_and_optimizers(brain, objective) _process_figure(cfg, False, graph_fig, init_dir, "brain_graph", 0) transforms = transform_base_images(train_set, num_steps=5, num_images=2) transforms_fig = plot_transforms(**transforms) diff --git a/runner/train.py b/runner/train.py index 264324d6..f673e835 100644 --- a/runner/train.py +++ b/runner/train.py @@ -5,16 +5,16 @@ from typing import Dict, List import torch +import wandb from omegaconf import DictConfig from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader -import wandb from retinal_rl.classification.loss import ClassificationContext from retinal_rl.classification.training import process_dataset, run_epoch from retinal_rl.dataset import Imageset from retinal_rl.models.brain import Brain -from retinal_rl.models.goal import Goal +from retinal_rl.models.objective import Objective from runner.analyze import analyze from runner.util import save_checkpoint @@ -26,7 +26,7 @@ def train( cfg: DictConfig, device: torch.device, brain: Brain, - goal: Goal[ClassificationContext], + objective: Objective[ClassificationContext], optimizer: Optimizer, train_set: Imageset, test_set: Imageset, @@ -40,7 +40,7 @@ def train( cfg (DictConfig): The configuration for the experiment. device (torch.device): The device to run the computations on. brain (Brain): The Brain model to train and evaluate. - goal (Goal): The optimizer for updating the model parameters. + objective (Objective): The optimizer for updating the model parameters. train_set (Imageset): The training dataset. test_set (Imageset): The test dataset. initial_epoch (int): The epoch to start training from. @@ -56,11 +56,23 @@ def train( if initial_epoch == 0: brain.train() train_losses = process_dataset( - device, brain, goal, optimizer, initial_epoch, trainloader, is_training=False + device, + brain, + objective, + optimizer, + initial_epoch, + trainloader, + is_training=False, ) brain.eval() test_losses = process_dataset( - device, brain, goal, optimizer, initial_epoch, testloader, is_training=False + device, + brain, + objective, + optimizer, + initial_epoch, + testloader, + is_training=False, ) # Initialize the history @@ -75,7 +87,7 @@ def train( cfg, device, brain, - goal, + objective, history, train_set, test_set, @@ -88,11 +100,11 @@ def train( logger.info("Initialization complete.") - for epoch in range(initial_epoch + 1, goal.num_epochs() + 1): + for epoch in range(initial_epoch + 1, objective.num_epochs() + 1): brain, history = run_epoch( device, brain, - goal, + objective, optimizer, history, epoch, @@ -122,7 +134,7 @@ def train( cfg, device, brain, - goal, + objective, history, train_set, test_set,