Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: ✨ specify output_dir in reference scripts #1820

Merged
merged 2 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion references/classification/train_pytorch_character.py
odulcy-mindee marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import logging
import multiprocessing as mp
import time
from pathlib import Path

import numpy as np
import torch
Expand Down Expand Up @@ -335,7 +336,7 @@ def main(args):
val_loss, acc = evaluate(model, val_loader, batch_transforms)
if val_loss < min_loss:
print(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...")
torch.save(model.state_dict(), f"./{exp_name}.pt")
torch.save(model.state_dict(), Path(args.output_dir) / f"{exp_name}.pt")
min_loss = val_loss
print(f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} (Acc: {acc:.2%})")
# W&B
Expand Down
4 changes: 3 additions & 1 deletion references/classification/train_pytorch_orientation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import logging
import multiprocessing as mp
import time
from pathlib import Path

import numpy as np
import torch
Expand Down Expand Up @@ -341,7 +342,7 @@ def main(args):
val_loss, acc = evaluate(model, val_loader, batch_transforms)
if val_loss < min_loss:
print(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...")
torch.save(model.state_dict(), f"./{exp_name}.pt")
torch.save(model.state_dict(), Path(args.output_dir) / f"{exp_name}.pt")
min_loss = val_loss
print(f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} (Acc: {acc:.2%})")
# W&B
Expand Down Expand Up @@ -376,6 +377,7 @@ def parse_args():
)

parser.add_argument("arch", type=str, help="classification model to train")
parser.add_argument("--output_dir", type=str, default=".", help="path to save checkpoints and final model")
parser.add_argument("--type", type=str, required=True, choices=["page", "crop"], help="type of data to train on")
parser.add_argument("--train_path", type=str, required=True, help="path to training data folder")
parser.add_argument("--val_path", type=str, required=True, help="path to validation data folder")
Expand Down
4 changes: 3 additions & 1 deletion references/classification/train_tensorflow_character.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import datetime
import time
from pathlib import Path

import numpy as np
import tensorflow as tf
Expand Down Expand Up @@ -298,7 +299,7 @@ def main(args):
val_loss, acc = evaluate(model, val_loader, batch_transforms)
if val_loss < min_loss:
print(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...")
model.save_weights(f"./{exp_name}.weights.h5")
model.save_weights(Path(args.output_dir) / f"{exp_name}.weights.h5")
min_loss = val_loss
print(f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} (Acc: {acc:.2%})")
# W&B
Expand Down Expand Up @@ -345,6 +346,7 @@ def parse_args():
)

parser.add_argument("arch", type=str, help="text-recognition model to train")
parser.add_argument("--output_dir", type=str, default=".", help="path to save checkpoints and final model")
parser.add_argument("--name", type=str, default=None, help="Name of your training experiment")
parser.add_argument("--epochs", type=int, default=10, help="number of epochs to train the model on")
parser.add_argument("-b", "--batch_size", type=int, default=64, help="batch size for training")
Expand Down
4 changes: 3 additions & 1 deletion references/classification/train_tensorflow_orientation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import datetime
import time
from pathlib import Path

import numpy as np
import tensorflow as tf
Expand Down Expand Up @@ -308,7 +309,7 @@ def main(args):
val_loss, acc = evaluate(model, val_loader, batch_transforms)
if val_loss < min_loss:
print(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...")
model.save_weights(f"./{exp_name}.weights.h5")
model.save_weights(Path(args.output_dir) / f"{exp_name}.weights.h5")
min_loss = val_loss
print(f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} (Acc: {acc:.2%})")
# W&B
Expand Down Expand Up @@ -355,6 +356,7 @@ def parse_args():
)

parser.add_argument("arch", type=str, help="classification model to train")
parser.add_argument("--output_dir", type=str, default=".", help="path to save checkpoints and final model")
parser.add_argument("--type", type=str, required=True, choices=["page", "crop"], help="type of data to train on")
parser.add_argument("--train_path", type=str, help="path to training data folder")
parser.add_argument("--val_path", type=str, required=True, help="path to validation data folder")
Expand Down
6 changes: 4 additions & 2 deletions references/detection/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import logging
import multiprocessing as mp
import time
from pathlib import Path

import numpy as np
import torch
Expand Down Expand Up @@ -390,11 +391,11 @@ def main(args):
val_loss, recall, precision, mean_iou = evaluate(model, val_loader, batch_transforms, val_metric, amp=args.amp)
if val_loss < min_loss:
print(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...")
torch.save(model.state_dict(), f"./{exp_name}.pt")
torch.save(model.state_dict(), Path(args.output_dir) / f"{exp_name}.pt")
min_loss = val_loss
if args.save_interval_epoch:
print(f"Saving state at epoch: {epoch + 1}")
torch.save(model.state_dict(), f"./{exp_name}_epoch{epoch + 1}.pt")
torch.save(model.state_dict(), Path(args.output_dir) / f"{exp_name}_epoch{epoch + 1}.pt")
log_msg = f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} "
if any(val is None for val in (recall, precision, mean_iou)):
log_msg += "(Undefined metric value, caused by empty GTs or predictions)"
Expand Down Expand Up @@ -428,6 +429,7 @@ def parse_args():
)

parser.add_argument("arch", type=str, help="text-detection model to train")
parser.add_argument("--output_dir", type=str, default=".", help="path to save checkpoints and final model")
parser.add_argument("--train_path", type=str, required=True, help="path to training data folder")
parser.add_argument("--val_path", type=str, required=True, help="path to validation data folder")
parser.add_argument("--name", type=str, default=None, help="Name of your training experiment")
Expand Down
6 changes: 4 additions & 2 deletions references/detection/train_pytorch_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import hashlib
import multiprocessing
import time
from pathlib import Path

