Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SynthWordGenerator to text reco training scripts #825

Merged
merged 19 commits into from
Feb 22, 2022
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doctr/transforms/functional/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from .base import crop_boxes

__all__ = ["invert_colors", "rotate_sample", "crop_detection"]
__all__ = ["invert_colors", "rotated_img_tensor", "rotate_sample", "crop_detection"]
felixdittrich92 marked this conversation as resolved.
Show resolved Hide resolved


def invert_colors(img: tf.Tensor, min_val: float = 0.6) -> tf.Tensor:
Expand Down
97 changes: 70 additions & 27 deletions references/recognition/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
from fastprogress.fastprogress import master_bar, progress_bar
from torch.optim.lr_scheduler import CosineAnnealingLR, MultiplicativeLR, OneCycleLR
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torchvision.transforms import ColorJitter, Compose, Normalize
from torchvision.transforms import ColorJitter, Compose, GaussianBlur, InterpolationMode, Normalize, RandomRotation

from doctr import transforms as T
from doctr.datasets import VOCABS, RecognitionDataset
from doctr.datasets import VOCABS, RecognitionDataset, WordGenerator
from doctr.models import recognition
from doctr.utils.metrics import TextMatch
from utils import plot_recorder, plot_samples
Expand Down Expand Up @@ -181,11 +181,31 @@ def main(args):

# Load val data generator
st = time.time()
val_set = RecognitionDataset(
img_folder=os.path.join(args.val_path, 'images'),
labels_path=os.path.join(args.val_path, 'labels.json'),
img_transforms=T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True),
)
if args.val_path is not None:
felixdittrich92 marked this conversation as resolved.
Show resolved Hide resolved
val_set = RecognitionDataset(
img_folder=os.path.join(args.val_path, 'images'),
labels_path=os.path.join(args.val_path, 'labels.json'),
img_transforms=T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True),
)
else:
# Load synthetic data generator
val_set = WordGenerator(
vocab=VOCABS[args.vocab],
min_chars=1,
max_chars=17,
num_samples=int(args.num_synth_samples * 0.2),
font_family=[os.path.join(args.fonts_folder, f)
for f in os.listdir(args.fonts_folder) if f.endswith('.ttf')],
img_transforms=Compose([
T.RandomApply(RandomRotation(
[-6, 6], interpolation=InterpolationMode.BILINEAR, expand=True), .2),
T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True),
T.RandomApply(T.ColorInversion(min_val=1.0), 1.0),
T.RandomApply(GaussianBlur(
kernel_size=3, sigma=(0.3, 2.0)), .2),
]),
)

val_loader = DataLoader(
val_set,
batch_size=args.batch_size,
Expand Down Expand Up @@ -237,24 +257,44 @@ def main(args):

st = time.time()

# Load train data generator
base_path = Path(args.train_path)
parts = [base_path] if base_path.joinpath('labels.json').is_file() else [
base_path.joinpath(sub) for sub in os.listdir(base_path)
]
train_set = RecognitionDataset(
parts[0].joinpath('images'),
parts[0].joinpath('labels.json'),
img_transforms=Compose([
T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True),
# Augmentations
T.RandomApply(T.ColorInversion(), .1),
ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.02),
]),
)
if len(parts) > 1:
for subfolder in parts[1:]:
train_set.merge_dataset(RecognitionDataset(subfolder.joinpath('images'), subfolder.joinpath('labels.json')))
if args.train_path is not None:
felixdittrich92 marked this conversation as resolved.
Show resolved Hide resolved
# Load train data generator
base_path = Path(args.train_path)
parts = [base_path] if base_path.joinpath('labels.json').is_file() else [
base_path.joinpath(sub) for sub in os.listdir(base_path)
]
train_set = RecognitionDataset(
parts[0].joinpath('images'),
parts[0].joinpath('labels.json'),
img_transforms=Compose([
T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True),
# Augmentations
T.RandomApply(T.ColorInversion(), .1),
ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.02),
]),
)
if len(parts) > 1:
for subfolder in parts[1:]:
train_set.merge_dataset(RecognitionDataset(
subfolder.joinpath('images'), subfolder.joinpath('labels.json')))
else:
# Load synthetic data generator
train_set = WordGenerator(
vocab=VOCABS[args.vocab],
min_chars=1,
max_chars=17,
felixdittrich92 marked this conversation as resolved.
Show resolved Hide resolved
num_samples=args.num_synth_samples,
font_family=[os.path.join(args.fonts_folder, f)
for f in os.listdir(args.fonts_folder) if f.endswith('.ttf')],
img_transforms=Compose([
T.RandomApply(RandomRotation(
[-6, 6], interpolation=InterpolationMode.BILINEAR, expand=True), .2),
T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True),
T.RandomApply(T.ColorInversion(min_val=1.0), 0.9),
felixdittrich92 marked this conversation as resolved.
Show resolved Hide resolved
T.RandomApply(GaussianBlur(
kernel_size=3, sigma=(0.3, 2.0)), .2),
]),
)

