Skip to content

Commit

Permalink
feat: Added automatic detection of rotated bbox in training utils (#534)
Browse files Browse the repository at this point in the history
* feat: Makes rotated bbox mode automatic

* fix: Avoids inplace modifications

* feat: Added safeguard on subplot layout
  • Loading branch information
fg-mindee authored Oct 21, 2021
1 parent 22e28f7 commit 49dc156
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 12 deletions.
4 changes: 2 additions & 2 deletions references/classification/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

def plot_samples(images, targets):
# Unnormalize image
num_samples = 12
num_rows = 3
num_samples = min(len(images), 12)
num_rows = min(len(images), 3)
num_cols = int(math.ceil(num_samples / num_rows))
_, axes = plt.subplots(num_rows, num_cols, figsize=(20, 5))
for idx in range(num_samples):
Expand Down
16 changes: 9 additions & 7 deletions references/detection/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,28 @@
from typing import List, Dict


def plot_samples(images, targets: List[Dict[str, np.ndarray]], rotation: bool = False) -> None:
def plot_samples(images, targets: List[Dict[str, np.ndarray]]) -> None:
# Unnormalize image
nb_samples = 4
nb_samples = min(len(images), 4)
_, axes = plt.subplots(2, nb_samples, figsize=(20, 5))
for idx in range(nb_samples):
img = (255 * images[idx].numpy()).round().clip(0, 255).astype(np.uint8)
if img.shape[0] == 3 and img.shape[2] != 3:
img = img.transpose(1, 2, 0)

target = np.zeros(img.shape[:2], np.uint8)
boxes = targets[idx]['boxes'][np.logical_not(targets[idx]['flags'])]
boxes = targets[idx].copy()
boxes[:, [0, 2]] = boxes[:, [0, 2]] * img.shape[1]
boxes[:, [1, 3]] = boxes[:, [1, 3]] * img.shape[0]
for box in boxes.round().astype(int):
if rotation:
box = cv2.boxPoints(((box[0], box[2]), (box[1], box[3]), box[4]))
boxes[:, :4] = boxes[:, :4].round().astype(int)

for box in boxes:
if boxes.shape[1] == 5:
box = cv2.boxPoints(((int(box[0]), int(box[1])), (int(box[2]), int(box[3])), -box[4]))
box = np.int0(box)
cv2.fillPoly(target, [box], 1)
else:
target[box[1]: box[3] + 1, box[0]: box[2] + 1] = 1
target[int(box[1]): int(box[3]) + 1, int(box[0]): int(box[2]) + 1] = 1

axes[0][idx].imshow(img)
axes[0][idx].axis('off')
Expand Down
6 changes: 3 additions & 3 deletions references/recognition/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@


def plot_samples(images, targets):
# Unnormalize image
num_samples = 12
num_rows = 3
# Unnormalize image
num_samples = min(len(images), 12)
num_rows = min(len(images), 3)
num_cols = int(math.ceil(num_samples / num_rows))
_, axes = plt.subplots(num_rows, num_cols, figsize=(20, 5))
for idx in range(num_samples):
Expand Down

0 comments on commit 49dc156

Please sign in to comment.