-
Notifications
You must be signed in to change notification settings - Fork 95
/
Copy patheval_func.py
624 lines (537 loc) · 33.3 KB
/
eval_func.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
import os
import numpy as np
from keras_cv_attention_models import backend
from keras_cv_attention_models.backend import layers, models, functional, callbacks
from keras_cv_attention_models.coco import anchors_func, info
from keras_cv_attention_models.models import no_grad_if_torch
from tqdm import tqdm
@backend.register_keras_serializable(package="kecam/coco")
class DecodePredictions(layers.Layer):
"""
The most simple version decoding prediction and NMS:
>>> from keras_cv_attention_models import efficientdet, test_images
>>> model = efficientdet.EfficientDetD0()
>>> preds = model(model.preprocess_input(test_images.dog()))
# Decode and NMS
>>> from keras_cv_attention_models import coco
>>> input_shape = model.input_shape[1:-1]
>>> anchors = coco.get_anchors(input_shape=input_shape, pyramid_levels=[3, 7], anchor_scale=4)
>>> dd = coco.decode_bboxes(preds[0], anchors).numpy()
>>> rr = tf.image.non_max_suppression(dd[:, :4], dd[:, 4:].max(-1), score_threshold=0.3, max_output_size=15, iou_threshold=0.5)
>>> dd_nms = tf.gather(dd, rr).numpy()
>>> bboxes, labels, scores = dd_nms[:, :4], dd_nms[:, 4:].argmax(-1), dd_nms[:, 4:].max(-1)
>>> print(f"{bboxes = }, {labels = }, {scores = }")
>>> # bboxes = array([[0.433231 , 0.54432285, 0.8778939 , 0.8187578 ]], dtype=float32), labels = array([17]), scores = array([0.85373735], dtype=float32)
"""
def __init__(
self,
input_shape=512,
pyramid_levels=[3, 7],
anchors_mode=None,
use_object_scores="auto",
anchor_scale="auto",
aspect_ratios=(1, 2, 0.5),
num_scales=3,
regression_len=4, # bbox output len, typical value is 4, for yolov8 reg_max=16 -> regression_len = 16 * 4 == 64
score_threshold=0.3, # decode parameter, can be set new value in `self.call`
iou_or_sigma=0.5, # decode parameter, can be set new value in `self.call`
max_output_size=100, # decode parameter, can be set new value in `self.call`
method="hard", # decode parameter, can be set new value in `self.call`
mode="global", # decode parameter, can be set new value in `self.call`
topk=0, # decode parameter, can be set new value in `self.call`
use_static_output=False, # Set to True if using this as an actual layer, especially for converting tflite
use_sigmoid_on_score=False, # wether applying sigmoid on score outputs. Set True if model is built using `classifier_activation=None`
num_masks=0, # Set > 0 value for segmentation model with masks output
**kwargs,
):
super().__init__(**kwargs)
self.pyramid_levels = list(range(min(pyramid_levels), max(pyramid_levels) + 1))
use_object_scores, num_anchors, anchor_scale = anchors_func.get_anchors_mode_parameters(anchors_mode, use_object_scores, "auto", anchor_scale)
self.regression_len, self.aspect_ratios, self.num_scales, self.num_masks = regression_len, aspect_ratios, num_scales, num_masks
self.anchors_mode, self.use_object_scores, self.anchor_scale = anchors_mode, use_object_scores, anchor_scale # num_anchors not using
if input_shape is not None and (isinstance(input_shape, (list, tuple)) and input_shape[1] is not None):
self.__init_anchor__(input_shape)
else:
self.anchors = None
self.__input_shape__ = input_shape
self.use_static_output, self.use_sigmoid_on_score = use_static_output, use_sigmoid_on_score
self.nms_kwargs = {
"score_threshold": score_threshold,
"iou_or_sigma": iou_or_sigma,
"max_output_size": max_output_size,
"method": method,
"mode": mode,
"topk": topk,
}
super().build(input_shape)
def __init_anchor__(self, input_shape):
if isinstance(input_shape, (list, tuple)) and len(input_shape) > 2:
# input_shape = input_shape[:2] if backend.image_data_format() == "channels_last" else input_shape[-2:]
channel_axis, channel_dim = min(enumerate(input_shape), key=lambda xx: xx[1]) # Assume the smallest value is the channel dimension
input_shape = [dim for axis, dim in enumerate(input_shape) if axis != channel_axis]
elif isinstance(input_shape, int):
input_shape = (input_shape, input_shape)
if self.anchors_mode == anchors_func.ANCHOR_FREE_MODE:
self.anchors = anchors_func.get_anchor_free_anchors(input_shape, self.pyramid_levels)
elif self.anchors_mode == anchors_func.YOLOR_MODE:
self.anchors = anchors_func.get_yolor_anchors(input_shape, self.pyramid_levels)
elif self.anchors_mode == anchors_func.YOLOV8_MODE:
self.anchors = anchors_func.get_anchor_free_anchors(input_shape, self.pyramid_levels, grid_zero_start=False)
else:
grid_zero_start = False
self.anchors = anchors_func.get_anchors(input_shape, self.pyramid_levels, self.aspect_ratios, self.num_scales, self.anchor_scale, grid_zero_start)
self.__input_shape__ = input_shape
return self.anchors
def __topk_class_boxes_single__(self, pred, topk=5000):
# https://github.com/google/automl/tree/master/efficientdet/tf2/postprocess.py#L82
bbox_outputs, class_outputs = pred[:, : self.regression_len], pred[:, self.regression_len :]
num_classes = class_outputs.shape[-1]
class_outputs_flatten = functional.reshape(class_outputs, -1)
topk = class_outputs_flatten.shape[0] if topk == -1 else min(topk, class_outputs_flatten.shape[0]) # select all if -1
_, class_topk_indices = functional.top_k(class_outputs_flatten, k=topk, sorted=False)
# get original indices for class_outputs, original_indices_hh -> picking indices, original_indices_ww -> picked labels
original_indices_hh, original_indices_ww = class_topk_indices // num_classes, class_topk_indices % num_classes
class_indices = functional.stack([original_indices_hh, original_indices_ww], axis=-1)
scores_topk = functional.gather_nd(class_outputs, class_indices)
bboxes_topk = functional.gather(bbox_outputs, original_indices_hh)
return bboxes_topk, scores_topk, original_indices_ww, original_indices_hh
# def __nms_per_class__(self, bbs, ccs, labels, score_threshold=0.3, iou_threshold=0.5, soft_nms_sigma=0.5, max_output_size=100):
# # https://github.com/google/automl/tree/master/efficientdet/tf2/postprocess.py#L409
# # Not using, same result with `torchvision.ops.batched_nms`
# rrs = []
# for ii in tf.unique(labels)[0]:
# pick = tf.where(labels == ii)
# bb, cc = tf.gather_nd(bbs, pick), tf.gather_nd(ccs, pick)
# rr, nms_scores = tf.image.non_max_suppression_with_scores(bb, cc, max_output_size, iou_threshold, score_threshold, soft_nms_sigma)
# bb_nms = tf.gather(bb, rr)
# rrs.append(tf.concat([bb_nms, tf.ones([bb_nms.shape[0], 1]) * tf.cast(ii, bb_nms.dtype), tf.expand_dims(nms_scores, 1)], axis=-1))
# rrs = tf.concat(rrs, axis=0)
# if tf.shape(rrs)[0] > max_output_size:
# score_top_k = tf.argsort(rrs[:, -1], direction="DESCENDING")[:max_output_size]
# rrs = tf.gather(rrs, score_top_k)
# bboxes, labels, scores = rrs[:, :4], rrs[:, 4], rrs[:, -1]
# return bboxes.numpy(), labels.numpy(), scores.numpy()
@staticmethod
def nms_per_class(bbs, ccs, labels, score_threshold=0.3, iou_threshold=0.5, soft_nms_sigma=0.5, max_output_size=100):
# From torchvision.ops.batched_nms strategy: in order to perform NMS independently per class. we add an offset to all the boxes.
# The offset is dependent only on the class idx, and is large enough so that boxes from different classes do not overlap
# Same result with per_class method: https://github.com/google/automl/tree/master/efficientdet/tf2/postprocess.py#L409
cls_offset = functional.cast(labels, bbs.dtype) * (functional.reduce_max(bbs) + 1)
bbs_per_class = bbs + functional.expand_dims(cls_offset, -1)
indices, nms_scores = functional.non_max_suppression_with_scores(bbs_per_class, ccs, max_output_size, iou_threshold, score_threshold, soft_nms_sigma)
return functional.gather(bbs, indices), functional.gather(labels, indices), nms_scores, indices
@staticmethod
def nms_global(bbs, ccs, labels, score_threshold=0.3, iou_threshold=0.5, soft_nms_sigma=0.5, max_output_size=100):
indices, nms_scores = functional.non_max_suppression_with_scores(bbs, ccs, max_output_size, iou_threshold, score_threshold, soft_nms_sigma)
return functional.gather(bbs, indices), functional.gather(labels, indices), nms_scores, indices
def __object_score_split__(self, pred):
return pred[:, :-1], pred[:, -1] # May overwrite
def __to_static__(self, bboxs, lables, confidences, masks=None, max_output_size=100):
indices = functional.expand_dims(functional.range(functional.shape(bboxs)[0]), -1)
lables = functional.cast(lables, bboxs.dtype)
if masks is None:
concated = functional.concat([bboxs, functional.expand_dims(lables, -1), functional.expand_dims(confidences, -1)], axis=-1)
else:
masks = functional.reshape(functional.cast(masks, bboxs.dtype), [-1, masks.shape[1] * masks.shape[2]])
concated = functional.concat([bboxs, functional.expand_dims(lables, -1), functional.expand_dims(confidences, -1), masks], axis=-1)
concated = functional.tensor_scatter_nd_update(functional.zeros([max_output_size, concated.shape[-1]], dtype=bboxs.dtype), indices, concated)
return concated
@staticmethod
def process_mask_proto_single(mask_proto, masks, bboxs):
# mask_proto: [input_height // 4, input_width // 4, 32], masks: [num, 32], bboxs: [num, 4]
protos_height, protos_width = mask_proto.shape[:2]
mask_proto = functional.transpose(functional.reshape(mask_proto, [-1, mask_proto.shape[-1]]), [1, 0])
masks = functional.sigmoid(masks @ mask_proto) # [num, protos_height * protos_width]
masks = functional.reshape(masks, [-1, protos_height, protos_width]) # [num, protos_height, protos_width]
""" Filter by bbox area """
top, left, bottom, right = functional.split(bboxs[:, :, None], [1, 1, 1, 1], axis=1) # [num, 1_pos, 1]
height_range = functional.range(protos_height, dtype=top.dtype)[None, :, None] / protos_height # [1, protos_height, 1]
width_range = functional.range(protos_width, dtype=top.dtype)[None, None] / protos_width # [1, 1, protos_width]
height_cond = functional.logical_and(height_range >= top, height_range < bottom) # [num, protos_height, 1]
width_cond = functional.logical_and(width_range >= left, width_range < right) # [num, 1, protos_width]
masks *= functional.cast(functional.logical_and(height_cond, width_cond), masks.dtype) # [num, protos_height, protos_width]
return masks
def __decode_single__(
self, pred, mask_proto=None, score_threshold=0.3, iou_or_sigma=0.5, max_output_size=100, method="hard", mode="global", topk=0, input_shape=None
):
# https://github.com/google/automl/tree/master/efficientdet/tf2/postprocess.py#L159
pred = functional.cast(pred.detach() if hasattr(pred, "detach") else pred, "float32")
if input_shape is not None:
self.__init_anchor__(input_shape)
if self.num_masks > 0: # Segmentation masks
pred, masks = pred[:, : -self.num_masks], pred[:, -self.num_masks :]
else:
masks = None
if self.use_object_scores: # YOLO outputs: [bboxes, classses_score, object_score]
pred, object_scores = self.__object_score_split__(pred)
if topk != 0:
bbs, ccs, labels, picking_indices = self.__topk_class_boxes_single__(pred, topk)
anchors = functional.gather(self.anchors, picking_indices)
if self.use_object_scores:
ccs = ccs * functional.gather(object_scores, picking_indices)
else:
bbs, scores = pred[:, : self.regression_len], pred[:, self.regression_len :]
ccs, labels = functional.reduce_max(scores, axis=-1), functional.argmax(scores, axis=-1)
anchors = self.anchors
if self.use_object_scores:
ccs = ccs * object_scores
ccs = functional.sigmoid(ccs) if self.use_sigmoid_on_score else ccs
# print(f"{bbs.shape = }, {anchors.shape = }")
bbs_decoded = anchors_func.decode_bboxes(bbs, anchors, regression_len=self.regression_len)
iou_threshold, soft_nms_sigma = (1.0, iou_or_sigma / 2) if method.lower() == "gaussian" else (iou_or_sigma, 0.0)
if mode == "per_class":
bboxs, lables, confidences, indices = self.nms_per_class(bbs_decoded, ccs, labels, score_threshold, iou_threshold, soft_nms_sigma, max_output_size)
elif mode == "global":
bboxs, lables, confidences, indices = self.nms_global(bbs_decoded, ccs, labels, score_threshold, iou_threshold, soft_nms_sigma, max_output_size)
else:
bboxs, lables, confidences, indices = bbs_decoded, labels, ccs, None # Return raw decoded data for testing
if self.num_masks > 0 and indices is not None: # Segmentation masks
masks = functional.gather(masks, indices)
masks = self.process_mask_proto_single(mask_proto, masks, bboxs)
if self.use_static_output:
return self.__to_static__(bboxs, lables, confidences, masks, max_output_size)
elif self.num_masks > 0:
return bboxs, lables, confidences, masks
else:
return bboxs, lables, confidences
def call(self, preds, mask_protos=None, input_shape=None, training=False, **nms_kwargs):
"""
https://github.com/google/automl/tree/master/efficientdet/tf2/postprocess.py#L159
mask_protos: mask output from segmentation model.
input_shape: actual input shape if model using dynamic input shape `[None, None, 3]`.
nms_kwargs:
score_threshold: float value in (0, 1), min score threshold, lower output score will be excluded. Default 0.3.
iou_or_sigma: means `soft_nms_sigma` if method is "gaussian", else `iou_threshold`. Default 0.5.
max_output_size: max output size for `tf.image.non_max_suppression_with_scores`. Default 100.
If use_static_output=True, fixed output shape will be `[batch, max_output_size, 6]`.
method: "gaussian" or "hard". Default "hard".
mode: "global" or "per_class". "per_class" is strategy from `torchvision.ops.batched_nms`. Default "global".
topk: Using topk highest scores, each bbox may have multi labels. Set `0` to disable, `-1` using all. Default 0.
"""
self.nms_kwargs.update(nms_kwargs)
if self.num_masks > 0: # Segmentation model
assert mask_protos is not None, "self.num_masks={} > 0, but mask_protos not provided".format(self.num_masks)
if self.use_static_output and self.num_masks > 0: # Segmentation model
return functional.map_fn(lambda xx: self.__decode_single__(xx[0], xx[1], **nms_kwargs), [preds, mask_protos], fn_output_signature=preds.dtype)
elif self.use_static_output:
return functional.map_fn(lambda xx: self.__decode_single__(xx, **nms_kwargs), preds)
elif len(preds.shape) == 3 and self.num_masks > 0: # Segmentation model
return [self.__decode_single__(pred, mask_proto, **self.nms_kwargs, input_shape=input_shape) for pred, mask_proto in zip(preds, mask_protos)]
elif len(preds.shape) == 3:
return [self.__decode_single__(pred, **self.nms_kwargs, input_shape=input_shape) for pred in preds]
else:
return self.__decode_single__(preds, mask_protos, **self.nms_kwargs, input_shape=input_shape)
def get_config(self):
config = super().get_config()
config.update(
{
"input_shape": self.__input_shape__,
"pyramid_levels": self.pyramid_levels,
"anchors_mode": self.anchors_mode,
"use_object_scores": self.use_object_scores,
"anchor_scale": self.anchor_scale,
"aspect_ratios": self.aspect_ratios,
"num_scales": self.num_scales,
"use_static_output": self.use_static_output,
"use_sigmoid_on_score": self.use_sigmoid_on_score,
"num_masks": self.num_masks,
}
)
config.update(self.nms_kwargs)
return config
""" COCO Evaluation """
def scale_bboxes_back_single(bboxes, image_shape, scale, pad_top, pad_left, target_shape):
# height, width = target_shape[0] / scale, target_shape[1] / scale
# bboxes *= [height, width, height, width]
bboxes *= [target_shape[0], target_shape[1], target_shape[0], target_shape[1]]
bboxes -= [pad_top, pad_left, pad_top, pad_left]
bboxes /= scale
clip_value_max = functional.convert_to_tensor([image_shape[0], image_shape[1], image_shape[0], image_shape[1]], dtype="float32")
bboxes = functional.clip_by_value(bboxes, 0, clip_value_max=clip_value_max)
# [top, left, bottom, right] -> [left, top, width, height]
bboxes = functional.stack([bboxes[:, 1], bboxes[:, 0], bboxes[:, 3] - bboxes[:, 1], bboxes[:, 2] - bboxes[:, 0]], axis=-1)
return bboxes
def image_process(image, target_shape, mean, std, resize_method="bilinear", resize_antialias=False, use_bgr_input=False, letterbox_pad=-1):
if backend.is_tensorflow_backend:
from keras_cv_attention_models.coco.tf_data import tf_imread as imread, aspect_aware_resize_and_crop_image
else:
import cv2
from keras_cv_attention_models.coco.torch_data import aspect_aware_resize_and_crop_image
imread = lambda image_path: cv2.imread(image_path)[:, :, ::-1] # BGR -> RGB
if isinstance(image, str) or len(image.shape) < 2:
image = imread(image) # it's image path
if backend.is_tensorflow_backend:
original_image_shape = functional.shape(image)[:2]
image = functional.cast(image, "float32")
else:
original_image_shape, image = image.shape[:2], image.astype("float32")
image, scale, pad_top, pad_left = aspect_aware_resize_and_crop_image(
image, target_shape, letterbox_pad=letterbox_pad, method=resize_method, antialias=resize_antialias
)
image = (image - mean) / std # automl behavior: rescale -> resize
if use_bgr_input:
image = image[:, :, ::-1]
return image, scale, pad_top, pad_left, original_image_shape
def init_eval_dataset(
data_name="coco/2017",
input_shape=(512, 512),
batch_size=8,
rescale_mode="torch",
resize_method="bilinear",
resize_antialias=False,
letterbox_pad=-1,
use_bgr_input=False,
):
if backend.is_tensorflow_backend:
from keras_cv_attention_models.coco import tf_data
if data_name.endswith(".json"):
dataset, _, num_classes = tf_data.detection_dataset_from_custom_json(data_name, with_info=True)
else:
import tensorflow_datasets as tfds
dataset, info = tfds.load(data_name, with_info=True)
num_classes = info.features["objects"]["label"].num_classes
ds = dataset.get("validation", dataset.get("test", None))
mean, std = tf_data.init_mean_std_by_rescale_mode(rescale_mode)
__image_process__ = lambda image: image_process(image, input_shape, mean, std, resize_method, resize_antialias, use_bgr_input, letterbox_pad)
# ds: [resized_image, scale, pad_top, pad_left, original_image_shape, image_id]
ds = ds.map(lambda datapoint: (*__image_process__(datapoint["image"]), datapoint.get("image/id", datapoint["image"])))
ds = ds.batch(batch_size)
return ds, num_classes
else:
import torch
from torch.utils.data import Dataset, DataLoader
from keras_cv_attention_models.coco import torch_data
_, test, total_images, num_classes = torch_data.load_from_custom_json(data_name)
mean, std = torch_data.init_mean_std_by_rescale_mode(rescale_mode, convert_to_image_data_format=False)
class EvalDataset(Dataset):
def __len__(self):
return len(test)
def __getitem__(self, index):
image_path = test[index]["image"]
image, scale, pad_top, pad_left, original_image_shape = image_process(
image_path, input_shape, mean, std, resize_method, resize_antialias, use_bgr_input, letterbox_pad
)
image = torch.from_numpy(image).permute([2, 0, 1]).contiguous()
return image, scale, pad_top, pad_left, torch.as_tensor(original_image_shape), image_path
ds = DataLoader(EvalDataset(), batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True, sampler=None, drop_last=False)
# ds.element_spec = next(iter(ds))
return ds, num_classes
def model_detection_and_decode(model, eval_dataset, pred_decoder, nms_kwargs={}, is_coco=True, image_id_map=None, num_classes=80):
sample_image = next(iter(eval_dataset))[0]
target_shape = sample_image.shape[1:-1] if backend.image_data_format() == "channels_last" else sample_image.shape[2:]
# num_classes = model.output_shape[-1] - 4
if is_coco:
to_91_labels = (lambda label: label + 1) if num_classes >= 90 else (lambda label: info.COCO_80_to_90_LABEL_DICT[label] + 1)
else:
to_91_labels = lambda label: label
# Format: [image_id, x, y, width, height, score, class]
to_coco_eval_single = lambda image_id, bbox, label, score: [image_id, *bbox.tolist(), score, to_91_labels(label)]
results = []
for images, scales, pad_tops, pad_lefts, original_image_shapes, image_ids in tqdm(eval_dataset):
preds = model.predict(images).cpu().float() if backend.is_torch_backend else functional.cast(model(images), "float32")
# decoded_preds: [[bboxes, labels, scores], [bboxes, labels, scores], ...]
decoded_preds = pred_decoder(preds, **nms_kwargs)
# Loop on batch
for rr, image_shape, scale, pad_top, pad_left, image_id in zip(decoded_preds, original_image_shapes, scales, pad_tops, pad_lefts, image_ids):
bboxes, labels, scores = rr
image_id, bboxes, labels, scores = np.array(image_id).item(), bboxes.numpy(), labels.numpy(), scores.numpy()
if image_id_map is not None:
image_id = image_id_map[image_id.decode() if isinstance(image_id, bytes) else image_id]
bboxes = scale_bboxes_back_single(bboxes, image_shape, scale, pad_top, pad_left, target_shape).numpy()
results.extend([to_coco_eval_single(image_id, bb, cc, ss) for bb, cc, ss in zip(bboxes, labels, scores)]) # Loop on prediction results
return np.array(results)
class COCOEvaluation:
def __init__(self, annotations=None):
from pycocotools.coco import COCO
if annotations is None:
url = "https://github.com/leondgarse/keras_cv_attention_models/releases/download/assets/coco_annotations_instances_val2017.json"
file_hash = "b681580a54b900b3cb44022fd1102ad5"
annotations = backend.get_file(origin=url, file_hash=file_hash)
if isinstance(annotations, dict): # json already loaded as dict
coco_gt = COCO()
coco_gt.dataset = annotations
coco_gt.createIndex()
else:
coco_gt = COCO(annotations)
self.coco_gt = coco_gt
def __call__(self, detection_results):
from pycocotools.cocoeval import COCOeval
image_ids = [ii["image_id"] for ii in detection_results] if isinstance(detection_results[0], dict) else [ii[0] for ii in detection_results]
image_ids = list(set(image_ids))
print("len(image_ids) =", len(image_ids))
coco_dt = self.coco_gt.loadRes(detection_results)
coco_eval = COCOeval(cocoGt=self.coco_gt, cocoDt=coco_dt, iouType="bbox")
coco_eval.params.imgIds = image_ids
coco_eval.evaluate()
coco_eval.accumulate()
coco_eval.summarize()
return coco_eval
def to_coco_json(detection_results, save_path, indent=2):
import json
__to_coco_json__ = lambda xx: {"image_id": int(xx[0]), "bbox": [float(ii) for ii in xx[1:5]], "score": float(xx[5]), "category_id": int(xx[6])}
aa = [__to_coco_json__(ii) for ii in detection_results]
with open(save_path, "w") as ff:
json.dump(aa, ff, indent=indent)
def to_coco_annotation(json_path):
import json
from PIL import Image
with open(json_path, "r") as ff:
aa = json.load(ff)
# int conversion just in case key is str
categories = {int(kk): vv for kk, vv in aa["indices_2_labels"].items()} if "indices_2_labels" in aa else {}
base_path = os.path.expanduser(aa["info"]["base_path"]) if "base_path" in aa.get("info", {}) and len(aa["info"]["base_path"]) > 0 else None
annotations, images, image_id_map = [], [], {}
for image_id, ii in enumerate(aa.get("validation", aa.get("test", []))):
image_file = os.path.join(base_path, ii["image"]) if base_path else ii["image"]
width, height = Image.open(image_file).size # For decoding bboxes, not actually openning images
for bb, label in zip(ii["objects"]["bbox"], ii["objects"]["label"]):
# bb [top, left, bottom, right] in [0, 1] -> [left, top, bbox_width, bbox_height] with actual coordinates
top = bb[0] * height
left = bb[1] * width
bbox_height = bb[2] * height - top
bbox_width = bb[3] * width - left
bb = [left, top, bbox_width, bbox_height]
area = bbox_width * bbox_height # Actual area in COCO is the segmentation area, doesn't matter in detection mission
label = int(label)
annotations.append({"bbox": bb, "category_id": label, "image_id": image_id, "id": len(annotations), "iscrowd": 0, "area": area})
if label not in categories:
categories[label] = str(len(categories))
images.append({"id": image_id, "file_name": image_file, "height": height, "width": width})
image_id_map[image_file] = image_id
categories = [{"id": kk, "name": vv} for kk, vv in categories.items()]
return {"images": images, "annotations": annotations, "categories": categories}, image_id_map
""" Wrapper a callback for using in training """
class COCOEvalCallback(callbacks.Callback):
"""
Basic test:
>>> from keras_cv_attention_models import efficientdet, coco
>>> model = efficientdet.EfficientDetD0()
>>> ee = coco.eval_func.COCOEvalCallback(batch_size=4)
>>> ee.model = model
>>> ee.on_epoch_end()
"""
def __init__(
self,
data_name="coco/2017", # [init_eval_dataset parameters]
batch_size=8,
resize_method="bilinear",
resize_antialias=False,
rescale_mode="auto",
letterbox_pad=-1,
use_bgr_input=False,
take_samples=-1,
nms_score_threshold=0.001, # [model_detection_and_decode parameters]
nms_iou_or_sigma=0.5,
nms_max_output_size=100,
nms_method="gaussian", # gaussian or hard
nms_mode="per_class", # per_class or global
nms_topk=5000,
anchors_mode="auto", # [model anchors related parameters]
anchor_scale=4, # Init anchors for model prediction. "auto" means 1 if (anchors_mode=="anchor_free" or anchors_mode=="yolor"), else 4
aspect_ratios=(1, 2, 0.5), # For efficientdet anchors only
num_scales=3, # For efficientdet anchors only
annotation_file=None,
save_json=None,
start_epoch=0, # [trainign callbacks parameters]
frequency=1,
model_basic_save_name=None,
):
super().__init__()
self.anchors_mode = anchors_mode
self.take_samples, self.annotation_file, self.start_epoch, self.frequency = take_samples, annotation_file, start_epoch, frequency
self.save_json, self.model_basic_save_name, self.save_path, self.item_key = save_json, model_basic_save_name, "checkpoints", "val_ap_ar"
self.data_name = data_name
self.dataset_kwargs = {
"data_name": data_name,
"batch_size": batch_size,
"rescale_mode": rescale_mode,
"resize_method": resize_method,
"resize_antialias": resize_antialias,
"letterbox_pad": letterbox_pad,
"use_bgr_input": use_bgr_input,
}
self.nms_kwargs = {
"score_threshold": nms_score_threshold,
"iou_or_sigma": nms_iou_or_sigma,
"max_output_size": nms_max_output_size,
"method": nms_method,
"mode": nms_mode,
"topk": nms_topk,
}
self.anchor_kwargs = {
"anchor_scale": anchor_scale,
"aspect_ratios": aspect_ratios,
"num_scales": num_scales,
}
self.efficient_det_num_anchors = len(aspect_ratios) * num_scales
self.is_coco = True if data_name.startswith("coco") and not data_name.endswith(".json") else False
if self.data_name.endswith(".json") and self.annotation_file is None:
self.annotation_file, self.image_id_map = to_coco_annotation(self.data_name)
else:
self.image_id_map = None
self.built = False
def build(self, input_shape, output_shape):
import re
input_shape = (
(int(input_shape[1]), int(input_shape[2])) if backend.image_data_format() == "channels_last" else (int(input_shape[2]), int(input_shape[3]))
)
self.eval_dataset, self.num_classes = init_eval_dataset(input_shape=input_shape, **self.dataset_kwargs)
print("\n>>>> [COCOEvalCallback] self.dataset_kwargs:", self.dataset_kwargs)
regression_len = (output_shape[-1] - self.num_classes) // 4 * 4
if self.anchors_mode is None or self.anchors_mode == "auto":
self.anchors_mode, num_anchors = anchors_func.get_anchors_mode_by_anchors(input_shape, total_anchors=output_shape[1], regression_len=regression_len)
elif self.anchors_mode == anchors_func.EFFICIENTDET_MODE:
num_anchors = self.efficient_det_num_anchors
else:
num_anchors = anchors_func.NUM_ANCHORS.get(self.anchors_mode, 9)
pyramid_levels = anchors_func.get_pyramid_levels_by_anchors(input_shape, total_anchors=output_shape[1], num_anchors=num_anchors)
print(">>>> [COCOEvalCallback] input_shape: {}, pyramid_levels: {}, anchors_mode: {}".format(input_shape, pyramid_levels, self.anchors_mode))
# print(">>>>", self.dataset_kwargs)
# print(">>>>", self.nms_kwargs)
use_sigmoid_on_score = not any([ii.name.endswith("_sigmoid") for ii in self.model.layers[-50:]])
print(">>>> use_sigmoid_on_score:", use_sigmoid_on_score)
self.pred_decoder = DecodePredictions(
input_shape, pyramid_levels, self.anchors_mode, regression_len=regression_len, use_sigmoid_on_score=use_sigmoid_on_score, **self.anchor_kwargs
)
# Training saving best
if self.model_basic_save_name is not None:
monitor_save_name = self.model_basic_save_name + "_epoch_{}_" + self.item_key + "_{}.h5"
self.monitor_save_re = re.compile(monitor_save_name.format(r"\d*", r"[\d\.]*"))
self.monitor_save = os.path.join(self.save_path, monitor_save_name)
self.is_better = lambda cur, pre: cur >= pre
self.pre_best = -1e5
self.coco_evaluation = COCOEvaluation(self.annotation_file)
self.built = True
@no_grad_if_torch
def on_epoch_end(self, epoch=0, logs=None):
if not self.built:
if self.dataset_kwargs["rescale_mode"] == "auto":
self.dataset_kwargs["rescale_mode"] = getattr(self.model, "rescale_mode", "torch")
self.build(self.model.input_shape, self.model.output_shape)
if epoch + 1 < self.start_epoch or epoch % self.frequency != 0:
return
# pred_decoder = self.model.decode_predictions
eval_dataset = self.eval_dataset.take(self.take_samples) if self.take_samples > 0 else self.eval_dataset
detection_results = model_detection_and_decode(
self.model, eval_dataset, self.pred_decoder, self.nms_kwargs, self.is_coco, self.image_id_map, self.num_classes
)
coco_eval = None if len(detection_results) == 0 else self.coco_evaluation(detection_results)
if self.save_json is not None:
to_coco_json(detection_results, self.save_json)
print(">>>> Detection results saved to:", self.save_json)
if hasattr(self.model, "history") and hasattr(self.model.history, "history"):
self.model.history.history.setdefault(self.item_key, []).append(([0] * 12) if coco_eval is None else coco_eval.stats.tolist())
# Training save best
cur_ap = coco_eval.stats[0] if coco_eval is not None else 0
if self.model_basic_save_name is not None and self.is_better(cur_ap, self.pre_best):
self.pre_best = cur_ap
# pre_monitor_saves = glob(self.monitor_save_re)
pre_monitor_saves = [ii for ii in os.listdir(self.save_path) if self.monitor_save_re.match(ii)]
# tf.print(">>>> pre_monitor_saves:", pre_monitor_saves)
if len(pre_monitor_saves) != 0:
os.remove(os.path.join(self.save_path, pre_monitor_saves[0]))
monitor_save = self.monitor_save.format(epoch + 1, "{:.4f}".format(cur_ap))
print("\n>>>> Save best to:", monitor_save)
if self.model is not None:
self.model.save(monitor_save)
return coco_eval