Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize prediction on long image and deduplicate similar boxes with multiple lables #11366

Merged
merged 6 commits into from
Dec 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion ppocr/postprocess/picodet_postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,24 @@ def area_of(left_top, right_bottom):
return hw[..., 0] * hw[..., 1]


def calculate_containment(boxes0, boxes1):
"""
Calculate the containment of the boxes.
Args:
boxes0 (N, 4): ground truth boxes.
boxes1 (N or 1, 4): predicted boxes.
Returns:
containment (N): containment values.
"""
overlap_left_top = np.maximum(boxes0[..., :2], boxes1[..., :2])
overlap_right_bottom = np.minimum(boxes0[..., 2:], boxes1[..., 2:])

overlap_area = area_of(overlap_left_top, overlap_right_bottom)
area0 = area_of(boxes0[..., :2], boxes0[..., 2:])
area1 = area_of(boxes1[..., :2], boxes1[..., 2:])
return overlap_area / np.minimum(area0, np.expand_dims(area1, axis=0))


class PicoDetPostProcess(object):
"""
Args:
Expand Down Expand Up @@ -245,6 +263,24 @@ def __call__(self, ori_img, img, preds):
for dt in out_boxes_list:
clsid, bbox, score = int(dt[0]), dt[2:], dt[1]
label = self.labels[clsid]
result = {'bbox': bbox, 'label': label}
result = {'bbox': bbox, 'label': label, 'score': score}
results.append(result)

# Handle conflict where a box is simultaneously recognized as multiple labels.
# Use IoU to find similar boxes. Prioritize labels as table, text, and others when deduplicate similar boxes.
bboxes = np.array([x['bbox'] for x in results])
duplicate_idx = list()
for i in range(len(results)):
if i in duplicate_idx:
continue
containments = calculate_containment(bboxes, bboxes[i, ...])
overlaps = np.where(containments > 0.5)[0]
if len(overlaps) > 1:
table_box = [x for x in overlaps if results[x]['label'] == 'table']
if len(table_box) > 0:
keep = sorted([(x, results[x]) for x in table_box], key=lambda x: x[1]['score'], reverse=True)[0][0]
else:
keep = sorted([(x, results[x]) for x in overlaps], key=lambda x: x[1]['score'], reverse=True)[0][0]
duplicate_idx.extend([x for x in overlaps if x != keep])
results = [x for i, x in enumerate(results) if i not in duplicate_idx]
return results
71 changes: 70 additions & 1 deletion tools/infer/predict_det.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
dt_boxes = np.array(dt_boxes_new)
return dt_boxes

def __call__(self, img):
def predict(self, img):
ori_im = img.copy()
data = {'image': img}

Expand Down Expand Up @@ -283,6 +283,75 @@ def __call__(self, img):
et = time.time()
return dt_boxes, et - st

def __call__(self, img):
# For image like poster with one side much greater than the other side,
# splitting recursively and processing with overlap to enhance performance.
MIN_BOUND_DISTANCE = 50
dt_boxes = np.zeros((0, 4, 2), dtype=np.float32)
elapse = 0
if img.shape[0] / img.shape[1] > 2 and img.shape[0] > self.args.det_limit_side_len:
start_h = 0
end_h = 0
while end_h <= img.shape[0]:
end_h = start_h + img.shape[1] * 3 // 4
subimg = img[start_h: end_h, :]
if len(subimg) == 0:
break
sub_dt_boxes, sub_elapse = self.predict(subimg)
offset = start_h
# To prevent text blocks from being cut off, roll back a certain buffer area.
if len(sub_dt_boxes) == 0 or img.shape[1] - max([x[-1][1] for x in sub_dt_boxes]) > MIN_BOUND_DISTANCE:
start_h = end_h
else:
sorted_indices = np.argsort(sub_dt_boxes[:, 2, 1])
sub_dt_boxes = sub_dt_boxes[sorted_indices]
bottom_line = 0 if len(sub_dt_boxes) <= 1 else int(np.max(sub_dt_boxes[:-1, 2, 1]))
if bottom_line > 0:
start_h += bottom_line
sub_dt_boxes = sub_dt_boxes[sub_dt_boxes[:, 2, 1] <= bottom_line]
else:
start_h = end_h
if len(sub_dt_boxes) > 0:
if dt_boxes.shape[0] == 0:
dt_boxes = sub_dt_boxes + np.array([0, offset], dtype=np.float32)
else:
dt_boxes = np.append(dt_boxes,
sub_dt_boxes + np.array([0, offset], dtype=np.float32),
axis=0)
elapse += sub_elapse
elif img.shape[1] / img.shape[0] > 3 and img.shape[1] > self.args.det_limit_side_len * 3:
start_w = 0
end_w = 0
while end_w <= img.shape[1]:
end_w = start_w + img.shape[0] * 3 // 4
subimg = img[:, start_w: end_w]
if len(subimg) == 0:
break
sub_dt_boxes, sub_elapse = self.predict(subimg)
offset = start_w
if len(sub_dt_boxes) == 0 or img.shape[0] - max([x[-1][0] for x in sub_dt_boxes]) > MIN_BOUND_DISTANCE:
start_w = end_w
else:
sorted_indices = np.argsort(sub_dt_boxes[:, 2, 0])
sub_dt_boxes = sub_dt_boxes[sorted_indices]
right_line = 0 if len(sub_dt_boxes) <= 1 else int(np.max(sub_dt_boxes[:-1, 1, 0]))
if right_line > 0:
start_w += right_line
sub_dt_boxes = sub_dt_boxes[sub_dt_boxes[:, 1, 0] <= right_line]
else:
start_w = end_w
if len(sub_dt_boxes) > 0:
if dt_boxes.shape[0] == 0:
dt_boxes = sub_dt_boxes + np.array([offset, 0], dtype=np.float32)
else:
dt_boxes = np.append(dt_boxes,
sub_dt_boxes + np.array([offset, 0], dtype=np.float32),
axis=0)
elapse += sub_elapse
else:
dt_boxes, elapse = self.predict(img)
return dt_boxes, elapse


if __name__ == "__main__":
args = utility.parse_args()
Expand Down