From 50e0bec0a09391c02eb328cf5940cc73f33c66a6 Mon Sep 17 00:00:00 2001 From: Kentaro Wada Date: Tue, 30 Jul 2024 19:18:58 +0900 Subject: [PATCH] Implement "AI Text to Rectangles" --- labelme/ai/__init__.py | 3 + labelme/ai/text_to_annotation.py | 92 +++++++++++++++++++++++ labelme/app.py | 72 +++++++++++++++++- labelme/widgets/__init__.py | 2 + labelme/widgets/ai_prompt_widget.py | 112 ++++++++++++++++++++++++++++ setup.py | 1 + 6 files changed, 281 insertions(+), 1 deletion(-) create mode 100644 labelme/ai/text_to_annotation.py create mode 100644 labelme/widgets/ai_prompt_widget.py diff --git a/labelme/ai/__init__.py b/labelme/ai/__init__.py index 1dad86774..717e4924b 100644 --- a/labelme/ai/__init__.py +++ b/labelme/ai/__init__.py @@ -2,6 +2,9 @@ from .efficient_sam import EfficientSam from .segment_anything_model import SegmentAnythingModel +from .text_to_annotation import get_rectangles_from_texts # NOQA: F401 +from .text_to_annotation import get_shapes_from_annotations # NOQA: F401 +from .text_to_annotation import non_maximum_suppression # NOQA: F401 class SegmentAnythingModelVitB(SegmentAnythingModel): diff --git a/labelme/ai/text_to_annotation.py b/labelme/ai/text_to_annotation.py new file mode 100644 index 000000000..35ca4ba7c --- /dev/null +++ b/labelme/ai/text_to_annotation.py @@ -0,0 +1,92 @@ +import json +import time + +import numpy as np +import osam + +from labelme.logger import logger + + +def get_rectangles_from_texts( + model: str, image: np.ndarray, texts: list[str] +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + request: osam.types.GenerateRequest = osam.types.GenerateRequest( + model=model, + image=image, + prompt=osam.types.Prompt( + texts=texts, + iou_threshold=1.0, + score_threshold=0.01, + max_annotations=1000, + ), + ) + logger.debug( + f"Requesting with model={model!r}, image={(image.shape, image.dtype)}, " + f"prompt={request.prompt!r}" + ) + t_start = time.time() + response: osam.types.GenerateResponse = osam.apis.generate(request=request) + + num_annotations = len(response.annotations) + logger.debug( + f"Response: num_annotations={num_annotations}, " + f"elapsed_time={time.time() - t_start:.3f} [s]" + ) + + boxes: np.ndarray = np.empty((num_annotations, 4), dtype=np.float32) + scores: np.ndarray = np.empty((num_annotations,), dtype=np.float32) + labels: np.ndarray = np.empty((num_annotations,), dtype=np.int32) + for i, annotation in enumerate(response.annotations): + boxes[i] = [ + annotation.bounding_box.xmin, + annotation.bounding_box.ymin, + annotation.bounding_box.xmax, + annotation.bounding_box.ymax, + ] + scores[i] = annotation.score + labels[i] = texts.index(annotation.text) + + return boxes, scores, labels + + +def non_maximum_suppression( + boxes: np.ndarray, + scores: np.ndarray, + labels: np.ndarray, + iou_threshold: float, + score_threshold: float, + max_num_detections: int, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + num_classes = np.max(labels) + 1 + scores_of_all_classes = np.zeros((len(boxes), num_classes), dtype=np.float32) + for i, (score, label) in enumerate(zip(scores, labels)): + scores_of_all_classes[i, label] = score + logger.debug(f"Input: num_boxes={len(boxes)}") + boxes, scores, labels = osam.apis.non_maximum_suppression( + boxes=boxes, + scores=scores_of_all_classes, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + max_num_detections=max_num_detections, + ) + logger.debug(f"Output: num_boxes={len(boxes)}") + return boxes, scores, labels + + +def get_shapes_from_annotations( + boxes: np.ndarray, scores: np.ndarray, labels: np.ndarray, texts: list[str] +) -> list[dict]: + shapes: list[dict] = [] + for box, score, label in zip(boxes.tolist(), scores.tolist(), labels.tolist()): + text = texts[label] + xmin, ymin, xmax, ymax = box + shape = { + "label": text, + "points": [[xmin, ymin], [xmax, ymax]], + "group_id": None, + "shape_type": "rectangle", + "flags": {}, + "description": json.dumps(dict(score=score, text=text)), + } + shapes.append(shape) + return shapes diff --git a/labelme/app.py b/labelme/app.py index 9a46469aa..7bbce4936 100644 --- a/labelme/app.py +++ b/labelme/app.py @@ -18,12 +18,14 @@ from labelme import PY2 from labelme import __appname__ +from labelme import ai from labelme.ai import MODELS from labelme.config import get_config from labelme.label_file import LabelFile from labelme.label_file import LabelFileError from labelme.logger import logger from labelme.shape import Shape +from labelme.widgets import AiPromptWidget from labelme.widgets import BrightnessContrastDialog from labelme.widgets import Canvas from labelme.widgets import FileDialogPreview @@ -784,7 +786,7 @@ def __init__( selectAiModel.setDefaultWidget(QtWidgets.QWidget()) selectAiModel.defaultWidget().setLayout(QtWidgets.QVBoxLayout()) # - selectAiModelLabel = QtWidgets.QLabel(self.tr("AI Model")) + selectAiModelLabel = QtWidgets.QLabel(self.tr("AI Mask Model")) selectAiModelLabel.setAlignment(QtCore.Qt.AlignCenter) selectAiModel.defaultWidget().layout().addWidget(selectAiModelLabel) # @@ -809,6 +811,12 @@ def __init__( else None ) + self._ai_prompt_widget: QtWidgets.QWidget = AiPromptWidget( + on_submit=self._submit_ai_prompt, parent=self + ) + ai_prompt_action = QtWidgets.QWidgetAction(self) + ai_prompt_action.setDefaultWidget(self._ai_prompt_widget) + self.tools = self.toolbar("Tools") self.actions.tool = ( open_, @@ -829,6 +837,8 @@ def __init__( zoom, None, selectAiModel, + None, + ai_prompt_action, ) self.statusBar().showMessage(str(self.tr("%s started.")) % __appname__) @@ -989,6 +999,66 @@ def queueEvent(self, function): def status(self, message, delay=5000): self.statusBar().showMessage(message, delay) + def _submit_ai_prompt(self, _) -> None: + texts = self._ai_prompt_widget.get_text_prompt().split(",") + boxes, scores, labels = ai.get_rectangles_from_texts( + model="yoloworld", + image=utils.img_qt_to_arr(self.image)[:, :, :3], + texts=texts, + ) + + for shape in self.canvas.shapes: + if shape.shape_type != "rectangle" or shape.label not in texts: + continue + box = np.array( + [ + shape.points[0].x(), + shape.points[0].y(), + shape.points[1].x(), + shape.points[1].y(), + ], + dtype=np.float32, + ) + boxes = np.r_[boxes, [box]] + scores = np.r_[scores, [1.01]] + labels = np.r_[labels, [texts.index(shape.label)]] + + boxes, scores, labels = ai.non_maximum_suppression( + boxes=boxes, + scores=scores, + labels=labels, + iou_threshold=self._ai_prompt_widget.get_iou_threshold(), + score_threshold=self._ai_prompt_widget.get_score_threshold(), + max_num_detections=100, + ) + + keep = scores != 1.01 + boxes = boxes[keep] + scores = scores[keep] + labels = labels[keep] + + shape_dicts: list[dict] = ai.get_shapes_from_annotations( + boxes=boxes, + scores=scores, + labels=labels, + texts=texts, + ) + + shapes: list[Shape] = [] + for shape_dict in shape_dicts: + shape = Shape( + label=shape_dict["label"], + shape_type=shape_dict["shape_type"], + description=shape_dict["description"], + ) + for point in shape_dict["points"]: + shape.addPoint(QtCore.QPointF(*point)) + shapes.append(shape) + + self.canvas.storeShapes() + self.loadShapes(shapes, replace=False) + self.setDirty() + def resetState(self): self.labelList.clear() self.filename = None diff --git a/labelme/widgets/__init__.py b/labelme/widgets/__init__.py index 999cc4551..6283ef1e8 100644 --- a/labelme/widgets/__init__.py +++ b/labelme/widgets/__init__.py @@ -1,5 +1,7 @@ # flake8: noqa +from .ai_prompt_widget import AiPromptWidget + from .brightness_contrast_dialog import BrightnessContrastDialog from .canvas import Canvas diff --git a/labelme/widgets/ai_prompt_widget.py b/labelme/widgets/ai_prompt_widget.py new file mode 100644 index 000000000..c37c9ab95 --- /dev/null +++ b/labelme/widgets/ai_prompt_widget.py @@ -0,0 +1,112 @@ +from qtpy import QtWidgets + + +class AiPromptWidget(QtWidgets.QWidget): + def __init__(self, on_submit, parent=None): + super().__init__(parent=parent) + + self.setLayout(QtWidgets.QVBoxLayout()) + self.layout().setSpacing(0) + + text_prompt_widget = _TextPromptWidget(on_submit=on_submit, parent=self) + text_prompt_widget.setMaximumWidth(400) + self.layout().addWidget(text_prompt_widget) + + nms_params_widget = _NmsParamsWidget(parent=self) + nms_params_widget.setMaximumWidth(400) + self.layout().addWidget(nms_params_widget) + + def get_text_prompt(self) -> str: + text_prompt_widget: QtWidgets.QWidget = self.layout().itemAt(0).widget() + return text_prompt_widget.get_text_prompt() + + def get_iou_threshold(self) -> float: + nms_params_widget = self.layout().itemAt(1).widget() + return nms_params_widget.get_iou_threshold() + + def get_score_threshold(self) -> float: + nms_params_widget = self.layout().itemAt(1).widget() + return nms_params_widget.get_score_threshold() + + +class _TextPromptWidget(QtWidgets.QWidget): + def __init__(self, on_submit, parent=None): + super().__init__(parent=parent) + + self.setLayout(QtWidgets.QHBoxLayout()) + self.layout().setContentsMargins(0, 0, 0, 0) + + label = QtWidgets.QLabel(self.tr("AI Prompt")) + self.layout().addWidget(label) + + texts_widget = QtWidgets.QLineEdit() + texts_widget.setPlaceholderText(self.tr("e.g., dog,cat,bird")) + self.layout().addWidget(texts_widget) + + submit_button = QtWidgets.QPushButton(text="Submit", parent=self) + submit_button.clicked.connect(slot=on_submit) + self.layout().addWidget(submit_button) + + def get_text_prompt(self) -> str: + texts_widget: QtWidgets.QWidget = self.layout().itemAt(1).widget() + return texts_widget.text() + + +class _NmsParamsWidget(QtWidgets.QWidget): + def __init__(self, parent=None): + super().__init__(parent=parent) + + self.setLayout(QtWidgets.QHBoxLayout()) + self.layout().setContentsMargins(0, 0, 0, 0) + self.layout().addWidget(_ScoreThresholdWidget(parent=parent)) + self.layout().addWidget(_IouThresholdWidget(parent=parent)) + + def get_score_threshold(self) -> float: + score_threshold_widget: QtWidgets.QWidget = self.layout().itemAt(0).widget() + return score_threshold_widget.get_value() + + def get_iou_threshold(self) -> float: + iou_threshold_widget: QtWidgets.QWidget = self.layout().itemAt(1).widget() + return iou_threshold_widget.get_value() + + +class _ScoreThresholdWidget(QtWidgets.QWidget): + def __init__(self, parent=None): + super().__init__(parent=parent) + + self.setLayout(QtWidgets.QHBoxLayout()) + self.layout().setContentsMargins(0, 0, 0, 0) + + label = QtWidgets.QLabel(self.tr("Score Threshold")) + self.layout().addWidget(label) + + threshold_widget: QtWidgets.QWidget = QtWidgets.QDoubleSpinBox() + threshold_widget.setRange(0, 1) + threshold_widget.setSingleStep(0.05) + threshold_widget.setValue(0.1) + self.layout().addWidget(threshold_widget) + + def get_value(self) -> float: + threshold_widget: QtWidgets.QWidget = self.layout().itemAt(1).widget() + return threshold_widget.value() + + +class _IouThresholdWidget(QtWidgets.QWidget): + def __init__(self, parent=None): + super().__init__(parent=parent) + + self.setLayout(QtWidgets.QHBoxLayout()) + self.layout().setContentsMargins(0, 0, 0, 0) + + label = QtWidgets.QLabel(self.tr("IoU Threshold")) + self.layout().addWidget(label) + + threshold_widget: QtWidgets.QWidget = QtWidgets.QDoubleSpinBox() + threshold_widget.setRange(0, 1) + threshold_widget.setSingleStep(0.05) + threshold_widget.setValue(0.5) + self.layout().addWidget(threshold_widget) + + def get_value(self) -> float: + threshold_widget: QtWidgets.QWidget = self.layout().itemAt(1).widget() + return threshold_widget.value() diff --git a/setup.py b/setup.py index 2cc1e134d..799ef02d0 100644 --- a/setup.py +++ b/setup.py @@ -27,6 +27,7 @@ def get_install_requires(): "natsort>=7.1.0", "numpy", "onnxruntime>=1.14.1,!=1.16.0", + "osam>=0.2.1", "Pillow>=2.8", "PyYAML", "qtpy!=1.11.2",