Skip to content

Commit

Permalink
Use configurable for StandardROIHeads
Browse files Browse the repository at this point in the history
Reviewed By: rbgirshick

Differential Revision: D21386044

fbshipit-source-id: 80fb5481dbaa9bd6c53ed4d594e64108e92ae7a3
  • Loading branch information
ppwwyyxx authored and facebook-github-bot committed May 16, 2020
1 parent 806d9ca commit 67c30d0
Show file tree
Hide file tree
Showing 8 changed files with 229 additions and 93 deletions.
2 changes: 1 addition & 1 deletion detectron2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@

# This line will be programatically read/write by setup.py.
# Leave them at the bottom of this file and don't touch them.
__version__ = "0.1.2"
__version__ = "0.1.3"
2 changes: 1 addition & 1 deletion detectron2/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def from_config(cls, cfg):
assert init_func.__name__ == "__init__", "@configurable should only be used for __init__!"
if init_func.__module__.startswith("detectron2."):
assert (
"experimental" in init_func.__doc__
init_func.__doc__ is not None and "experimental" in init_func.__doc__
), f"configurable {init_func} should be marked experimental"

@functools.wraps(init_func)
Expand Down
96 changes: 72 additions & 24 deletions detectron2/modeling/roi_heads/cascade_rcnn.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import List
import torch
from torch import nn
from torch.autograd.function import Function

