Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add shrinked target in linknet + dilation in postprocessing #822

Merged
merged 7 commits into from
Feb 17, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 72 additions & 12 deletions doctr/models/detection/linknet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import cv2
import numpy as np
import pyclipper
from shapely.geometry import Polygon

from doctr.file_utils import is_tf_available
from doctr.models.core import BaseModel
Expand Down Expand Up @@ -37,6 +39,51 @@ def __init__(
bin_thresh,
assume_straight_pages
)
self.unclip_ratio = 1.5

def polygon_to_box(
self,
points: np.ndarray,
) -> np.ndarray:
"""Expand a polygon (points) by a factor unclip_ratio, and returns a polygon

Args:
points: The first parameter.

Returns:
a box in absolute coordinates (xmin, ymin, xmax, ymax) or (4, 2) array (quadrangle)
"""
if not self.assume_straight_pages:
# Compute the rectangle polygon enclosing the raw polygon
rect = cv2.minAreaRect(points)
points = cv2.boxPoints(rect)
# Add 1 pixel to correct cv2 approx
area = (rect[1][0] + 1) * (1 + rect[1][1])
length = 2 * (rect[1][0] + rect[1][1]) + 2
else:
poly = Polygon(points)
area = poly.area
length = poly.length
distance = area * self.unclip_ratio / length # compute distance to expand polygon
offset = pyclipper.PyclipperOffset()
offset.AddPath(points, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
_points = offset.Execute(distance)
# Take biggest stack of points
idx = 0
if len(_points) > 1:
max_size = 0
for _idx, p in enumerate(_points):
if len(p) > max_size:
idx = _idx
max_size = len(p)
# We ensure that _points can be correctly casted to a ndarray
_points = [_points[idx]]
expanded_points = np.asarray(_points) # expand polygon
if len(expanded_points) < 1:
return None
return cv2.boundingRect(expanded_points) if self.assume_straight_pages else np.roll(
cv2.boxPoints(cv2.minAreaRect(expanded_points)), -1, axis=0
)

def bitmap_to_boxes(
self,
Expand Down Expand Up @@ -75,6 +122,11 @@ def bitmap_to_boxes(
if score < self.box_thresh: # remove polygons with a weak objectness
continue

if self.assume_straight_pages:
_box = self.polygon_to_box(points)
else:
_box = self.polygon_to_box(np.squeeze(contour))

if self.assume_straight_pages:
# compute relative polygon to get rid of img shape
xmin, ymin, xmax, ymax = x / width, y / height, (x + w) / width, (y + h) / height
Expand Down Expand Up @@ -102,6 +154,7 @@ class _LinkNet(BaseModel):

min_size_box: int = 3
assume_straight_pages: bool = True
shrink_ratio = 0.5

def build_target(
self,
Expand All @@ -117,11 +170,7 @@ def build_target(
h, w = output_shape
target_shape = (len(target), h, w, 1)

if self.assume_straight_pages:
seg_target = np.zeros(target_shape, dtype=bool)
else:
seg_target = np.zeros(target_shape, dtype=np.uint8)

seg_target = np.zeros(target_shape, dtype=np.uint8)
seg_mask = np.ones(target_shape, dtype=bool)

for idx, _target in enumerate(target):
Expand Down Expand Up @@ -156,13 +205,24 @@ def build_target(
if box_size < self.min_size_box:
seg_mask[idx, box[1]: box[3] + 1, box[0]: box[2] + 1] = False
continue
# Fill polygon with 1
if not self.assume_straight_pages:
cv2.fillPoly(seg_target[idx], [poly.astype(np.int32)], 1)
else:
if box.shape == (4, 2):
box = [np.min(box[:, 0]), np.min(box[:, 1]), np.max(box[:, 0]), np.max(box[:, 1])]
seg_target[idx, box[1]: box[3] + 1, box[0]: box[2] + 1] = True

# Negative shrink for gt, as described in paper
polygon = Polygon(poly)
distance = polygon.area * (1 - np.power(self.shrink_ratio, 2)) / polygon.length
subject = [tuple(coor) for coor in poly]
padding = pyclipper.PyclipperOffset()
padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
shrunken = padding.Execute(-distance)

# Draw polygon on gt if it is valid
if len(shrunken) == 0:
seg_mask[idx, box[1]: box[3] + 1, box[0]: box[2] + 1] = False
continue
shrunken = np.array(shrunken[0]).reshape(-1, 2)
if shrunken.shape[0] <= 2 or not Polygon(shrunken).is_valid:
seg_mask[idx, box[1]: box[3] + 1, box[0]: box[2] + 1] = False
continue
cv2.fillPoly(seg_target[idx], [shrunken.astype(np.int32)], 1)

# Don't forget to switch back to channel first if PyTorch is used
if not is_tf_available():
Expand Down