Skip to content

Commit

Permalink
fix : add if statement when n_queries * n_cls is less than 10000
Browse files Browse the repository at this point in the history
  • Loading branch information
SangbumChoi committed Aug 4, 2023
1 parent 985fa0b commit 4ea705b
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions models/deformable_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,11 +486,12 @@ def forward(self, outputs, target_sizes):
score = all_scores[b]
lbls = all_labels[b]

pre_topk = score.topk(10000).indices
box = box[pre_topk]
score = score[pre_topk]
lbls = lbls[pre_topk]

if n_queries * n_cls > 10000:
pre_topk = score.topk(10000).indices
box = box[pre_topk]
score = score[pre_topk]
lbls = lbls[pre_topk]

keep_inds = batched_nms(box, score, lbls, 0.7)[:100]
results.append({
'scores': score[keep_inds],
Expand Down

0 comments on commit 4ea705b

Please sign in to comment.