Skip to content

Commit

Permalink
feat: Added option to use AMP with TF scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
fg-mindee committed Dec 7, 2021
1 parent 1671cdd commit 5db608f
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 6 deletions.
15 changes: 13 additions & 2 deletions references/classification/train_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import tensorflow as tf
import wandb
from fastprogress.fastprogress import master_bar, progress_bar
from tensorflow.keras import mixed_precision

gpu_devices = tf.config.experimental.list_physical_devices('GPU')
if any(gpu_devices):
Expand All @@ -27,7 +28,7 @@
from utils import plot_samples


def fit_one_epoch(model, train_loader, batch_transforms, optimizer, mb):
def fit_one_epoch(model, train_loader, batch_transforms, optimizer, mb, amp=False):
train_iter = iter(train_loader)
# Iterate over the batches of the dataset
for _ in progress_bar(range(train_loader.num_batches), parent=mb):
Expand All @@ -39,6 +40,8 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, mb):
out = model(images, training=True)
train_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(targets, out)
grads = tape.gradient(train_loss, model.trainable_weights)
if amp:
grads = optimizer.get_unscaled_gradients(grads)
optimizer.apply_gradients(zip(grads, model.trainable_weights))

mb.child.comment = f'Training loss: {train_loss.numpy().mean():.6}'
Expand Down Expand Up @@ -81,6 +84,10 @@ def main(args):

vocab = VOCABS[args.vocab]

# AMP
if args.amp:
mixed_precision.set_global_policy('mixed_float16')

# Load val data generator
st = time.time()
val_set = CharacterGenerator(
Expand Down Expand Up @@ -108,6 +115,7 @@ def main(args):
num_classes=len(vocab),
include_top=True,
)

# Resume weights
if isinstance(args.resume, str):
model.load_weights(args.resume)
Expand Down Expand Up @@ -169,6 +177,8 @@ def main(args):
beta_2=0.99,
epsilon=1e-6,
)
if args.amp:
optimizer = mixed_precision.LossScaleOptimizer(optimizer)

# Tensorboard to monitor training
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
Expand Down Expand Up @@ -201,7 +211,7 @@ def main(args):
# Training loop
mb = master_bar(range(args.epochs))
for epoch in mb:
fit_one_epoch(model, train_loader, batch_transforms, optimizer, mb)
fit_one_epoch(model, train_loader, batch_transforms, optimizer, mb, args.amp)

# Validation loop at the end of each epoch
val_loss, acc = evaluate(model, val_loader, batch_transforms)
Expand Down Expand Up @@ -257,6 +267,7 @@ def parse_args():
help='Log to Weights & Biases')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
help='Load pretrained parameters before starting the training')
parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true")
args = parser.parse_args()

return args
Expand Down
14 changes: 12 additions & 2 deletions references/detection/train_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import tensorflow as tf
import wandb
from fastprogress.fastprogress import master_bar, progress_bar
from tensorflow.keras import mixed_precision

gpu_devices = tf.config.experimental.list_physical_devices('GPU')
if any(gpu_devices):
Expand All @@ -29,7 +30,7 @@
from utils import plot_samples


def fit_one_epoch(model, train_loader, batch_transforms, optimizer, mb):
def fit_one_epoch(model, train_loader, batch_transforms, optimizer, mb, amp=False):
train_iter = iter(train_loader)
# Iterate over the batches of the dataset
for images, targets in progress_bar(train_iter, parent=mb):
Expand All @@ -39,6 +40,8 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, mb):
with tf.GradientTape() as tape:
train_loss = model(images, targets, training=True)['loss']
grads = tape.gradient(train_loss, model.trainable_weights)
if amp:
grads = optimizer.get_unscaled_gradients(grads)
optimizer.apply_gradients(zip(grads, model.trainable_weights))