import numpy as np
import torch
Expand Down Expand Up @@ -410,11 +411,11 @@ def main(rank: int, world_size: int, args):
)
if val_loss < min_loss:
print(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...")
torch.save(model.module.state_dict(), f"./{exp_name}.pt")
torch.save(model.module.state_dict(), Path(args.output_dir) / f"{exp_name}.pt")
min_loss = val_loss
if args.save_interval_epoch:
print(f"Saving state at epoch: {epoch + 1}")
torch.save(model.state_dict(), f"./{exp_name}_epoch{epoch + 1}.pt")
torch.save(model.state_dict(), Path(args.output_dir) / f"{exp_name}_epoch{epoch + 1}.pt")
odulcy-mindee marked this conversation as resolved.
Show resolved Hide resolved
log_msg = f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} "
if any(val is None for val in (recall, precision, mean_iou)):
log_msg += "(Undefined metric value, caused by empty GTs or predictions)"
Expand Down Expand Up @@ -453,6 +454,7 @@ def parse_args():
parser.add_argument("--backend", default="nccl", type=str, help="backend to use for torch DDP")

parser.add_argument("arch", type=str, help="text-detection model to train")
parser.add_argument("--output_dir", type=str, default=".", help="path to save checkpoints and final model")
parser.add_argument("--train_path", type=str, required=True, help="path to training data folder")
parser.add_argument("--val_path", type=str, required=True, help="path to validation data folder")
parser.add_argument("--name", type=str, default=None, help="Name of your training experiment")
Expand Down
4 changes: 3 additions & 1 deletion references/detection/train_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import datetime
import hashlib
import time
from pathlib import Path

import numpy as np
import tensorflow as tf
Expand Down Expand Up @@ -357,7 +358,7 @@ def main(args):
min_loss = val_loss
if args.save_interval_epoch:
print(f"Saving state at epoch: {epoch + 1}")
model.save_weights(f"./{exp_name}_{epoch + 1}.weights.h5")
model.save_weights(Path(args.output_dir) / f"{exp_name}_{epoch + 1}.weights.h5")
odulcy-mindee marked this conversation as resolved.
Show resolved Hide resolved
log_msg = f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} "
if any(val is None for val in (recall, precision, mean_iou)):
log_msg += "(Undefined metric value, caused by empty GTs or predictions)"
Expand Down Expand Up @@ -401,6 +402,7 @@ def parse_args():
)

parser.add_argument("arch", type=str, help="text-detection model to train")
parser.add_argument("--output_dir", type=str, default=".", help="path to save checkpoints and final model")
parser.add_argument("--train_path", type=str, required=True, help="path to training data folder")
parser.add_argument("--val_path", type=str, required=True, help="path to validation data folder")
parser.add_argument("--name", type=str, default=None, help="Name of your training experiment")
Expand Down
3 changes: 2 additions & 1 deletion references/recognition/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def main(args):
val_loss, exact_match, partial_match = evaluate(model, val_loader, batch_transforms, val_metric, amp=args.amp)
if val_loss < min_loss:
print(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...")
torch.save(model.state_dict(), f"./{exp_name}.pt")
torch.save(model.state_dict(), Path(args.output_dir) / f"{exp_name}.pt")
min_loss = val_loss
print(
f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} "
Expand Down Expand Up @@ -427,6 +427,7 @@ def parse_args():
)

parser.add_argument("arch", type=str, help="text-recognition model to train")
parser.add_argument("--output_dir", type=str, default=".", help="path to save checkpoints and final model")
parser.add_argument("--train_path", type=str, default=None, help="path to train data folder(s)")
parser.add_argument("--val_path", type=str, default=None, help="path to val data folder")
parser.add_argument(
Expand Down
3 changes: 2 additions & 1 deletion references/recognition/train_pytorch_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def main(rank: int, world_size: int, args):
# random parameters and gradients are synchronized in backward passes.
# Therefore, saving it in one process is sufficient.
print(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...")
torch.save(model.module.state_dict(), f"./{exp_name}.pt")
torch.save(model.module.state_dict(), Path(args.output_dir) / f"{exp_name}.pt")
min_loss = val_loss
print(
f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} "
Expand Down Expand Up @@ -365,6 +365,7 @@ def parse_args():
parser.add_argument("--backend", default="nccl", type=str, help="Backend to use for Torch DDP")

parser.add_argument("arch", type=str, help="text-recognition model to train")
parser.add_argument("--output_dir", type=str, default=".", help="path to save checkpoints and final model")
parser.add_argument("--train_path", type=str, default=None, help="path to train data folder(s)")
parser.add_argument("--val_path", type=str, default=None, help="path to val data folder")
parser.add_argument(
Expand Down
3 changes: 2 additions & 1 deletion references/recognition/train_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def main(args):
val_loss, exact_match, partial_match = evaluate(model, val_loader, batch_transforms, val_metric)
if val_loss < min_loss:
print(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...")
model.save_weights(f"./{exp_name}.weights.h5")
model.save_weights(Path(args.output_dir) / f"{exp_name}.weights.h5")
min_loss = val_loss
print(
f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} "
Expand Down Expand Up @@ -391,6 +391,7 @@ def parse_args():
)

parser.add_argument("arch", type=str, help="text-recognition model to train")
parser.add_argument("--output_dir", type=str, default=".", help="path to save checkpoints and final model")
parser.add_argument("--train_path", type=str, default=None, help="path to train data folder(s)")
parser.add_argument("--val_path", type=str, default=None, help="path to val data folder")
parser.add_argument(
Expand Down
Loading