diff --git a/ppocr/postprocess/picodet_postprocess.py b/ppocr/postprocess/picodet_postprocess.py index 1a0aeb4387..4053714d30 100644 --- a/ppocr/postprocess/picodet_postprocess.py +++ b/ppocr/postprocess/picodet_postprocess.py @@ -78,6 +78,24 @@ def area_of(left_top, right_bottom): return hw[..., 0] * hw[..., 1] +def calculate_containment(boxes0, boxes1): + """ + Calculate the containment of the boxes. + Args: + boxes0 (N, 4): ground truth boxes. + boxes1 (N or 1, 4): predicted boxes. + Returns: + containment (N): containment values. + """ + overlap_left_top = np.maximum(boxes0[..., :2], boxes1[..., :2]) + overlap_right_bottom = np.minimum(boxes0[..., 2:], boxes1[..., 2:]) + + overlap_area = area_of(overlap_left_top, overlap_right_bottom) + area0 = area_of(boxes0[..., :2], boxes0[..., 2:]) + area1 = area_of(boxes1[..., :2], boxes1[..., 2:]) + return overlap_area / np.minimum(area0, np.expand_dims(area1, axis=0)) + + class PicoDetPostProcess(object): """ Args: @@ -245,6 +263,24 @@ def __call__(self, ori_img, img, preds): for dt in out_boxes_list: clsid, bbox, score = int(dt[0]), dt[2:], dt[1] label = self.labels[clsid] - result = {'bbox': bbox, 'label': label} + result = {'bbox': bbox, 'label': label, 'score': score} results.append(result) + + # Handle conflict where a box is simultaneously recognized as multiple labels. + # Use IoU to find similar boxes. Prioritize labels as table, text, and others when deduplicate similar boxes. + bboxes = np.array([x['bbox'] for x in results]) + duplicate_idx = list() + for i in range(len(results)): + if i in duplicate_idx: + continue + containments = calculate_containment(bboxes, bboxes[i, ...]) + overlaps = np.where(containments > 0.5)[0] + if len(overlaps) > 1: + table_box = [x for x in overlaps if results[x]['label'] == 'table'] + if len(table_box) > 0: + keep = sorted([(x, results[x]) for x in table_box], key=lambda x: x[1]['score'], reverse=True)[0][0] + else: + keep = sorted([(x, results[x]) for x in overlaps], key=lambda x: x[1]['score'], reverse=True)[0][0] + duplicate_idx.extend([x for x in overlaps if x != keep]) + results = [x for i, x in enumerate(results) if i not in duplicate_idx] return results diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py index 6c5c36cf86..831a49e6d4 100755 --- a/tools/infer/predict_det.py +++ b/tools/infer/predict_det.py @@ -217,7 +217,7 @@ def filter_tag_det_res_only_clip(self, dt_boxes, image_shape): dt_boxes = np.array(dt_boxes_new) return dt_boxes - def __call__(self, img): + def predict(self, img): ori_im = img.copy() data = {'image': img} @@ -283,6 +283,75 @@ def __call__(self, img): et = time.time() return dt_boxes, et - st + def __call__(self, img): + # For image like poster with one side much greater than the other side, + # splitting recursively and processing with overlap to enhance performance. + MIN_BOUND_DISTANCE = 50 + dt_boxes = np.zeros((0, 4, 2), dtype=np.float32) + elapse = 0 + if img.shape[0] / img.shape[1] > 2 and img.shape[0] > self.args.det_limit_side_len: + start_h = 0 + end_h = 0 + while end_h <= img.shape[0]: + end_h = start_h + img.shape[1] * 3 // 4 + subimg = img[start_h: end_h, :] + if len(subimg) == 0: + break + sub_dt_boxes, sub_elapse = self.predict(subimg) + offset = start_h + # To prevent text blocks from being cut off, roll back a certain buffer area. + if len(sub_dt_boxes) == 0 or img.shape[1] - max([x[-1][1] for x in sub_dt_boxes]) > MIN_BOUND_DISTANCE: + start_h = end_h + else: + sorted_indices = np.argsort(sub_dt_boxes[:, 2, 1]) + sub_dt_boxes = sub_dt_boxes[sorted_indices] + bottom_line = 0 if len(sub_dt_boxes) <= 1 else int(np.max(sub_dt_boxes[:-1, 2, 1])) + if bottom_line > 0: + start_h += bottom_line + sub_dt_boxes = sub_dt_boxes[sub_dt_boxes[:, 2, 1] <= bottom_line] + else: + start_h = end_h + if len(sub_dt_boxes) > 0: + if dt_boxes.shape[0] == 0: + dt_boxes = sub_dt_boxes + np.array([0, offset], dtype=np.float32) + else: + dt_boxes = np.append(dt_boxes, + sub_dt_boxes + np.array([0, offset], dtype=np.float32), + axis=0) + elapse += sub_elapse + elif img.shape[1] / img.shape[0] > 3 and img.shape[1] > self.args.det_limit_side_len * 3: + start_w = 0 + end_w = 0 + while end_w <= img.shape[1]: + end_w = start_w + img.shape[0] * 3 // 4 + subimg = img[:, start_w: end_w] + if len(subimg) == 0: + break + sub_dt_boxes, sub_elapse = self.predict(subimg) + offset = start_w + if len(sub_dt_boxes) == 0 or img.shape[0] - max([x[-1][0] for x in sub_dt_boxes]) > MIN_BOUND_DISTANCE: + start_w = end_w + else: + sorted_indices = np.argsort(sub_dt_boxes[:, 2, 0]) + sub_dt_boxes = sub_dt_boxes[sorted_indices] + right_line = 0 if len(sub_dt_boxes) <= 1 else int(np.max(sub_dt_boxes[:-1, 1, 0])) + if right_line > 0: + start_w += right_line + sub_dt_boxes = sub_dt_boxes[sub_dt_boxes[:, 1, 0] <= right_line] + else: + start_w = end_w + if len(sub_dt_boxes) > 0: + if dt_boxes.shape[0] == 0: + dt_boxes = sub_dt_boxes + np.array([offset, 0], dtype=np.float32) + else: + dt_boxes = np.append(dt_boxes, + sub_dt_boxes + np.array([offset, 0], dtype=np.float32), + axis=0) + elapse += sub_elapse + else: + dt_boxes, elapse = self.predict(img) + return dt_boxes, elapse + if __name__ == "__main__": args = utility.parse_args()