Skip to content

Commit

Permalink
apply code factor suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 committed Nov 15, 2021
1 parent 5d120c8 commit d0f3817
Showing 1 changed file with 29 additions and 23 deletions.
52 changes: 29 additions & 23 deletions doctr/datasets/synthtext.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class SynthText(VisionDataset):
<https://arxiv.org/abs/1604.06646>`_.
Example::
>>> # NOTE: This dataset has currently no train/test split
>>> from doctr.datasets import SynthText
>>> data_set = SynthText(download=True)
>>> img, target = data_set[0]
Expand All @@ -36,6 +37,7 @@ class SynthText(VisionDataset):

def __init__(
self,
train: bool = True,
sample_transforms: Optional[Callable[[Any], Any]] = None,
rotated_bbox: bool = False,
**kwargs: Any,
Expand All @@ -61,37 +63,41 @@ def __init__(
if not os.path.exists(os.path.join(tmp_root, img_path[0])):
raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_path[0])}")

labels = self._text_to_words(txt)
labels = _text_to_words(txt)
word_boxes = word_boxes.transpose(2, 1, 0) if word_boxes.ndim == 3 else np.expand_dims(word_boxes, axis=0)

if rotated_bbox:
# x_center, y_center, w, h, alpha = 0
box_targets = [self._compute_rotated_box(pts) for pts in word_boxes]
box_targets = [_compute_rotated_box(pts) for pts in word_boxes]
else:
# xmin, ymin, xmax, ymax
box_targets = [self._compute_straight_box(pts) for pts in word_boxes] # type: ignore[misc]
box_targets = [_compute_straight_box(pts) for pts in word_boxes] # type: ignore[misc]

self.data.append((img_path[0], dict(boxes=np.asarray(box_targets, dtype=np_dtype), labels=labels)))

self.root = tmp_root

def _text_to_words(self, txt: np.ndarray) -> List[str]:
line = '\n'.join(txt)
return line.split()

def _compute_straight_box(self, pts: np.ndarray) -> Tuple[float, float, float, float]:
# pts: Nx2
xmin = np.min(pts[:, 0])
xmax = np.max(pts[:, 0])
ymin = np.min(pts[:, 1])
ymax = np.max(pts[:, 1])
return xmin, ymin, xmax, ymax

def _compute_rotated_box(self, pts: np.ndarray) -> Tuple[float, float, float, float, int]:
# pts: Nx2
x = np.min(pts[:, 0])
y = np.min(pts[:, 1])
width = np.max(pts[:, 0]) - x
height = np.max(pts[:, 1]) - y
# x_center, y_center, w, h, alpha = 0
return x + width / 2, y + height / 2, width, height, 0

def _text_to_words(txt: np.ndarray) -> List[str]:
"""Convert np.str-Array to list of str."""
line = '\n'.join(txt)
return line.split()


def _compute_straight_box(pts: np.ndarray) -> Tuple[float, float, float, float]:
# pts: Nx2
xmin = np.min(pts[:, 0])
xmax = np.max(pts[:, 0])
ymin = np.min(pts[:, 1])
ymax = np.max(pts[:, 1])
return xmin, ymin, xmax, ymax


def _compute_rotated_box(pts: np.ndarray) -> Tuple[float, float, float, float, int]:
# pts: Nx2
x = np.min(pts[:, 0])
y = np.min(pts[:, 1])
width = np.max(pts[:, 0]) - x
height = np.max(pts[:, 1]) - y
# x_center, y_center, w, h, alpha = 0
return x + width / 2, y + height / 2, width, height, 0

0 comments on commit d0f3817

Please sign in to comment.