Skip to content

Commit

Permalink
added use bbox switcher
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxTeselkin committed Sep 13, 2024
1 parent 9302ba9 commit 4ccb8a2
Showing 1 changed file with 41 additions and 21 deletions.
62 changes: 41 additions & 21 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from supervisely._utils import rand_str, is_debug_with_sly_net
from supervisely.app.content import get_data_dir
import supervisely.app.development as sly_app_development
from supervisely.app.widgets import Switch, Field
from sam2.build_sam import build_sam2, build_sam2_video_predictor
from sam2.sam2_image_predictor import SAM2ImagePredictor
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
Expand All @@ -40,6 +41,20 @@

class SegmentAnything2(sly.nn.inference.PromptableSegmentation):

def add_content_to_pretrained_tab(self, gui):
self.use_bbox = Switch(switched=True)
use_bbox_field = Field(
content=self.use_bbox,
title="Use bounding box prompt",
description=(
"Define whether to use bounding box prompt when labeling images and videos or not. "
"If turned off, then only point prompts (positive and negative clicks) will be used. "
"Adding bounding box prompt can be useful when labeling entire objects, while using "
"only point prompts can be better when segmenting specific parts of objects."
),
)
return use_bbox_field

def support_custom_models(self):
return False

Expand Down Expand Up @@ -364,20 +379,21 @@ def generate_artificial_prompt(self, bitmap):
col_index += origin_col
point_coordinates.append([col_index, row_index])
point_labels.append(1)
# generate box prompt
rectangle = bitmap.to_bbox()
padding = 0.03
bbox = [
rectangle.left * (1 - padding),
rectangle.top * (1 - padding),
rectangle.right * (1 + padding),
rectangle.bottom * (1 + padding),
]
prompt = {
"point_coordinates": point_coordinates,
"point_labels": point_labels,
"bbox": bbox,
}
if self.use_bbox.is_switched():
# generate box prompt
rectangle = bitmap.to_bbox()
padding = 0.03
bbox = [
rectangle.left * (1 - padding),
rectangle.top * (1 - padding),
rectangle.right * (1 + padding),
rectangle.bottom * (1 + padding),
]
prompt["bbox"] = bbox
return prompt

def serve(self):
Expand Down Expand Up @@ -492,7 +508,10 @@ def smart_segmentation(response: Response, request: Request):
try:
# predict
logger.debug("Preparing settings for inference request...")
settings["mode"] = "combined"
if self.use_bbox.is_switched():
settings["mode"] = "combined"
else:
settings["mode"] = "points"
if "image_id" in smtool_state:
settings["input_image_id"] = smtool_state["image_id"]
elif "video" in smtool_state:
Expand Down Expand Up @@ -736,15 +755,16 @@ def track(request: Request):
try:
bbox_str = bitmap_frame_data[bitmap_center_str]
figure_prompt = frame_prompts[bbox_str]
top, left, bottom, right = bbox_str.split("-")
padding = 0.03
bbox = [
int(left) * (1 - padding),
int(top) * (1 - padding),
int(right) * (1 + padding),
int(bottom) * (1 + padding),
]
figure_prompt["bbox"] = bbox
if self.use_bbox.is_switched():
top, left, bottom, right = bbox_str.split("-")
padding = 0.03
bbox = [
int(left) * (1 - padding),
int(top) * (1 - padding),
int(right) * (1 + padding),
int(bottom) * (1 + padding),
]
figure_prompt["bbox"] = bbox
except Exception:
mode = "artificial clicks"
if mode == "artificial clicks":
Expand All @@ -764,7 +784,7 @@ def track(request: Request):
fig_prompt = figure_data["figure_prompt"]
point_coordinates = fig_prompt["point_coordinates"]
point_labels = fig_prompt["point_labels"]
bbox = fig_prompt["bbox"]
bbox = fig_prompt.get("bbox")
_, out_obj_ids, out_mask_logits = video_predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=0,
Expand Down

0 comments on commit 4ccb8a2

Please sign in to comment.