Shortcuts

Source code for mmdet.models.detectors.centernet

# Copyright (c) OpenMMLab. All rights reserved.
import torch

from mmdet.core import bbox2result
from mmdet.models.builder import DETECTORS
from ...core.utils import flip_tensor
from .single_stage import SingleStageDetector


[docs]@DETECTORS.register_module() class CenterNet(SingleStageDetector): """Implementation of CenterNet(Objects as Points) <https://arxiv.org/abs/1904.07850>. """ def __init__(self, backbone, neck, bbox_head, train_cfg=None, test_cfg=None, pretrained=None, init_cfg=None): super(CenterNet, self).__init__(backbone, neck, bbox_head, train_cfg, test_cfg, pretrained, init_cfg)
[docs] def merge_aug_results(self, aug_results, with_nms): """Merge augmented detection bboxes and score. Args: aug_results (list[list[Tensor]]): Det_bboxes and det_labels of each image. with_nms (bool): If True, do nms before return boxes. Returns: tuple: (out_bboxes, out_labels) """ recovered_bboxes, aug_labels = [], [] for single_result in aug_results: recovered_bboxes.append(single_result[0][0]) aug_labels.append(single_result[0][1]) bboxes = torch.cat(recovered_bboxes, dim=0).contiguous() labels = torch.cat(aug_labels).contiguous() if with_nms: out_bboxes, out_labels = self.bbox_head._bboxes_nms( bboxes, labels, self.bbox_head.test_cfg) else: out_bboxes, out_labels = bboxes, labels return out_bboxes, out_labels
[docs] def aug_test(self, imgs, img_metas, rescale=True): """Augment testing of CenterNet. Aug test must have flipped image pair, and unlike CornerNet, it will perform an averaging operation on the feature map instead of detecting bbox. Args: imgs (list[Tensor]): Augmented images. img_metas (list[list[dict]]): Meta information of each image, e.g., image size, scaling factor, etc. rescale (bool): If True, return boxes in original image space. Default: True. Note: ``imgs`` must including flipped image pairs. Returns: list[list[np.ndarray]]: BBox results of each image and classes. The outer list corresponds to each image. The inner list corresponds to each class. """ img_inds = list(range(len(imgs))) assert img_metas[0][0]['flip'] + img_metas[1][0]['flip'], ( 'aug test must have flipped image pair') aug_results = [] for ind, flip_ind in zip(img_inds[0::2], img_inds[1::2]): flip_direction = img_metas[flip_ind][0]['flip_direction'] img_pair = torch.cat([imgs[ind], imgs[flip_ind]]) x = self.extract_feat(img_pair) center_heatmap_preds, wh_preds, offset_preds = self.bbox_head(x) assert len(center_heatmap_preds) == len(wh_preds) == len( offset_preds) == 1 # Feature map averaging center_heatmap_preds[0] = ( center_heatmap_preds[0][0:1] + flip_tensor(center_heatmap_preds[0][1:2], flip_direction)) / 2 wh_preds[0] = (wh_preds[0][0:1] + flip_tensor(wh_preds[0][1:2], flip_direction)) / 2 bbox_list = self.bbox_head.get_bboxes( center_heatmap_preds, wh_preds, [offset_preds[0][0:1]], img_metas[ind], rescale=rescale, with_nms=False) aug_results.append(bbox_list) nms_cfg = self.bbox_head.test_cfg.get('nms_cfg', None) if nms_cfg is None: with_nms = False else: with_nms = True bbox_list = [self.merge_aug_results(aug_results, with_nms)] bbox_results = [ bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes) for det_bboxes, det_labels in bbox_list ] return bbox_results
Read the Docs v: v2.21.0
Versions
latest
stable
v2.21.0
v2.20.0
v2.19.1
v2.19.0
v2.18.1
v2.18.0
v2.17.0
v2.16.0
v2.15.1
v2.15.0
v2.14.0
v2.13.0
v2.12.0
v2.11.0
v2.10.0
v2.9.0
v2.8.0
v2.7.0
v2.6.0
v2.5.0
v2.4.0
v2.3.0
v2.2.1
v2.2.0
v2.1.0
v2.0.0
v1.2.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.