train_loader = DataLoader(
train_set,
Expand Down Expand Up @@ -348,9 +388,12 @@ def parse_args():
parser = argparse.ArgumentParser(description='DocTR training script for text recognition (PyTorch)',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)

parser.add_argument('train_path', type=str, help='path to train data folder(s)')
parser.add_argument('val_path', type=str, help='path to val data folder')
parser.add_argument('train_path', type=str, default=None, help='path to train data folder(s)')
parser.add_argument('val_path', type=str, default=None, help='path to val data folder')
parser.add_argument('arch', type=str, help='text-recognition model to train')
parser.add_argument('--fonts_folder', type=str, default=None, help='path to folder with fonts for synthetic data')
parser.add_argument('--num_synth_samples', type=int, default=8000000,
help='number of synthetic samples to generate')
parser.add_argument('--name', type=str, default=None, help='Name of your training experiment')
parser.add_argument('--epochs', type=int, default=10, help='number of epochs to train the model on')
parser.add_argument('-b', '--batch_size', type=int, default=64, help='batch size for training')
Expand Down
105 changes: 75 additions & 30 deletions references/recognition/train_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@
tf.config.experimental.set_memory_growth(gpu_devices[0], True)

from doctr import transforms as T
from doctr.datasets import VOCABS, DataLoader, RecognitionDataset
from doctr.datasets import VOCABS, DataLoader, RecognitionDataset, WordGenerator
from doctr.models import recognition
from doctr.transforms.functional import rotated_img_tensor
from doctr.utils.metrics import TextMatch
from utils import plot_recorder, plot_samples

Expand Down Expand Up @@ -137,13 +138,34 @@ def main(args):
if args.amp:
mixed_precision.set_global_policy('mixed_float16')

# Load val data generator
st = time.time()
val_set = RecognitionDataset(
img_folder=os.path.join(args.val_path, 'images'),
labels_path=os.path.join(args.val_path, 'labels.json'),
img_transforms=T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True),
)

if args.val_path is not None:
# Load val data generator
val_set = RecognitionDataset(
img_folder=os.path.join(args.val_path, 'images'),
labels_path=os.path.join(args.val_path, 'labels.json'),
img_transforms=T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True),
)
else:
# Load synthetic data generator
val_set = WordGenerator(
vocab=VOCABS[args.vocab],
min_chars=1,
max_chars=17,
num_samples=int(args.num_synth_samples * 0.2),
font_family=[os.path.join(args.fonts_folder, f)
for f in os.listdir(args.fonts_folder) if f.endswith('.ttf')],
felixdittrich92 marked this conversation as resolved.
Show resolved Hide resolved
img_transforms=T.Compose([
T.RandomApply(T.LambdaTransformation(
lambda x: rotated_img_tensor(x, np.random.choice([-6.0, 6.0]), expand=True)), .2),
felixdittrich92 marked this conversation as resolved.
Show resolved Hide resolved
T.Resize((32, 128), preserve_aspect_ratio=True),
T.RandomApply(T.ColorInversion(min_val=1.0), 1.0),
T.RandomApply(T.GaussianBlur(
kernel_shape=3, std=(0.3, 2.0)), .2),
])
)

