Skip to content

Commit

Permalink
Annotator txt_color updates (ultralytics#13842)
Browse files Browse the repository at this point in the history
Signed-off-by: Glenn Jocher <[email protected]>
Co-authored-by: Muhammad Rizwan Munawar <[email protected]>
Co-authored-by: UltralyticsAssistant <[email protected]>
  • Loading branch information
3 people authored Jun 20, 2024
1 parent 31de5d0 commit 3f90100
Show file tree
Hide file tree
Showing 5 changed files with 206 additions and 20 deletions.
6 changes: 3 additions & 3 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ Thank you for your interest in contributing to Ultralytics open-source YOLO repo

1. [Code of Conduct](#code-of-conduct)
2. [Contributing via Pull Requests](#contributing-via-pull-requests)
- [CLA Signing](#cla-signing)
- [Google-Style Docstrings](#google-style-docstrings)
- [GitHub Actions CI Tests](#github-actions-ci-tests)
- [CLA Signing](#cla-signing)
- [Google-Style Docstrings](#google-style-docstrings)
- [GitHub Actions CI Tests](#github-actions-ci-tests)
3. [Reporting Bugs](#reporting-bugs)
4. [License](#license)
5. [Conclusion](#conclusion)
Expand Down
80 changes: 80 additions & 0 deletions docs/en/usage/simple-utilities.md
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,86 @@ for obb in obb_boxes:
image_with_obb = ann.result()
```

#### Bounding Boxes Circle Annotation ([Circle Label](https://docs.ultralytics.com/reference/utils/plotting/#ultralytics.utils.plotting.Annotator.circle_label))

```python
import cv2

from ultralytics import YOLO
from ultralytics.utils.plotting import Annotator, colors

model = YOLO("yolov8s.pt")
cap = cv2.VideoCapture("path/to/video/file.mp4")

w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS))
writer = cv2.VideoWriter("Ultralytics circle annotation.avi", cv2.VideoWriter_fourcc(*"MJPG"), fps, (w, h))

while True:
ret, im0 = cap.read()
if not ret:
break

annotator = Annotator(im0, line_width=2)

results = model.predict(im0)
boxes = results[0].boxes.xyxy.cpu()
clss = results[0].boxes.cls.cpu().tolist()

for box, cls in zip(boxes, clss):
x1, y1 = int((box[0] + box[2]) // 2), int((box[1] + box[3]) // 2)
annotator.circle_label(box, label=model.names[int(cls)], color=colors(int(cls), True))

writer.write(im0)
cv2.imshow("Ultralytics circle annotation", im0)

if cv2.waitKey(1) & 0xFF == ord("q"):
break

writer.release()
cap.release()
cv2.destroyAllWindows()
```

#### Bounding Boxes Text Annotation ([Text Label](https://docs.ultralytics.com/reference/utils/plotting/#ultralytics.utils.plotting.Annotator.text_label))

```python
import cv2

from ultralytics import YOLO
from ultralytics.utils.plotting import Annotator, colors

model = YOLO("yolov8s.pt")
cap = cv2.VideoCapture("path/to/video/file.mp4")

w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS))
writer = cv2.VideoWriter("Ultralytics text annotation.avi", cv2.VideoWriter_fourcc(*"MJPG"), fps, (w, h))

while True:
ret, im0 = cap.read()
if not ret:
break

annotator = Annotator(im0, line_width=2)

results = model.predict(im0)
boxes = results[0].boxes.xyxy.cpu()
clss = results[0].boxes.cls.cpu().tolist()

for box, cls in zip(boxes, clss):
x1, y1 = int((box[0] + box[2]) // 2), int((box[1] + box[3]) // 2)
annotator.text_label(box, label=model.names[int(cls)], color=colors(int(cls), True))

writer.write(im0)
cv2.imshow("Ultralytics text annotation", im0)

if cv2.waitKey(1) & 0xFF == ord("q"):
break

writer.release()
cap.release()
cv2.destroyAllWindows()
```

See the [`Annotator` Reference Page](../reference/utils/plotting.md#ultralytics.utils.plotting.Annotator) for additional insight.

## Miscellaneous
Expand Down
8 changes: 4 additions & 4 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ We greatly appreciate contributions from the community, including examples, appl

1. **Create a pull request (PR)** with the title prefix `[Example]`, adding your new example folder to the `examples/` directory within the repository.
2. **Ensure your project adheres to the following standards:**
- Makes use of the `ultralytics` package.
- Includes a `README.md` with clear instructions for setting up and running the example.
- Avoids adding large files or dependencies unless they are absolutely necessary for the example.
- Contributors should be willing to provide support for their examples and address related issues.
- Makes use of the `ultralytics` package.
- Includes a `README.md` with clear instructions for setting up and running the example.
- Avoids adding large files or dependencies unless they are absolutely necessary for the example.
- Contributors should be willing to provide support for their examples and address related issues.

For more detailed information and guidance on contributing, please visit our [contribution documentation](https://docs.ultralytics.com/help/contributing).

Expand Down
18 changes: 9 additions & 9 deletions examples/YOLOv8-ONNXRuntime-CPP/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,21 +72,21 @@ Note (2): Due to ONNX Runtime, we need to use CUDA 11 and cuDNN 8. Keep in mind

3. Create a build directory and navigate to it:

```console
mkdir build && cd build
```
```console
mkdir build && cd build
```

4. Run CMake to generate the build files:

```console
cmake ..
```
```console
cmake ..
```

5. Build the project:

```console
make
```
```console
make
```

6. The built executable should now be located in the `build` directory.

Expand Down
114 changes: 110 additions & 4 deletions ultralytics/utils/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,108 @@ def __init__(self, im, line_width=None, font_size=None, font="Arial.ttf", pil=Fa
(104, 31, 17),
}

def box_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), rotated=False):
"""Add one xyxy box to image with label."""
txt_color = (
(104, 31, 17) if color in self.dark_colors else (255, 255, 255) if color in self.light_colors else txt_color
def get_txt_color(self, color=(128, 128, 128), txt_color=(255, 255, 255)):
"""Assign text color based on background color."""
if color in self.dark_colors:
return 104, 31, 17
elif color in self.light_colors:
return 255, 255, 255
else:
return txt_color

def circle_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), margin=2):
"""
Draws a label with a background rectangle centered within a given bounding box.
Args:
box (tuple): The bounding box coordinates (x1, y1, x2, y2).
label (str): The text label to be displayed.
color (tuple, optional): The background color of the rectangle (R, G, B).
txt_color (tuple, optional): The color of the text (R, G, B).
margin (int, optional): The margin between the text and the rectangle border.
"""

# If label have more than 3 characters, skip other characters, due to circle size
if len(label) > 3:
print(
f"Length of label is {len(label)}, initial 3 label characters will be considered for circle annotation!"
)
label = label[:3]

# Calculate the center of the box
x_center, y_center = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)
# Get the text size
text_size = cv2.getTextSize(str(label), cv2.FONT_HERSHEY_SIMPLEX, self.sf - 0.15, self.tf)[0]
# Calculate the required radius to fit the text with the margin
required_radius = int(((text_size[0] ** 2 + text_size[1] ** 2) ** 0.5) / 2) + margin
# Draw the circle with the required radius
cv2.circle(self.im, (x_center, y_center), required_radius, color, -1)
# Calculate the position for the text
text_x = x_center - text_size[0] // 2
text_y = y_center + text_size[1] // 2
# Draw the text
cv2.putText(
self.im,
str(label),
(text_x, text_y),
cv2.FONT_HERSHEY_SIMPLEX,
self.sf - 0.15,
self.get_txt_color(color, txt_color),
self.tf,
lineType=cv2.LINE_AA,
)

def text_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), margin=5):
"""
Draws a label with a background rectangle centered within a given bounding box.
Args:
box (tuple): The bounding box coordinates (x1, y1, x2, y2).
label (str): The text label to be displayed.
color (tuple, optional): The background color of the rectangle (R, G, B).
txt_color (tuple, optional): The color of the text (R, G, B).
margin (int, optional): The margin between the text and the rectangle border.
"""

# Calculate the center of the bounding box
x_center, y_center = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)
# Get the size of the text
text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, self.sf - 0.1, self.tf)[0]
# Calculate the top-left corner of the text (to center it)
text_x = x_center - text_size[0] // 2
text_y = y_center + text_size[1] // 2
# Calculate the coordinates of the background rectangle
rect_x1 = text_x - margin
rect_y1 = text_y - text_size[1] - margin
rect_x2 = text_x + text_size[0] + margin
rect_y2 = text_y + margin
# Draw the background rectangle
cv2.rectangle(self.im, (rect_x1, rect_y1), (rect_x2, rect_y2), color, -1)
# Draw the text on top of the rectangle
cv2.putText(
self.im,
label,
(text_x, text_y),
cv2.FONT_HERSHEY_SIMPLEX,
self.sf - 0.1,
self.get_txt_color(color, txt_color),
self.tf,
lineType=cv2.LINE_AA,
)

def box_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), rotated=False):
"""
Draws a bounding box to image with label.
Args:
box (tuple): The bounding box coordinates (x1, y1, x2, y2).
label (str): The text label to be displayed.
color (tuple, optional): The background color of the rectangle (R, G, B).
txt_color (tuple, optional): The color of the text (R, G, B).
rotated (bool, optional): Variable used to check if task is OBB
"""

txt_color = self.get_txt_color(color, txt_color)
if isinstance(box, torch.Tensor):
box = box.tolist()
if self.pil or not is_ascii(label):
Expand Down Expand Up @@ -242,6 +339,7 @@ def masks(self, masks, colors, im_gpu, alpha=0.5, retina_masks=False):
alpha (float): Mask transparency: 0.0 fully transparent, 1.0 opaque
retina_masks (bool): Whether to use high resolution masks or not. Defaults to False.
"""

if self.pil:
# Convert to numpy first
self.im = np.asarray(self.im).copy()
Expand Down Expand Up @@ -281,6 +379,7 @@ def kpts(self, kpts, shape=(640, 640), radius=5, kpt_line=True, conf_thres=0.25)
Note:
`kpt_line=True` currently only supports human pose plotting.
"""

if self.pil:
# Convert to numpy first
self.im = np.asarray(self.im).copy()
Expand Down Expand Up @@ -376,6 +475,7 @@ def get_bbox_dimension(self, bbox=None):
Returns:
angle (degree): Degree value of angle between three points
"""

x_min, y_min, x_max, y_max = bbox
width = x_max - x_min
height = y_max - y_min
Expand All @@ -390,6 +490,7 @@ def draw_region(self, reg_pts=None, color=(0, 255, 0), thickness=5):
color (tuple): Region Color value
thickness (int): Region area thickness value
"""

cv2.polylines(self.im, [np.array(reg_pts, dtype=np.int32)], isClosed=True, color=color, thickness=thickness)

def draw_centroid_and_tracks(self, track, color=(255, 0, 255), track_thickness=2):
Expand All @@ -401,6 +502,7 @@ def draw_centroid_and_tracks(self, track, color=(255, 0, 255), track_thickness=2
color (tuple): tracks line color
track_thickness (int): track line thickness value
"""

points = np.hstack(track).astype(np.int32).reshape((-1, 1, 2))
cv2.polylines(self.im, [points], isClosed=False, color=color, thickness=track_thickness)
cv2.circle(self.im, (int(track[-1][0]), int(track[-1][1])), track_thickness * 2, color, -1)
Expand Down Expand Up @@ -513,6 +615,7 @@ def estimate_pose_angle(a, b, c):
Returns:
angle (degree): Degree value of angle between three points
"""

a, b, c = np.array(a), np.array(b), np.array(c)
radians = np.arctan2(c[1] - b[1], c[0] - b[0]) - np.arctan2(a[1] - b[1], a[0] - b[0])
angle = np.abs(radians * 180.0 / np.pi)
Expand All @@ -530,6 +633,7 @@ def draw_specific_points(self, keypoints, indices=None, shape=(640, 640), radius
shape (tuple): imgsz for model inference
radius (int): Keypoint radius value
"""

if indices is None:
indices = [2, 5, 7]
for i, k in enumerate(keypoints):
Expand Down Expand Up @@ -626,6 +730,7 @@ def seg_bbox(self, mask, mask_color=(255, 0, 255), det_label=None, track_label=N
det_label (str): Detection label text
track_label (str): Tracking label text
"""

cv2.polylines(self.im, [np.int32([mask])], isClosed=True, color=mask_color, thickness=2)

label = f"Track ID: {track_label}" if track_label else det_label
Expand Down Expand Up @@ -695,6 +800,7 @@ def visioneye(self, box, center_point, color=(235, 219, 11), pin_color=(255, 0,
color (tuple): object centroid and line color value
pin_color (tuple): visioneye point color value
"""

center_bbox = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)
cv2.circle(self.im, center_point, self.tf * 2, pin_color, -1)
cv2.circle(self.im, center_bbox, self.tf * 2, color, -1)
Expand Down

0 comments on commit 3f90100

Please sign in to comment.