Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SynthWordGenerator to text reco training scripts #825

Merged
merged 19 commits into from
Feb 22, 2022
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/references.yml
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,10 @@ jobs:
unzip toy_recogition_set-036a4d80.zip -d reco_set
- if: matrix.framework == 'tensorflow'
name: Train for a short epoch (TF)
run: python references/recognition/train_tensorflow.py ./reco_set ./reco_set crnn_vgg16_bn -b 4 --epochs 1
run: python references/recognition/train_tensorflow.py crnn_vgg16_bn --train_path ./reco_set --val_path ./reco_set -b 4 --epochs 1
- if: matrix.framework == 'pytorch'
name: Train for a short epoch (PT)
run: python references/recognition/train_pytorch.py ./reco_set ./reco_set crnn_mobilenet_v3_small -b 4 --epochs 1
run: python references/recognition/train_pytorch.py crnn_mobilenet_v3_small --train_path ./reco_set --val_path ./reco_set -b 4 --epochs 1

latency-text-recognition:
runs-on: ${{ matrix.os }}
Expand Down
23 changes: 19 additions & 4 deletions references/recognition/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,35 @@ pip install -r references/requirements.txt

You can start your training in TensorFlow:

with own dataset:

```shell
python references/recognition/train_tensorflow.py crnn_vgg16_bn --train_path path/to/your/train_set --val_path path/to/your/val_set --epochs 5
```

with synthetic dataset:

```shell
python references/recognition/train_tensorflow.py path/to/your/train_set path/to/your/val_set crnn_vgg16_bn --epochs 5
python references/recognition/train_tensorflow.py crnn_vgg16_bn --epochs 5
felixdittrich92 marked this conversation as resolved.
Show resolved Hide resolved
```

or PyTorch:

with own dataset:

```shell
python references/recognition/train_pytorch.py path/to/your/train_set path/to/your/val_set crnn_vgg16_bn --epochs 5 --device 0
python references/recognition/train_pytorch.py crnn_vgg16_bn --train_path path/to/your/train_set --val_path path/to/your/val_set --epochs 5 --device 0
```

with synthetic dataset:

```shell
python references/recognition/train_pytorch.py crnn_vgg16_bn --epochs 5 --device 0
```

## Data format

You need to provide both `train_path` and `val_path` arguments to start training.
If you want to train with your own data you need to provide `train_path` and/or `val_path` arguments to start training.
Each of these paths must lead to a 2-elements folder:

```shell
Expand All @@ -40,7 +55,7 @@ Each of these paths must lead to a 2-elements folder:
├── labels.json
```

The JSON files must contain word-labels for each picture as a string.
The JSON files must contain word-labels for each picture as a string.
The order of entries in the json does not matter.

