Skip to content

Commit

Permalink
Fix for comments
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyag-grovety committed Jul 21, 2023
1 parent 80aa621 commit 2f8ae58
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 7 deletions.
2 changes: 1 addition & 1 deletion python/tvm/relay/op/vision/_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def _regular_nms_shape_func(boxes_shape, scores_shape, attrs):
out_scores_shape[0] = boxes_shape[0]
out_scores_shape[1] = int64(attrs.max_detections)

out_num_detections_shape[0] = int64(1)
out_num_detections_shape[0] = boxes_shape[0]

return out_boxes_shape, out_classes_shape, out_scores_shape, out_num_detections_shape

Expand Down
8 changes: 2 additions & 6 deletions python/tvm/topi/cuda/ssd/multibox.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,9 +363,7 @@ def transform_loc(loc, loc_base_idx, anchor, anchor_base_idx, clip, vx, vy, vw,
with ib.if_scope(tvm.tir.any(keep_background == 1, cls_id[tid] > 0)):
with ib.if_scope(j == 0):
out_base_idx = i * num_anchors * 6
out_loc[out_base_idx] = (
cls_id[tid] - 0.0 if keep_background == 1 else cls_id[tid] - 1.0
)
out_loc[out_base_idx] = cls_id[tid] if keep_background == 1 else cls_id[tid] - 1.0
out_loc[out_base_idx + 1] = score[tid]
(
out_loc[out_base_idx + 2],
Expand All @@ -385,9 +383,7 @@ def transform_loc(loc, loc_base_idx, anchor, anchor_base_idx, clip, vx, vy, vw,
)
with ib.else_scope():
out_base_idx = i * num_anchors * 6 + temp_valid_count[tid - 1] * 6
out_loc[out_base_idx] = (
cls_id[tid] - 0.0 if keep_background == 1 else cls_id[tid] - 1.0
)
out_loc[out_base_idx] = cls_id[tid] if keep_background == 1 else cls_id[tid] - 1.0
out_loc[out_base_idx + 1] = score[tid]
(
out_loc[out_base_idx + 2],
Expand Down

0 comments on commit 2f8ae58

Please sign in to comment.