diff --git a/.dev_scripts/benchmark_inference_fps.py b/.dev_scripts/benchmark_inference_fps.py index b45825925de..5d46d52a8e3 100644 --- a/.dev_scripts/benchmark_inference_fps.py +++ b/.dev_scripts/benchmark_inference_fps.py @@ -6,7 +6,8 @@ import mmcv from mmcv import Config, DictAction from mmcv.runner import init_dist -from tools.analysis_tools.benchmark import measure_inferense_speed +from terminaltables import GithubFlavoredMarkdownTable +from tools.analysis_tools.benchmark import repeat_measure_inference_speed def parse_args(): @@ -19,12 +20,17 @@ def parse_args(): type=int, default=1, help='round a number to a given precision in decimal digits') + parser.add_argument( + '--repeat-num', + type=int, + default=1, + help='number of repeat times of measurement for averaging the results') parser.add_argument( '--out', type=str, help='output path of gathered fps to be stored') parser.add_argument( - '--max-iter', type=int, default=400, help='num of max iter') + '--max-iter', type=int, default=2000, help='num of max iter') parser.add_argument( - '--log-interval', type=int, default=40, help='interval of logging') + '--log-interval', type=int, default=50, help='interval of logging') parser.add_argument( '--fuse-conv-bn', action='store_true', @@ -52,9 +58,43 @@ def parse_args(): return args +def results2markdown(result_dict): + table_data = [] + is_multiple_results = False + for cfg_name, value in result_dict.items(): + name = cfg_name.replace('configs/', '') + fps = value['fps'] + ms_times_pre_image = value['ms_times_pre_image'] + if isinstance(fps, list): + is_multiple_results = True + mean_fps = value['mean_fps'] + mean_times_pre_image = value['mean_times_pre_image'] + fps_str = ','.join([str(s) for s in fps]) + ms_times_pre_image_str = ','.join( + [str(s) for s in ms_times_pre_image]) + table_data.append([ + name, fps_str, mean_fps, ms_times_pre_image_str, + mean_times_pre_image + ]) + else: + table_data.append([name, fps, ms_times_pre_image]) + + if is_multiple_results: + table_data.insert(0, [ + 'model', 'fps', 'mean_fps', 'times_pre_image(ms)', + 'mean_times_pre_image(ms)' + ]) + + else: + table_data.insert(0, ['model', 'fps', 'times_pre_image(ms)']) + table = GithubFlavoredMarkdownTable(table_data) + print(table.table, flush=True) + + if __name__ == '__main__': args = parse_args() assert args.round_num >= 0 + assert args.repeat_num >= 1 config = Config.fromfile(args.config) @@ -75,20 +115,55 @@ def parse_args(): checkpoint = osp.join(args.checkpoint_root, model_info['checkpoint'].strip()) try: - fps = measure_inferense_speed(cfg, checkpoint, args.max_iter, - args.log_interval, - args.fuse_conv_bn) - print( - f'{cfg_path} fps : {fps:.{args.round_num}f} img / s, ' - f'times per image: {1000/fps:.{args.round_num}f} ms / img', - flush=True) - result_dict[cfg_path] = dict( - fps=round(fps, args.round_num), - ms_times_pre_image=round(1000 / fps, args.round_num)) + fps = repeat_measure_inference_speed(cfg, checkpoint, + args.max_iter, + args.log_interval, + args.fuse_conv_bn, + args.repeat_num) + if args.repeat_num > 1: + fps_list = [round(fps_, args.round_num) for fps_ in fps] + times_pre_image_list = [ + round(1000 / fps_, args.round_num) for fps_ in fps + ] + mean_fps = round( + sum(fps_list) / len(fps_list), args.round_num) + mean_times_pre_image = round( + sum(times_pre_image_list) / len(times_pre_image_list), + args.round_num) + print( + f'{cfg_path} ' + f'Overall fps: {fps_list}[{mean_fps}] img / s, ' + f'times per image: ' + f'{times_pre_image_list}[{mean_times_pre_image}] ' + f'ms / img', + flush=True) + result_dict[cfg_path] = dict( + fps=fps_list, + mean_fps=mean_fps, + ms_times_pre_image=times_pre_image_list, + mean_times_pre_image=mean_times_pre_image) + else: + print( + f'{cfg_path} fps : {fps:.{args.round_num}f} img / s, ' + f'times per image: {1000 / fps:.{args.round_num}f} ' + f'ms / img', + flush=True) + result_dict[cfg_path] = dict( + fps=round(fps, args.round_num), + ms_times_pre_image=round(1000 / fps, args.round_num)) except Exception as e: - print(f'{config} error: {repr(e)}') - result_dict[cfg_path] = 0 + print(f'{cfg_path} error: {repr(e)}') + if args.repeat_num > 1: + result_dict[cfg_path] = dict( + fps=[0], + mean_fps=0, + ms_times_pre_image=[0], + mean_times_pre_image=0) + else: + result_dict[cfg_path] = dict(fps=0, ms_times_pre_image=0) if args.out: mmcv.mkdir_or_exist(args.out) mmcv.dump(result_dict, osp.join(args.out, 'batch_inference_fps.json')) + + results2markdown(result_dict) diff --git a/docs/1_exist_data_model.md b/docs/1_exist_data_model.md index 1e61a287210..fc7c286f472 100644 --- a/docs/1_exist_data_model.md +++ b/docs/1_exist_data_model.md @@ -283,7 +283,7 @@ Optional arguments: ### Examples -Assume that you have already downloaded the checkpoints to the directory `checkpoints/`. +Assuming that you have already downloaded the checkpoints to the directory `checkpoints/`. 1. Test Faster R-CNN and visualize the results. Press any key for the next image. Config and checkpoint files are available [here](https://github.com/open-mmlab/mmdetection/tree/master/configs/faster_rcnn). diff --git a/docs/useful_tools.md b/docs/useful_tools.md index a9472402db7..99f86966208 100644 --- a/docs/useful_tools.md +++ b/docs/useful_tools.md @@ -377,10 +377,35 @@ python tools/dataset_converters/cityscapes.py ${CITYSCAPES_PATH} [-h] [--img-dir python tools/dataset_converters/pascal_voc.py ${DEVKIT_PATH} [-h] [-o ${OUT_DIR}] ``` -## Robust Detection Benchmark +## Benchmark + +### Robust Detection Benchmark `tools/analysis_tools/test_robustness.py` and`tools/analysis_tools/robustness_eval.py` helps users to evaluate model robustness. The core idea comes from [Benchmarking Robustness in Object Detection: Autonomous Driving when Winter is Coming](https://arxiv.org/abs/1907.07484). For more information how to evaluate models on corrupted images and results for a set of standard models please refer to [robustness_benchmarking.md](robustness_benchmarking.md). +### FPS Benchmark + +`tools/analysis_tools/benchmark.py` helps users to calculate FPS. The FPS value includes model forward and post-processing. In order to get a more accurate value, currently only supports single GPU distributed startup mode. + +```shell +python -m torch.distributed.launch --nproc_per_node=1 --master_port=${PORT} tools/analysis_tools/benchmark.py \ + ${CONFIG} \ + ${CHECKPOINT} \ + [--repeat-num ${REPEAT_NUM}] \ + [--max-iter ${MAX_ITER}] \ + [--log-interval ${LOG_INTERVAL}] \ + --launcher pytorch +``` + +Examples: Assuming that you have already downloaded the `Faster R-CNN` model checkpoint to the directory `checkpoints/`. + +```shell +python -m torch.distributed.launch --nproc_per_node=1 --master_port=29500 tools/analysis_tools/benchmark.py \ + configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py \ + checkpoints/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth \ + --launcher pytorch +``` + ## Miscellaneous ### Evaluating a metric diff --git a/tools/analysis_tools/benchmark.py b/tools/analysis_tools/benchmark.py index 70b18180193..91f34c74063 100644 --- a/tools/analysis_tools/benchmark.py +++ b/tools/analysis_tools/benchmark.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import argparse +import copy import os import time @@ -18,6 +19,11 @@ def parse_args(): parser = argparse.ArgumentParser(description='MMDet benchmark a model') parser.add_argument('config', help='test config file path') parser.add_argument('checkpoint', help='checkpoint file') + parser.add_argument( + '--repeat-num', + type=int, + default=1, + help='number of repeat times of measurement for averaging the results') parser.add_argument( '--max-iter', type=int, default=2000, help='num of max iter') parser.add_argument( @@ -49,7 +55,7 @@ def parse_args(): return args -def measure_inferense_speed(cfg, checkpoint, max_iter, log_interval, +def measure_inference_speed(cfg, checkpoint, max_iter, log_interval, is_fuse_conv_bn): # set cudnn_benchmark if cfg.get('cudnn_benchmark', False): @@ -66,7 +72,10 @@ def measure_inferense_speed(cfg, checkpoint, max_iter, log_interval, data_loader = build_dataloader( dataset, samples_per_gpu=1, - workers_per_gpu=cfg.data.workers_per_gpu, + # Because multiple processes will occupy additional CPU resources, + # FPS statistics will be more unstable when workers_per_gpu is not 0. + # It is reasonable to set workers_per_gpu to 0. + workers_per_gpu=0, dist=True, shuffle=False) @@ -123,6 +132,40 @@ def measure_inferense_speed(cfg, checkpoint, max_iter, log_interval, return fps +def repeat_measure_inference_speed(cfg, + checkpoint, + max_iter, + log_interval, + is_fuse_conv_bn, + repeat_num=1): + assert repeat_num >= 1 + + fps_list = [] + + for _ in range(repeat_num): + # + cp_cfg = copy.deepcopy(cfg) + + fps_list.append( + measure_inference_speed(cp_cfg, checkpoint, max_iter, log_interval, + is_fuse_conv_bn)) + + if repeat_num > 1: + fps_list_ = [round(fps, 1) for fps in fps_list] + times_pre_image_list_ = [round(1000 / fps, 1) for fps in fps_list] + mean_fps_ = sum(fps_list_) / len(fps_list_) + mean_times_pre_image_ = sum(times_pre_image_list_) / len( + times_pre_image_list_) + print( + f'Overall fps: {fps_list_}[{mean_fps_:.1f}] img / s, ' + f'times per image: ' + f'{times_pre_image_list_}[{mean_times_pre_image_:.1f}] ms / img', + flush=True) + return fps_list + + return fps_list[0] + + def main(): args = parse_args() @@ -135,8 +178,9 @@ def main(): else: init_dist(args.launcher, **cfg.dist_params) - measure_inferense_speed(cfg, args.checkpoint, args.max_iter, - args.log_interval, args.fuse_conv_bn) + repeat_measure_inference_speed(cfg, args.checkpoint, args.max_iter, + args.log_interval, args.fuse_conv_bn, + args.repeat_num) if __name__ == '__main__':