Skip to content

Commit

Permalink
added multimask output support for ambiguous prompts
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxTeselkin committed Sep 27, 2024
1 parent 6a84427 commit e7e8cca
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 38 deletions.
2 changes: 1 addition & 1 deletion .devcontainer/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ RUN SAM2_BUILD_ALLOW_ERRORS=0 pip3 install -v -e ".[demo]"
RUN python3 setup.py clean --all
RUN python3 setup.py build_ext --inplace

RUN python3 -m pip install supervisely==6.73.145
RUN python3 -m pip install supervisely==6.73.171

RUN apt-get install ffmpeg libgeos-dev libsm6 libxext6 libexiv2-dev libxrender-dev libboost-all-dev -y
RUN pip install opencv-python
Expand Down
105 changes: 68 additions & 37 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@
from supervisely.sly_logger import logger
from supervisely.imaging import image as sly_image
from supervisely.io.fs import silent_remove
from supervisely._utils import rand_str, is_debug_with_sly_net
from supervisely._utils import rand_str
from supervisely.app.content import get_data_dir
import supervisely.app.development as sly_app_development
from supervisely.app.widgets import Switch, Field
from sam2.build_sam import build_sam2, build_sam2_video_predictor
from sam2.sam2_image_predictor import SAM2ImagePredictor
Expand Down Expand Up @@ -257,12 +256,21 @@ def predict(
self.set_image_data(input_image, settings)
self.previous_image_id = settings["input_image_id"]
# get predicted masks
masks, _, _ = self.predictor.predict(
point_coords=point_coordinates,
point_labels=point_labels,
multimask_output=False,
)
mask = masks[0]
if len(point_labels) > 1:
masks, _, _ = self.predictor.predict(
point_coords=point_coordinates,
point_labels=point_labels,
multimask_output=False,
)
mask = masks[0]
else:
masks, scores, logits = self.predictor.predict(
point_coords=point_coordinates,
point_labels=point_labels,
multimask_output=True,
)
max_score_ind = np.argmax(scores)
mask = masks[max_score_ind]
predictions.append(sly.nn.PredictionMask(class_name=class_name, mask=mask))
elif settings["mode"] == "combined":
# get point coordinates
Expand Down Expand Up @@ -305,13 +313,24 @@ def predict(
mask_input = self.model_cache.get(settings["input_image_id"])[
"mask_input"
]
masks, scores, logits = self.predictor.predict(
point_coords=point_coordinates,
point_labels=point_labels,
box=bbox_coordinates[None, :],
mask_input=mask_input[None, :, :],
multimask_output=False,
)
if len(point_labels) > 1:
masks, scores, logits = self.predictor.predict(
point_coords=point_coordinates,
point_labels=point_labels,
box=bbox_coordinates[None, :],
mask_input=mask_input[None, :, :],
multimask_output=False,
)
else:
masks, scores, logits = self.predictor.predict(
point_coords=point_coordinates,
point_labels=point_labels,
box=bbox_coordinates[None, :],
mask_input=mask_input[None, :, :],
multimask_output=True,
)
max_score_ind = np.argmax(scores)
masks = [masks[max_score_ind]]
elif init_mask is not None:
# transform
mask_input = self.predictor.transform.apply_image(init_mask)
Expand All @@ -328,20 +347,41 @@ def predict(
mask_input = mask_input.astype(float)
mask_input[mask_input > 0] = 20
mask_input[mask_input <= 0] = -20
masks, scores, logits = self.predictor.predict(
point_coords=point_coordinates,
point_labels=point_labels,
box=bbox_coordinates[None, :],
mask_input=mask_input[None, :, :],
multimask_output=False,
)
if len(point_labels) > 1:
masks, scores, logits = self.predictor.predict(
point_coords=point_coordinates,
point_labels=point_labels,
box=bbox_coordinates[None, :],
mask_input=mask_input[None, :, :],
multimask_output=False,
)
else:
masks, scores, logits = self.predictor.predict(
point_coords=point_coordinates,
point_labels=point_labels,
box=bbox_coordinates[None, :],
mask_input=mask_input[None, :, :],
multimask_output=True,
)
max_score_ind = np.argmax(scores)
masks = [masks[max_score_ind]]
else:
masks, scores, logits = self.predictor.predict(
point_coords=point_coordinates,
point_labels=point_labels,
box=bbox_coordinates[None, :],
multimask_output=False,
)
if len(point_labels) > 1:
masks, scores, logits = self.predictor.predict(
point_coords=point_coordinates,
point_labels=point_labels,
box=bbox_coordinates[None, :],
multimask_output=False,
)
else:
masks, scores, logits = self.predictor.predict(
point_coords=point_coordinates,
point_labels=point_labels,
box=bbox_coordinates[None, :],
multimask_output=True,
)
max_score_ind = np.argmax(scores)
masks = [masks[max_score_ind]]
# save bbox ccordinates and mask to cache
if settings["input_image_id"] in self.model_cache:
image_id = settings["input_image_id"]
Expand Down Expand Up @@ -847,15 +887,6 @@ def track(request: Request):
sly.logger.info("Successfully finished tracking process")


if is_debug_with_sly_net():
team_id = sly.env.team_id()
original_dir = os.getcwd()
sly_app_development.supervisely_vpn_network(action="up")
task = sly_app_development.create_debug_task(
team_id, port="8000", update_status=True
)
os.chdir(original_dir)

m = SegmentAnything2(
use_gui=True,
model_dir="app_data",
Expand Down

0 comments on commit e7e8cca

Please sign in to comment.