Skip to content

Commit

Permalink
refactor: Refactored constructor of SROIE dataset (#660)
Browse files Browse the repository at this point in the history
* start synth

* cleanup

* start synth

* add synthtext

* add docu and tests

* apply code factor suggestions

* apply changes

* clean

* refactor sroie

* apply changes
  • Loading branch information
felixdittrich92 authored Dec 7, 2021
1 parent 8382896 commit e076418
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 22 deletions.
41 changes: 21 additions & 20 deletions doctr/datasets/sroie.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,36 +49,37 @@ def __init__(
self.sample_transforms = sample_transforms
self.train = train

if rotated_bbox:
raise NotImplementedError

# # List images
tmp_root = os.path.join(self.root, 'images')
self.data: List[Tuple[str, Dict[str, Any]]] = []
np_dtype = np.float32

for img_path in os.listdir(tmp_root):

# File existence check
if not os.path.exists(os.path.join(tmp_root, img_path)):
raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_path)}")

stem = Path(img_path).stem
_targets = []
with open(os.path.join(self.root, 'annotations', f"{stem}.txt"), encoding='latin') as f:
for row in csv.reader(f, delimiter=','):
# Safeguard for blank lines
if len(row) > 0:
# Label may contain commas
label = ",".join(row[8:])
# Reduce 8 coords to 4
p1_x, p1_y, p2_x, p2_y, p3_x, p3_y, p4_x, p4_y = map(int, row[:8])
left, right = min(p1_x, p2_x, p3_x, p4_x), max(p1_x, p2_x, p3_x, p4_x)
top, bot = min(p1_y, p2_y, p3_y, p4_y), max(p1_y, p2_y, p3_y, p4_y)
if len(label) > 0:
_targets.append((label, [left, top, right, bot]))

text_targets, box_targets = zip(*_targets)

self.data.append((img_path, dict(boxes=np.asarray(box_targets, dtype=np_dtype), labels=text_targets)))
_rows = [row for row in list(csv.reader(f, delimiter=',')) if len(row) > 0]

labels = [",".join(row[8:]) for row in _rows]
# reorder coordinates (8 -> (4,2)) and filter empty lines
coords = np.stack([np.array(list(map(int, row[:8])), dtype=np_dtype).reshape((4, 2))
for row in _rows], axis=0)

if rotated_bbox:
# x_center, y_center, w, h, alpha = 0
mins = coords.min(axis=1)
maxs = coords.max(axis=1)
box_targets = np.concatenate(
((mins + maxs) / 2, maxs - mins, np.zeros((coords.shape[0], 1))), axis=1)
else:
# xmin, ymin, xmax, ymax
box_targets = np.concatenate((coords.min(axis=1), coords.max(axis=1)), axis=1)

self.data.append((img_path, dict(boxes=box_targets, labels=labels)))

self.root = tmp_root

def extra_repr(self) -> str:
Expand Down
2 changes: 1 addition & 1 deletion tests/pytorch/test_datasets_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_visiondataset():
[
['FUNSD', True, [512, 512], 149, False],
['FUNSD', False, [512, 512], 50, True],
['SROIE', True, [512, 512], 626, False],
['SROIE', True, [512, 512], 626, True],
['SROIE', False, [512, 512], 360, False],
['CORD', True, [512, 512], 800, True],
['CORD', False, [512, 512], 100, False],
Expand Down
2 changes: 1 addition & 1 deletion tests/tensorflow/test_datasets_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
[
['FUNSD', True, [512, 512], 149, False],
['FUNSD', False, [512, 512], 50, True],
['SROIE', True, [512, 512], 626, False],
['SROIE', True, [512, 512], 626, True],
['SROIE', False, [512, 512], 360, False],
['CORD', True, [512, 512], 800, True],
['CORD', False, [512, 512], 100, False],
Expand Down

0 comments on commit e076418

Please sign in to comment.