```shell
Expand Down
117 changes: 85 additions & 32 deletions references/recognition/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
from fastprogress.fastprogress import master_bar, progress_bar
from torch.optim.lr_scheduler import CosineAnnealingLR, MultiplicativeLR, OneCycleLR
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torchvision.transforms import ColorJitter, Compose, Normalize
from torchvision.transforms import ColorJitter, Compose, GaussianBlur, InterpolationMode, Normalize, RandomRotation

from doctr import transforms as T
from doctr.datasets import VOCABS, RecognitionDataset
from doctr.datasets import VOCABS, RecognitionDataset, WordGenerator
from doctr.models import recognition
from doctr.utils.metrics import TextMatch
from utils import plot_recorder, plot_samples
Expand Down Expand Up @@ -179,13 +179,35 @@ def main(args):

torch.backends.cudnn.benchmark = True

vocab = VOCABS[args.vocab]
fonts = args.font.split(",")
val_hash, train_hash = None, None

# Load val data generator
st = time.time()
val_set = RecognitionDataset(
img_folder=os.path.join(args.val_path, 'images'),
labels_path=os.path.join(args.val_path, 'labels.json'),
img_transforms=T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True),
)
if args.val_path is not None:
felixdittrich92 marked this conversation as resolved.
Show resolved Hide resolved
val_set = RecognitionDataset(
img_folder=os.path.join(args.val_path, 'images'),
labels_path=os.path.join(args.val_path, 'labels.json'),
img_transforms=T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True),
)
with open(os.path.join(args.val_path, 'labels.json'), 'rb') as f:
val_hash = hashlib.sha256(f.read()).hexdigest()
else:
# Load synthetic data generator
val_set = WordGenerator(
vocab=vocab,
min_chars=1,
max_chars=17,
num_samples=int(args.num_synth_samples * len(vocab) * 0.2),
font_family=fonts,
img_transforms=Compose([
T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True),
# Ensure we have a 90% split of white-background images
T.RandomApply(T.ColorInversion(min_val=1.0), 0.9),
felixdittrich92 marked this conversation as resolved.
Show resolved Hide resolved
]),
)

val_loader = DataLoader(
val_set,
batch_size=args.batch_size,
Expand All @@ -197,13 +219,11 @@ def main(args):
)
print(f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in "
f"{len(val_loader)} batches)")
with open(os.path.join(args.val_path, 'labels.json'), 'rb') as f:
val_hash = hashlib.sha256(f.read()).hexdigest()

batch_transforms = Normalize(mean=(0.694, 0.695, 0.693), std=(0.299, 0.296, 0.301))

# Load doctr model
model = recognition.__dict__[args.arch](pretrained=args.pretrained, vocab=VOCABS[args.vocab])
model = recognition.__dict__[args.arch](pretrained=args.pretrained, vocab=vocab)

# Resume weights
if isinstance(args.resume, str):
Expand Down Expand Up @@ -237,24 +257,47 @@ def main(args):

st = time.time()

# Load train data generator
base_path = Path(args.train_path)
parts = [base_path] if base_path.joinpath('labels.json').is_file() else [
base_path.joinpath(sub) for sub in os.listdir(base_path)
]
train_set = RecognitionDataset(
parts[0].joinpath('images'),
parts[0].joinpath('labels.json'),
img_transforms=Compose([
T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True),
# Augmentations
T.RandomApply(T.ColorInversion(), .1),
ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.02),
]),
)
if len(parts) > 1:
for subfolder in parts[1:]:
train_set.merge_dataset(RecognitionDataset(subfolder.joinpath('images'), subfolder.joinpath('labels.json')))
if args.train_path is not None:
felixdittrich92 marked this conversation as resolved.
Show resolved Hide resolved
# Load train data generator
base_path = Path(args.train_path)
parts = [base_path] if base_path.joinpath('labels.json').is_file() else [
base_path.joinpath(sub) for sub in os.listdir(base_path)
]
train_set = RecognitionDataset(
parts[0].joinpath('images'),
parts[0].joinpath('labels.json'),
img_transforms=Compose([
T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True),
# Augmentations
T.RandomApply(T.ColorInversion(), .1),
ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.02),
]),
)
if len(parts) > 1:
for subfolder in parts[1:]:
train_set.merge_dataset(RecognitionDataset(
subfolder.joinpath('images'), subfolder.joinpath('labels.json')))
with open(parts[0].joinpath('labels.json'), 'rb') as f:
train_hash = hashlib.sha256(f.read()).hexdigest()
else:
# Load synthetic data generator
train_set = WordGenerator(
vocab=vocab,
min_chars=1,
max_chars=17,
felixdittrich92 marked this conversation as resolved.
Show resolved Hide resolved
num_samples=args.num_synth_samples * len(vocab),
font_family=fonts,
img_transforms=Compose([
T.RandomApply(RandomRotation(
[-6, 6], interpolation=InterpolationMode.BILINEAR, expand=True), .2),
T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True),
# Ensure we have a 90% split of white-background images
T.RandomApply(T.ColorInversion(min_val=1.0), 0.9),
felixdittrich92 marked this conversation as resolved.
Show resolved Hide resolved
T.RandomApply(GaussianBlur(
kernel_size=3, sigma=(0.3, 2.0)), .2),
ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.02),
felixdittrich92 marked this conversation as resolved.
Show resolved Hide resolved
]),
)

train_loader = DataLoader(
train_set,
Expand All @@ -267,8 +310,6 @@ def main(args):
)
print(f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in "
f"{len(train_loader)} batches)")
with open(parts[0].joinpath('labels.json'), 'rb') as f:
train_hash = hashlib.sha256(f.read()).hexdigest()
felixdittrich92 marked this conversation as resolved.
Show resolved Hide resolved

if args.show_samples:
x, target = next(iter(train_loader))
Expand Down Expand Up @@ -348,9 +389,21 @@ def parse_args():
parser = argparse.ArgumentParser(description='DocTR training script for text recognition (PyTorch)',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)

parser.add_argument('train_path', type=str, help='path to train data folder(s)')
parser.add_argument('val_path', type=str, help='path to val data folder')
parser.add_argument('arch', type=str, help='text-recognition model to train')
parser.add_argument('--train_path', type=str, default=None, help='path to train data folder(s)')
parser.add_argument('--val_path', type=str, default=None, help='path to val data folder')
parser.add_argument(
'--num_synth_samples',
type=int,
default=1000,
help='Multiplied by the vocab length gets you the number samples that will be used.'
)
felixdittrich92 marked this conversation as resolved.
Show resolved Hide resolved
parser.add_argument(
'--font',
type=str,
default="FreeMono.ttf,FreeSans.ttf,FreeSerif.ttf",
help='Font family to be used'
)
parser.add_argument('--name', type=str, default=None, help='Name of your training experiment')
parser.add_argument('--epochs', type=int, default=10, help='number of epochs to train the model on')
parser.add_argument('-b', '--batch_size', type=int, default=64, help='batch size for training')
Expand Down
Loading