diff --git a/scripts/detection/faster_rcnn/train_faster_rcnn.py b/scripts/detection/faster_rcnn/train_faster_rcnn.py index 52d2449e1e..d9d17464f0 100644 --- a/scripts/detection/faster_rcnn/train_faster_rcnn.py +++ b/scripts/detection/faster_rcnn/train_faster_rcnn.py @@ -531,12 +531,13 @@ def train(net, train_data, val_data, eval_metric, batch_size, ctx, args): metric_losses[k].append(result[k]) for k in range(len(add_losses)): add_losses[k].append(result[len(metric_losses) + k]) + trainer.step(batch_size) + for metric, record in zip(metrics, metric_losses): metric.update(0, record) for metric, records in zip(metrics2, add_losses): for pred in records: metric.update(pred[0], pred[1]) - trainer.step(batch_size) # update metrics if (not args.horovod or hvd.rank() == 0) and args.log_interval \ diff --git a/scripts/instance/mask_rcnn/train_mask_rcnn.py b/scripts/instance/mask_rcnn/train_mask_rcnn.py index 36afa78ea0..e8100e21d3 100644 --- a/scripts/instance/mask_rcnn/train_mask_rcnn.py +++ b/scripts/instance/mask_rcnn/train_mask_rcnn.py @@ -10,7 +10,7 @@ os.environ['MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD'] = '25' os.environ['MXNET_GPU_COPY_NTHREADS'] = '1' os.environ['MXNET_OPTIMIZER_AGGREGATION_SIZE'] = '54' -os.environ['MXNET_USE_FUSION'] = '0 +os.environ['MXNET_USE_FUSION'] = '0' import logging import time @@ -559,6 +559,7 @@ def train(net, train_data, val_data, eval_metric, batch_size, ctx, logger, args) metric.reset() tic = time.time() btic = time.time() + speed = [] train_data_iter = iter(train_data) next_data_batch = next(train_data_iter) next_data_batch = split_and_load(next_data_batch, ctx_list=ctx) @@ -595,12 +596,14 @@ def train(net, train_data, val_data, eval_metric, batch_size, ctx, logger, args) except StopIteration: pass + trainer.step(batch_size) + for metric, record in zip(metrics, metric_losses): metric.update(0, record) for metric, records in zip(metrics2, add_losses): for pred in records: metric.update(pred[0], pred[1]) - trainer.step(batch_size) + if (not args.horovod or hvd.rank() == 0) and args.log_interval \ and not (i + 1) % args.log_interval: msg = ','.join(['{}={:.3f}'.format(*metric.get()) for metric in metrics + metrics2])