Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SynthWordGenerator to text reco training scripts #825

Merged
merged 19 commits into from
Feb 22, 2022
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/references.yml
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,10 @@ jobs:
unzip toy_recogition_set-036a4d80.zip -d reco_set
- if: matrix.framework == 'tensorflow'
name: Train for a short epoch (TF)
run: python references/recognition/train_tensorflow.py ./reco_set ./reco_set crnn_vgg16_bn -b 4 --epochs 1
run: python references/recognition/train_tensorflow.py crnn_vgg16_bn --train_path ./reco_set --val_path ./reco_set -b 4 --epochs 1
- if: matrix.framework == 'pytorch'
name: Train for a short epoch (PT)
run: python references/recognition/train_pytorch.py ./reco_set ./reco_set crnn_mobilenet_v3_small -b 4 --epochs 1
run: python references/recognition/train_pytorch.py crnn_mobilenet_v3_small --train_path ./reco_set --val_path ./reco_set -b 4 --epochs 1

latency-text-recognition:
runs-on: ${{ matrix.os }}
Expand Down
8 changes: 4 additions & 4 deletions references/recognition/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,19 @@ pip install -r references/requirements.txt
You can start your training in TensorFlow:

```shell
python references/recognition/train_tensorflow.py path/to/your/train_set path/to/your/val_set crnn_vgg16_bn --epochs 5
python references/recognition/train_tensorflow.py crnn_vgg16_bn --train_path path/to/your/train_set --val_path path/to/your/val_set --epochs 5
```
or PyTorch:

```shell
python references/recognition/train_pytorch.py path/to/your/train_set path/to/your/val_set crnn_vgg16_bn --epochs 5 --device 0
python references/recognition/train_pytorch.py crnn_vgg16_bn --train_path path/to/your/train_set --val_path path/to/your/val_set --epochs 5 --device 0
```



## Data format

You need to provide both `train_path` and `val_path` arguments to start training.
You need to provide both `train_path` and `val_path` arguments to start training.
Each of these paths must lead to a 2-elements folder:

```shell
Expand All @@ -40,7 +40,7 @@ Each of these paths must lead to a 2-elements folder:
├── labels.json
```

The JSON files must contain word-labels for each picture as a string.
The JSON files must contain word-labels for each picture as a string.
The order of entries in the json does not matter.

