forked from WongKinYiu/yolov7
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add option to use YOLOv5 AP metric (WongKinYiu#775)
* Add YOLOv5 metric option * Inform if using v5 metric
- Loading branch information
1 parent
b1850c7
commit 55b90e1
Showing
4 changed files
with
28 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,7 +39,8 @@ def test(data, | |
compute_loss=None, | ||
half_precision=True, | ||
trace=False, | ||
is_coco=False): | ||
is_coco=False, | ||
v5_metric=False): | ||
# Initialize/load model and set device | ||
training = model is not None | ||
if training: # called by train.py | ||
|
@@ -89,6 +90,9 @@ def test(data, | |
dataloader = create_dataloader(data[task], imgsz, batch_size, gs, opt, pad=0.5, rect=True, | ||
prefix=colorstr(f'{task}: '))[0] | ||
|
||
if v5_metric: | ||
print("Testing with YOLOv5 AP metric...") | ||
|
||
seen = 0 | ||
confusion_matrix = ConfusionMatrix(nc=nc) | ||
names = {k: v for k, v in enumerate(model.names if hasattr(model, 'names') else model.module.names)} | ||
|
@@ -217,7 +221,7 @@ def test(data, | |
# Compute statistics | ||
stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy | ||
if len(stats) and stats[0].any(): | ||
p, r, ap, f1, ap_class = ap_per_class(*stats, plot=plots, save_dir=save_dir, names=names) | ||
p, r, ap, f1, ap_class = ap_per_class(*stats, plot=plots, v5_metric=v5_metric, save_dir=save_dir, names=names) | ||
ap50, ap = ap[:, 0], ap.mean(1) # [email protected], [email protected]:0.95 | ||
mp, mr, map50, map = p.mean(), r.mean(), ap50.mean(), ap.mean() | ||
nt = np.bincount(stats[3].astype(np.int64), minlength=nc) # number of targets per class | ||
|
@@ -304,6 +308,7 @@ def test(data, | |
parser.add_argument('--name', default='exp', help='save to project/name') | ||
parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') | ||
parser.add_argument('--no-trace', action='store_true', help='don`t trace model') | ||
parser.add_argument('--v5-metric', action='store_true', help='assume maximum recall as 1.0 in AP calculation') | ||
opt = parser.parse_args() | ||
opt.save_json |= opt.data.endswith('coco.yaml') | ||
opt.data = check_file(opt.data) # check file | ||
|
@@ -325,11 +330,12 @@ def test(data, | |
save_hybrid=opt.save_hybrid, | ||
save_conf=opt.save_conf, | ||
trace=not opt.no_trace, | ||
v5_metric=opt.v5_metric | ||
) | ||
|
||
elif opt.task == 'speed': # speed benchmarks | ||
for w in opt.weights: | ||
test(opt.data, w, opt.batch_size, opt.img_size, 0.25, 0.45, save_json=False, plots=False) | ||
test(opt.data, w, opt.batch_size, opt.img_size, 0.25, 0.45, save_json=False, plots=False, v5_metric=opt.v5_metric) | ||
|
||
elif opt.task == 'study': # run over a range of settings and save/plot | ||
# python test.py --task study --data coco.yaml --iou 0.65 --weights yolov7.pt | ||
|
@@ -340,7 +346,7 @@ def test(data, | |
for i in x: # img-size | ||
print(f'\nRunning {f} point {i}...') | ||
r, _, t = test(opt.data, w, opt.batch_size, i, opt.conf_thres, opt.iou_thres, opt.save_json, | ||
plots=False) | ||
plots=False, v5_metric=opt.v5_metric) | ||
y.append(r + t) # results and times | ||
np.savetxt(f, y, fmt='%10.4g') # save | ||
os.system('zip -r study.zip study_*.txt') | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,7 +15,7 @@ def fitness(x): | |
return (x[:, :4] * w).sum(1) | ||
|
||
|
||
def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names=()): | ||
def ap_per_class(tp, conf, pred_cls, target_cls, v5_metric=False, plot=False, save_dir='.', names=()): | ||
""" Compute the average precision, given the recall and precision curves. | ||
Source: https://github.com/rafaelpadilla/Object-Detection-Metrics. | ||
# Arguments | ||
|
@@ -62,7 +62,7 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names | |
|
||
# AP from recall-precision curve | ||
for j in range(tp.shape[1]): | ||
ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j]) | ||
ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j], v5_metric=v5_metric) | ||
if plot and j == 0: | ||
py.append(np.interp(px, mrec, mpre)) # precision at [email protected] | ||
|
||
|
@@ -78,17 +78,21 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names | |
return p[:, i], r[:, i], ap, f1[:, i], unique_classes.astype('int32') | ||
|
||
|
||
def compute_ap(recall, precision): | ||
def compute_ap(recall, precision, v5_metric=False): | ||
""" Compute the average precision, given the recall and precision curves | ||
# Arguments | ||
recall: The recall curve (list) | ||
precision: The precision curve (list) | ||
v5_metric: Assume maximum recall to be 1.0, as in YOLOv5, MMDetetion etc. | ||
# Returns | ||
Average precision, precision curve, recall curve | ||
""" | ||
|
||
# Append sentinel values to beginning and end | ||
mrec = np.concatenate(([0.], recall, [recall[-1] + 0.01])) | ||
if v5_metric: # New YOLOv5 metric, same as MMDetection and Detectron2 repositories | ||
mrec = np.concatenate(([0.], recall, [1.0])) | ||
else: # Old YOLOv5 metric, i.e. default YOLOv7 metric | ||
mrec = np.concatenate(([0.], recall, [recall[-1] + 0.01])) | ||
mpre = np.concatenate(([1.], precision, [0.])) | ||
|
||
# Compute the precision envelope | ||
|