Skip to content

Commit

Permalink
Merge pull request huggingface#4 from Superb-AI-Suite/dev
Browse files Browse the repository at this point in the history
minor chore detr and remove obj in mask2former
  • Loading branch information
SangbumChoi authored Apr 2, 2024
2 parents 395b650 + 8604f8c commit 3f13ad3
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 8 deletions.
5 changes: 0 additions & 5 deletions src/transformers/models/detr/image_processing_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1855,11 +1855,6 @@ def post_process_instance_segmentation(
num_classes = class_queries_logits.shape[-1] - 1
num_queries = class_queries_logits.shape[-2]

# mask_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width]

# # Predicted label and score of each query (batch_size, num_queries)
# pred_scores, pred_labels = nn.functional.softmax(class_queries_logits, dim=-1).max(-1)

# Loop over items in batch size
results: List[Dict[str, TensorType]] = []

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1121,18 +1121,22 @@ def post_process_instance_segmentation(
pred_scores = scores_per_image * mask_scores_per_image
pred_classes = labels_per_image

mask_pred, pred_scores, pred_classes = remove_low_and_no_objects(
mask_pred, pred_scores, pred_classes, threshold, num_classes
)

segmentation = torch.zeros((384, 384)) - 1
if target_sizes is not None:
size = target_sizes[i] if isinstance(target_sizes[i], tuple) else target_sizes[i].cpu().tolist()
segmentation = torch.zeros(size) - 1
pred_masks = torch.nn.functional.interpolate(
pred_masks.unsqueeze(0), size=size, mode="nearest"
pred_masks.unsqueeze(0).cpu(), size=size, mode="nearest"
)[0]

instance_maps, segments = [], []
current_segment_id = 0
for j in range(num_queries):
score = pred_scores[j].item()
for j, score in enumerate(pred_scores):
score = score.item()

if not torch.all(pred_masks[j] == 0) and score >= threshold:
segmentation[pred_masks[j] == 1] = current_segment_id
Expand Down

0 comments on commit 3f13ad3

Please sign in to comment.