import torch
from mmdet.core import bbox2result, bbox_mapping_back, multiclass_nms
from ..builder import DETECTORS
from .single_stage import SingleStageDetector
[docs]@DETECTORS.register_module()
class RepPointsDetector(SingleStageDetector):
"""RepPoints: Point Set Representation for Object Detection.
This detector is the implementation of:
- RepPoints detector (https://arxiv.org/pdf/1904.11490)
"""
def __init__(self,
backbone,
neck,
bbox_head,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(RepPointsDetector,
self).__init__(backbone, neck, bbox_head, train_cfg, test_cfg,
pretrained)
[docs] def merge_aug_results(self, aug_bboxes, aug_scores, img_metas):
"""Merge augmented detection bboxes and scores.
Args:
aug_bboxes (list[Tensor]): shape (n, 4*#class)
aug_scores (list[Tensor] or None): shape (n, #class)
img_shapes (list[Tensor]): shape (3, ).
Returns:
tuple: (bboxes, scores)
"""
recovered_bboxes = []
for bboxes, img_info in zip(aug_bboxes, img_metas):
img_shape = img_info[0]['img_shape']
scale_factor = img_info[0]['scale_factor']
flip = img_info[0]['flip']
flip_direction = img_info[0]['flip_direction']
bboxes = bbox_mapping_back(bboxes, img_shape, scale_factor, flip,
flip_direction)
recovered_bboxes.append(bboxes)
bboxes = torch.cat(recovered_bboxes, dim=0)
if aug_scores is None:
return bboxes
else:
scores = torch.cat(aug_scores, dim=0)
return bboxes, scores
def aug_test(self, imgs, img_metas, rescale=False):
# recompute feats to save memory
feats = self.extract_feats(imgs)
aug_bboxes = []
aug_scores = []
for x, img_meta in zip(feats, img_metas):
# only one image in the batch
outs = self.bbox_head(x)
bbox_inputs = outs + (img_metas, self.test_cfg, False, False)
det_bboxes, det_scores = self.bbox_head.get_bboxes(*bbox_inputs)[0]
aug_bboxes.append(det_bboxes)
aug_scores.append(det_scores)
# after merging, bboxes will be rescaled to the original image size
merged_bboxes, merged_scores = self.merge_aug_results(
aug_bboxes, aug_scores, img_metas)
det_bboxes, det_labels = multiclass_nms(merged_bboxes, merged_scores,
self.test_cfg.score_thr,
self.test_cfg.nms,
self.test_cfg.max_per_img)
if rescale:
_det_bboxes = det_bboxes
else:
_det_bboxes = det_bboxes.clone()
_det_bboxes[:, :4] *= det_bboxes.new_tensor(
img_metas[0][0]['scale_factor'])
bbox_results = bbox2result(_det_bboxes, det_labels,
self.bbox_head.num_classes)
return bbox_results