mb.child.comment = f'Training loss: {train_loss.numpy():.6}'
Expand Down Expand Up @@ -74,6 +77,10 @@ def main(args):
if not isinstance(args.workers, int):
args.workers = min(16, mp.cpu_count())

# AMP
if args.amp:
mixed_precision.set_global_policy('mixed_float16')

st = time.time()
val_set = DetectionDataset(
img_folder=os.path.join(args.val_path, 'images'),
Expand Down Expand Up @@ -151,6 +158,8 @@ def main(args):
epsilon=1e-6,
clipnorm=5
)
if args.amp:
optimizer = mixed_precision.LossScaleOptimizer(optimizer)

# Tensorboard to monitor training
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
Expand Down Expand Up @@ -188,7 +197,7 @@ def main(args):
# Training loop
mb = master_bar(range(args.epochs))
for epoch in mb:
fit_one_epoch(model, train_loader, batch_transforms, optimizer, mb)
fit_one_epoch(model, train_loader, batch_transforms, optimizer, mb, args.amp)
# Validation loop at the end of each epoch
val_loss, recall, precision, mean_iou = evaluate(model, val_loader, batch_transforms, val_metric)
if val_loss < min_loss:
Expand Down Expand Up @@ -240,6 +249,7 @@ def parse_args():
help='Load pretrained parameters before starting the training')
parser.add_argument('--rotation', dest='rotation', action='store_true',
help='train with rotated bbox')
parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true")
args = parser.parse_args()

return args
Expand Down
14 changes: 12 additions & 2 deletions references/recognition/train_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import tensorflow as tf
import wandb
from fastprogress.fastprogress import master_bar, progress_bar
from tensorflow.keras import mixed_precision

gpu_devices = tf.config.experimental.list_physical_devices('GPU')
if any(gpu_devices):
Expand All @@ -30,7 +31,7 @@
from utils import plot_samples


def fit_one_epoch(model, train_loader, batch_transforms, optimizer, mb):
def fit_one_epoch(model, train_loader, batch_transforms, optimizer, mb, amp=False):
train_iter = iter(train_loader)
# Iterate over the batches of the dataset
for images, targets in progress_bar(train_iter, parent=mb):
Expand All @@ -40,6 +41,8 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, mb):
with tf.GradientTape() as tape:
train_loss = model(images, targets, training=True)['loss']
grads = tape.gradient(train_loss, model.trainable_weights)
if amp:
grads = optimizer.get_unscaled_gradients(grads)
optimizer.apply_gradients(zip(grads, model.trainable_weights))

mb.child.comment = f'Training loss: {train_loss.numpy().mean():.6}'
Expand Down Expand Up @@ -76,6 +79,10 @@ def main(args):
if not isinstance(args.workers, int):
args.workers = min(16, mp.cpu_count())

# AMP
if args.amp:
mixed_precision.set_global_policy('mixed_float16')

# Load val data generator
st = time.time()
val_set = RecognitionDataset(
Expand Down Expand Up @@ -162,6 +169,8 @@ def main(args):
epsilon=1e-6,
clipnorm=5
)
if args.amp:
optimizer = mixed_precision.LossScaleOptimizer(optimizer)

# Tensorboard to monitor training
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
Expand Down Expand Up @@ -195,7 +204,7 @@ def main(args):
# Training loop
mb = master_bar(range(args.epochs))
for epoch in mb:
fit_one_epoch(model, train_loader, batch_transforms, optimizer, mb)
fit_one_epoch(model, train_loader, batch_transforms, optimizer, mb, args.amp)

# Validation loop at the end of each epoch
val_loss, exact_match, partial_match = evaluate(model, val_loader, batch_transforms, val_metric)
Expand Down Expand Up @@ -240,6 +249,7 @@ def parse_args():
help='Log to Weights & Biases')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
help='Load pretrained parameters before starting the training')
parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true")
args = parser.parse_args()

return args
Expand Down

0 comments on commit 5db608f

Please sign in to comment.