Skip to content

Commit

Permalink
chore: format classification framework
Browse files Browse the repository at this point in the history
  • Loading branch information
fabioseel committed Oct 30, 2024
1 parent 09eeecb commit 6dcade9
Showing 1 changed file with 33 additions and 23 deletions.
56 changes: 33 additions & 23 deletions runner/classification/classification_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@


class ClassificationFramework(TrainingFramework):

def __init__(self, cfg: DictConfig):
self.cfg = cfg
self.train_set, self.test_set = get_datasets(self.cfg)
Expand All @@ -27,28 +26,39 @@ def initialize(self, brain: Brain, optimizer: torch.optim.Optimizer):
)
return brain, optimizer

def train(self, device: torch.device, brain: Brain, optimizer: torch.optim.Optimizer, objective: Optional[Objective[ContextT]] = None):
#TODO: check objective type
def train(
self,
device: torch.device,
brain: Brain,
optimizer: torch.optim.Optimizer,
objective: Optional[Objective[ContextT]] = None,
):
# TODO: check objective type
train(
self.cfg,
device,
brain,
objective,
optimizer,
self.train_set,
self.test_set,
self.completed_epochs,
self.histories,
)
self.cfg,
device,
brain,
objective,
optimizer,
self.train_set,
self.test_set,
self.completed_epochs,
self.histories,
)

def analyze(self, device: torch.device, brain: Brain, objective: Optional[Objective[ContextT]] = None):
def analyze(
self,
device: torch.device,
brain: Brain,
objective: Optional[Objective[ContextT]] = None,
):
analyze(
self.cfg,
device,
brain,
objective,
self.histories,
self.train_set,
self.test_set,
self.completed_epochs,
)
self.cfg,
device,
brain,
objective,
self.histories,
self.train_set,
self.test_set,
self.completed_epochs,
)

0 comments on commit 6dcade9

Please sign in to comment.