```shell
Expand Down
122 changes: 91 additions & 31 deletions references/recognition/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from torchvision.transforms import ColorJitter, Compose, Normalize

from doctr import transforms as T
from doctr.datasets import VOCABS, RecognitionDataset
from doctr.datasets import VOCABS, RecognitionDataset, WordGenerator
from doctr.models import recognition
from doctr.utils.metrics import TextMatch
from utils import plot_recorder, plot_samples
Expand Down Expand Up @@ -179,13 +179,36 @@ def main(args):

torch.backends.cudnn.benchmark = True

vocab = VOCABS[args.vocab]
fonts = args.font.split(",")

# Load val data generator
st = time.time()
val_set = RecognitionDataset(
img_folder=os.path.join(args.val_path, 'images'),
labels_path=os.path.join(args.val_path, 'labels.json'),
img_transforms=T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True),
)
if args.val_path is not None:
felixdittrich92 marked this conversation as resolved.
Show resolved Hide resolved
with open(os.path.join(args.val_path, 'labels.json'), 'rb') as f:
val_hash = hashlib.sha256(f.read()).hexdigest()

val_set = RecognitionDataset(
img_folder=os.path.join(args.val_path, 'images'),
labels_path=os.path.join(args.val_path, 'labels.json'),
img_transforms=T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True),
)
else:
val_hash = None
# Load synthetic data generator
val_set = WordGenerator(
vocab=vocab,
min_chars=args.min_chars,
max_chars=args.max_chars,
num_samples=args.val_samples * len(vocab),
font_family=fonts,
img_transforms=Compose([
T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True),
# Ensure we have a 90% split of white-background images
T.RandomApply(T.ColorInversion(min_val=1.0), 0.9),
felixdittrich92 marked this conversation as resolved.
Show resolved Hide resolved
]),
)

val_loader = DataLoader(
val_set,
batch_size=args.batch_size,
Expand All @@ -197,13 +220,11 @@ def main(args):
)
print(f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in "
f"{len(val_loader)} batches)")
with open(os.path.join(args.val_path, 'labels.json'), 'rb') as f:
val_hash = hashlib.sha256(f.read()).hexdigest()

batch_transforms = Normalize(mean=(0.694, 0.695, 0.693), std=(0.299, 0.296, 0.301))

# Load doctr model
model = recognition.__dict__[args.arch](pretrained=args.pretrained, vocab=VOCABS[args.vocab])
model = recognition.__dict__[args.arch](pretrained=args.pretrained, vocab=vocab)

# Resume weights
if isinstance(args.resume, str):
Expand Down Expand Up @@ -237,24 +258,45 @@ def main(args):

st = time.time()

# Load train data generator
base_path = Path(args.train_path)
parts = [base_path] if base_path.joinpath('labels.json').is_file() else [
base_path.joinpath(sub) for sub in os.listdir(base_path)
]
train_set = RecognitionDataset(
parts[0].joinpath('images'),
parts[0].joinpath('labels.json'),
img_transforms=Compose([
T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True),
# Augmentations
T.RandomApply(T.ColorInversion(), .1),
ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.02),
]),
)
if len(parts) > 1:
for subfolder in parts[1:]:
train_set.merge_dataset(RecognitionDataset(subfolder.joinpath('images'), subfolder.joinpath('labels.json')))
if args.train_path is not None:
felixdittrich92 marked this conversation as resolved.
Show resolved Hide resolved
# Load train data generator
base_path = Path(args.train_path)
parts = [base_path] if base_path.joinpath('labels.json').is_file() else [
base_path.joinpath(sub) for sub in os.listdir(base_path)
]
with open(parts[0].joinpath('labels.json'), 'rb') as f:
train_hash = hashlib.sha256(f.read()).hexdigest()

train_set = RecognitionDataset(
parts[0].joinpath('images'),
parts[0].joinpath('labels.json'),
img_transforms=Compose([
T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True),
# Augmentations
T.RandomApply(T.ColorInversion(), .1),
ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.02),
]),
)
if len(parts) > 1:
for subfolder in parts[1:]:
train_set.merge_dataset(RecognitionDataset(
subfolder.joinpath('images'), subfolder.joinpath('labels.json')))
else:
train_hash = None
# Load synthetic data generator
train_set = WordGenerator(
vocab=vocab,
min_chars=args.min_chars,
max_chars=args.max_chars,
num_samples=args.train_samples * len(vocab),
font_family=fonts,
img_transforms=Compose([
T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True),
# Ensure we have a 90% split of white-background images
T.RandomApply(T.ColorInversion(min_val=1.0), 0.9),
felixdittrich92 marked this conversation as resolved.
Show resolved Hide resolved
ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.02),
]),
)

train_loader = DataLoader(
train_set,
Expand All @@ -267,8 +309,6 @@ def main(args):
)
print(f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in "
f"{len(train_loader)} batches)")
with open(parts[0].joinpath('labels.json'), 'rb') as f:
train_hash = hashlib.sha256(f.read()).hexdigest()
felixdittrich92 marked this conversation as resolved.
Show resolved Hide resolved

if args.show_samples:
x, target = next(iter(train_loader))
Expand Down Expand Up @@ -348,9 +388,29 @@ def parse_args():
parser = argparse.ArgumentParser(description='DocTR training script for text recognition (PyTorch)',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)

parser.add_argument('train_path', type=str, help='path to train data folder(s)')
parser.add_argument('val_path', type=str, help='path to val data folder')
parser.add_argument('arch', type=str, help='text-recognition model to train')
parser.add_argument('--train_path', type=str, default=None, help='path to train data folder(s)')
parser.add_argument('--val_path', type=str, default=None, help='path to val data folder')
parser.add_argument(
'--train-samples',
type=int,
default=1000,
help='Multiplied by the vocab length gets you the number of synthetic training samples that will be used.'
)
parser.add_argument(
'--val-samples',
type=int,
default=20,
help='Multiplied by the vocab length gets you the number of synthetic validation samples that will be used.'
)
parser.add_argument(
'--font',
type=str,
default="FreeMono.ttf,FreeSans.ttf,FreeSerif.ttf",
help='Font family to be used'
)
parser.add_argument('--min-chars', type=int, default=1, help='Minimum number of characters per synthetic sample')
parser.add_argument('--max-chars', type=int, default=30, help='Maximum number of characters per synthetic sample')
felixdittrich92 marked this conversation as resolved.
Show resolved Hide resolved
parser.add_argument('--name', type=str, default=None, help='Name of your training experiment')
parser.add_argument('--epochs', type=int, default=10, help='number of epochs to train the model on')
parser.add_argument('-b', '--batch_size', type=int, default=64, help='batch size for training')
Expand Down
Loading