from detectron2.config import configurable
from detectron2.layers import ShapeSpec
from detectron2.structures import Boxes, Instances, pairwise_iou
from detectron2.utils.events import get_event_storage
Expand Down Expand Up @@ -32,27 +34,77 @@ class CascadeROIHeads(StandardROIHeads):
Implement :paper:`Cascade R-CNN`.
"""

def _init_box_head(self, cfg, input_shape):
@configurable
def __init__(
self,
*,
box_in_features: List[str],
box_pooler: ROIPooler,
box_heads: List[nn.Module],
box_predictors: List[nn.Module],
proposal_matchers: List[Matcher],
**kwargs,
):
"""
NOTE: this interface is experimental.
Args:
box_pooler (ROIPooler): pooler that extracts region features from given boxes
box_heads (list[nn.Module]): box head for each cascade stage
box_predictors (list[nn.Module]): box predictor for each cascade stage
proposal_matchers (list[Matcher]): matcher with different IoU thresholds to
match boxes with ground truth for each stage. The first matcher matches
RPN proposals with ground truth, the other matchers use boxes predicted
by the previous stage as proposals and match them with ground truth.
"""
assert "proposal_matcher" not in kwargs, (
"CascadeROIHeads takes 'proposal_matchers=' for each stage instead "
"of one 'proposal_matcher='."
)
# The first matcher matches RPN proposals with ground truth, done in the base class
kwargs["proposal_matcher"] = proposal_matchers[0]
num_stages = self.num_cascade_stages = len(box_heads)
box_heads = nn.ModuleList(box_heads)
box_predictors = nn.ModuleList(box_predictors)
assert len(box_predictors) == num_stages, f"{len(box_predictors)} != {num_stages}!"
assert len(proposal_matchers) == num_stages, f"{len(proposal_matchers)} != {num_stages}!"
super().__init__(
box_in_features=box_in_features,
box_pooler=box_pooler,
box_head=box_heads,
box_predictor=box_predictors,
**kwargs,
)
self.proposal_matchers = proposal_matchers

@classmethod
def from_config(cls, cfg, input_shape):
ret = super().from_config(cfg, input_shape)
ret.pop("proposal_matcher")
return ret

@classmethod
def _init_box_head(cls, cfg, input_shape):
# fmt: off
in_features = cfg.MODEL.ROI_HEADS.IN_FEATURES
pooler_resolution = cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION
pooler_scales = tuple(1.0 / input_shape[k].stride for k in self.in_features)
pooler_scales = tuple(1.0 / input_shape[k].stride for k in in_features)
sampling_ratio = cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO
pooler_type = cfg.MODEL.ROI_BOX_HEAD.POOLER_TYPE
cascade_bbox_reg_weights = cfg.MODEL.ROI_BOX_CASCADE_HEAD.BBOX_REG_WEIGHTS
cascade_ious = cfg.MODEL.ROI_BOX_CASCADE_HEAD.IOUS
self.num_cascade_stages = len(cascade_ious)
assert len(cascade_bbox_reg_weights) == self.num_cascade_stages
assert len(cascade_bbox_reg_weights) == len(cascade_ious)
assert cfg.MODEL.ROI_BOX_HEAD.CLS_AGNOSTIC_BBOX_REG, \
"CascadeROIHeads only support class-agnostic regression now!"
assert cascade_ious[0] == cfg.MODEL.ROI_HEADS.IOU_THRESHOLDS[0]
# fmt: on

in_channels = [input_shape[f].channels for f in self.in_features]
in_channels = [input_shape[f].channels for f in in_features]
# Check all channel counts are equal
assert len(set(in_channels)) == 1, in_channels
in_channels = in_channels[0]

self.box_pooler = ROIPooler(
box_pooler = ROIPooler(
output_size=pooler_resolution,
scales=pooler_scales,
sampling_ratio=sampling_ratio,
Expand All @@ -62,29 +114,25 @@ def _init_box_head(self, cfg, input_shape):
channels=in_channels, width=pooler_resolution, height=pooler_resolution
)

self.box_head = nn.ModuleList()
self.box_predictor = nn.ModuleList()
self.box2box_transform = []
self.proposal_matchers = []
for k in range(self.num_cascade_stages):
box_heads, box_predictors, proposal_matchers = [], [], []
for match_iou, bbox_reg_weights in zip(cascade_ious, cascade_bbox_reg_weights):
box_head = build_box_head(cfg, pooled_shape)
self.box_head.append(box_head)
# NOTE: use list of predictor in explicit args?
self.box_predictor.append(
box_heads.append(box_head)
box_predictors.append(
FastRCNNOutputLayers(
cfg,
box_head.output_shape,
box2box_transform=Box2BoxTransform(weights=cascade_bbox_reg_weights[k]),
box2box_transform=Box2BoxTransform(weights=bbox_reg_weights),
)
)

if k == 0:
# The first matching is done by the matcher of ROIHeads (self.proposal_matcher).
self.proposal_matchers.append(None)
else:
self.proposal_matchers.append(
Matcher([cascade_ious[k]], [0, 1], allow_low_quality_matches=False)
)
proposal_matchers.append(Matcher([match_iou], [0, 1], allow_low_quality_matches=False))
return {
"box_in_features": in_features,
"box_pooler": box_pooler,
"box_heads": box_heads,
"box_predictors": box_predictors,
"proposal_matchers": proposal_matchers,
}

def forward(self, images, features, proposals, targets=None):
del images
Expand Down Expand Up @@ -112,7 +160,7 @@ def _forward_box(self, features, proposals, targets=None):
Each has fields "proposal_boxes", and "objectness_logits",
"gt_classes", "gt_boxes".
"""
features = [features[f] for f in self.in_features]
features = [features[f] for f in self.box_in_features]
head_outputs = [] # (predictor, predictions, proposals)
prev_pred_boxes = None
image_sizes = [x.image_size for x in proposals]
Expand Down
4 changes: 2 additions & 2 deletions detectron2/modeling/roi_heads/fast_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def _predict_boxes(self):

"""
A subclass is expected to have the following methods because
they are used to query information about the head predictions.0
they are used to query information about the head predictions.
"""

def losses(self):
Expand Down Expand Up @@ -370,7 +370,7 @@ def __init__(
test_topk_per_image (int): number of top predictions to produce per image.
"""
super().__init__()
if isinstance(input_shape, int): # some backward compatbility
if isinstance(input_shape, int): # some backward compatibility
input_shape = ShapeSpec(channels=input_shape)
input_size = input_shape.channels * (input_shape.width or 1) * (input_shape.height or 1)
# The prediction layer for num_classes foreground classes and one background class
Expand Down
Loading

0 comments on commit 67c30d0

Please sign in to comment.