-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix generate_proposal_labels in cascade-rcnn series model, test=develop #27892
fix generate_proposal_labels in cascade-rcnn series model, test=develop #27892
Conversation
Thanks for your contribution! |
✅ This PR's description meets the template requirements! |
ba40645
to
8ed941c
Compare
… fix_generate_proposals_labels
template <typename T> | ||
void MaxOverlaps(const framework::Tensor& iou, | ||
framework::Tensor* max_overlaps) { | ||
const T* proposal_to_gt_overlaps = iou.data<T>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
作为util的通用函数,命名也需要更通用。 proposal_to_gt_overlaps -> iou_data
@@ -149,5 +149,19 @@ void ClipTiledBoxes(const platform::DeviceContext& ctx, | |||
} | |||
} | |||
|
|||
template <typename T> | |||
void MaxOverlaps(const framework::Tensor& iou, | |||
framework::Tensor* max_overlaps) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
max_overlaps -> max_iou?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
const T* proposal_to_gt_overlaps = iou.data<T>(); | ||
int row = iou.dims()[0]; | ||
int col = iou.dims()[1]; | ||
T* max_overlaps_dt = max_overlaps->data<T>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
max_overlaps_dt -> max_iou_data
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -2651,25 +2653,29 @@ def generate_proposal_labels(rpn_rois, | |||
use_random(bool): Use random sampling to choose foreground and background boxes. | |||
is_cls_agnostic(bool): bbox regression use class agnostic simply which only represent fg and bg boxes. | |||
is_cascade_rcnn(bool): it will filter some bbox crossing the image's boundary when setting True. | |||
max_overlap(Variable): Maximum overlap between each Input box and ground-truth. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Input box -> each proposal box .. ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -2651,25 +2653,29 @@ def generate_proposal_labels(rpn_rois, | |||
use_random(bool): Use random sampling to choose foreground and background boxes. | |||
is_cls_agnostic(bool): bbox regression use class agnostic simply which only represent fg and bg boxes. | |||
is_cascade_rcnn(bool): it will filter some bbox crossing the image's boundary when setting True. | |||
max_overlap(Variable): Maximum overlap between each Input box and ground-truth. | |||
return_max_overlap(bool): Whether return the maximum overlap between each Output box and ground-truth. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上,Output box更精确的描述
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
is_cascade_rcnn=False): | ||
is_cascade_rcnn=False, | ||
max_overlap=None, | ||
return_max_overlap=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
max_overlap是否始终return,而不加这个参数
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里希望可以兼容PaddleDetection先前的版本,如果始终return,https://github.com/PaddlePaddle/PaddleDetection/blob/release/0.4/ppdet/modeling/architectures/mask_rcnn.py#L125 计算 loss的位置会受到影响
gt_boxes = fluid.data( | ||
name='gt_boxes', shape=[6, 4], dtype='float32', lod_level=1) | ||
im_info = fluid.data( | ||
name='im_info', shape=[1, 3], dtype='float32', lod_level=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
im_info的lod_level是0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thx
max_overlap_slice = | ||
max_overlap->Slice(rpn_rois_lod[i], rpn_rois_lod[i + 1]); | ||
} else { | ||
max_overlap_slice.mutable_data<T>(im_info_slice.dims(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的shape是和im_info_slice相同吗? im_info_slice的shape是[1, 3]?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thx
AddInput("MaxOverlap", | ||
"(LoDTensor), This input is a 1D LoDTensor with shape [N]." | ||
"N is the number of Input(RpnRois), " | ||
"each element is the maxoverlap between " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maxoverlap -> max overlap
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
for (int i = 0; i < rois_num; ++i) { | ||
if ((rpn_rois_dt[i * 4 + 2] - rpn_rois_dt[i * 4 + 0] + 1) > 0 && | ||
(rpn_rois_dt[i * 4 + 3] - rpn_rois_dt[i * 4 + 1] + 1) > 0 && | ||
max_overlap_dt[i] < 1.) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add comments for why filter max_overlap_dt < 1.0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
考虑到只训练使用,这里不做兼容要求
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
generate_proposal_labels is used in static graph
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Function optimization
PR changes
OPs
Describe
Fix generate_proposal_labels in cascade-rcnn series model and add max_overlap as input.
This op is only used at the stage of training so it will have no effect on inference.
This op is only used in static mode so core.ops,xxx is not used as well.
The mAP of related models:
Cascade Faster RCNN is 40.8 and previous version is 40.9.
HTC is 42.9/37.0 and previous is 42.7/36.8
The difference on document is marked as below:
data:image/s3,"s3://crabby-images/8b403/8b403b43ff3a380aedd45452810f2b7856de149b" alt="image"
data:image/s3,"s3://crabby-images/cd55f/cd55fc76b1e9cf62eda37e5e99a55003ba6fe068" alt="image"
data:image/s3,"s3://crabby-images/fe2ef/fe2ef4847076c00463e54fa1c6d7fe44ae7e2d1e" alt="image"