Skip to content

Commit

Permalink
Clean up pytorch-lightning logger (rusty1s#210)
Browse files Browse the repository at this point in the history
* initial commit:

* code cov

* move progressbar
  • Loading branch information
rusty1s authored Feb 2, 2022
1 parent 37a0b59 commit 12a931e
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 2 deletions.
5 changes: 4 additions & 1 deletion bin/kumo-train
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ from kumo.config import cfg
from kumo.model import create_model
from kumo.train import create_loader
from kumo.train.encoder import get_emb_size
from kumo.train.trainer import ProgressBar
from kumo.utils import tracing
from kumo.utils.visualization import visualize_scalar_distribution

Expand Down Expand Up @@ -71,7 +72,9 @@ def main(args):
mode='min')
callbacks = [early_stopping, ckpt]
else:
callbacks = None
callbacks = []

callbacks.append(ProgressBar())

trainer = pl.Trainer(
gpus=gpus,
Expand Down
10 changes: 10 additions & 0 deletions kumo/train/trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Any, List, Optional

import pytorch_lightning as pl
import torch

from kumo.task import Task
Expand All @@ -20,3 +21,12 @@ def __init__(self, model: torch.nn.Module, task: Task, epochs: int,

def fit(self):
pass


class ProgressBar(pl.callbacks.ProgressBar):
def get_metrics(self, trainer, model):
# Remove `loss` and `v_num` values from the logger.
items = super().get_metrics(trainer, model)
items.pop('loss', None)
items.pop('v_num', None)
return items
3 changes: 2 additions & 1 deletion test/train/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from kumo.model import create_model
from kumo.train import create_loader
from kumo.train.encoder import get_emb_size
from kumo.train.trainer import ProgressBar


@pytest.mark.parametrize('task', [
Expand Down Expand Up @@ -49,5 +50,5 @@ def test_trainer(task):
)

# Training
trainer = pl.Trainer(fast_dev_run=True)
trainer = pl.Trainer(fast_dev_run=True, callbacks=[ProgressBar()])
trainer.fit(model, loaders[0], loaders[1])

0 comments on commit 12a931e

Please sign in to comment.