val_loader = DataLoader(
val_set,
batch_size=args.batch_size,
Expand Down Expand Up @@ -181,28 +203,48 @@ def main(args):

st = time.time()

# Load train data generator
base_path = Path(args.train_path)
parts = [base_path] if base_path.joinpath('labels.json').is_file() else [
base_path.joinpath(sub) for sub in os.listdir(base_path)
]
train_set = RecognitionDataset(
parts[0].joinpath('images'),
parts[0].joinpath('labels.json'),
img_transforms=T.Compose([
T.RandomApply(T.ColorInversion(), .1),
T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True),
# Augmentations
T.RandomJpegQuality(60),
T.RandomSaturation(.3),
T.RandomContrast(.3),
T.RandomBrightness(.3),
]),
)
if args.train_path is not None:
# Load train data generator
base_path = Path(args.train_path)
parts = [base_path] if base_path.joinpath('labels.json').is_file() else [
base_path.joinpath(sub) for sub in os.listdir(base_path)
]
train_set = RecognitionDataset(
parts[0].joinpath('images'),
parts[0].joinpath('labels.json'),
img_transforms=T.Compose([
T.RandomApply(T.ColorInversion(), .1),
T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True),
# Augmentations
T.RandomJpegQuality(60),
T.RandomSaturation(.3),
T.RandomContrast(.3),
T.RandomBrightness(.3),
]),
)

if len(parts) > 1:
for subfolder in parts[1:]:
train_set.merge_dataset(RecognitionDataset(subfolder.joinpath('images'), subfolder.joinpath('labels.json')))
if len(parts) > 1:
for subfolder in parts[1:]:
train_set.merge_dataset(RecognitionDataset(
subfolder.joinpath('images'), subfolder.joinpath('labels.json')))
else:
# Load synthetic data generator
train_set = WordGenerator(
vocab=VOCABS[args.vocab],
min_chars=1,
max_chars=17,
num_samples=args.num_synth_samples,
font_family=[os.path.join(args.fonts_folder, f)
for f in os.listdir(args.fonts_folder) if f.endswith('.ttf')],
img_transforms=T.Compose([
T.RandomApply(T.LambdaTransformation(
lambda x: rotated_img_tensor(x, np.random.choice([-6.0, 6.0]), expand=True)), .2),
T.Resize((32, 128), preserve_aspect_ratio=True),
T.RandomApply(T.ColorInversion(min_val=1.0), 0.9),
T.RandomApply(T.GaussianBlur(
kernel_shape=3, std=(0.3, 2.0)), .2),
])
)

train_loader = DataLoader(
train_set,
Expand Down Expand Up @@ -302,9 +344,12 @@ def parse_args():
parser = argparse.ArgumentParser(description='DocTR training script for text recognition (TensorFlow)',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)

parser.add_argument('train_path', type=str, help='path to train data folder(s)')
parser.add_argument('val_path', type=str, help='path to val data folder')
parser.add_argument('train_path', type=str, default=None, help='path to train data folder(s)')
parser.add_argument('val_path', type=str, default=None, help='path to val data folder')
parser.add_argument('arch', type=str, help='text-recognition model to train')
parser.add_argument('--fonts_folder', type=str, default=None, help='path to folder with fonts for synthetic data')
parser.add_argument('--num_synth_samples', type=int, default=8000000,
help='number of synthetic samples to generate')
parser.add_argument('--name', type=str, default=None, help='Name of your training experiment')
parser.add_argument('--epochs', type=int, default=10, help='number of epochs to train the model on')
parser.add_argument('-b', '--batch_size', type=int, default=64, help='batch size for training')
Expand Down