Source code for mmdet.models.detectors.cornernet

import torch

from mmdet.core import bbox2result, bbox_mapping_back
from ..builder import DETECTORS
from .single_stage import SingleStageDetector


[docs]@DETECTORS.register_module() class CornerNet(SingleStageDetector): """CornerNet. This detector is the implementation of the paper `CornerNet: Detecting Objects as Paired Keypoints <https://arxiv.org/abs/1808.01244>`_ . """ def __init__(self, backbone, neck, bbox_head, train_cfg=None, test_cfg=None, pretrained=None, init_cfg=None): super(CornerNet, self).__init__(backbone, neck, bbox_head, train_cfg, test_cfg, pretrained, init_cfg)
[docs] def merge_aug_results(self, aug_results, img_metas): """Merge augmented detection bboxes and score. Args: aug_results (list[list[Tensor]]): Det_bboxes and det_labels of each image. img_metas (list[list[dict]]): Meta information of each image, e.g., image size, scaling factor, etc. Returns: tuple: (bboxes, labels) """ recovered_bboxes, aug_labels = [], [] for bboxes_labels, img_info in zip(aug_results, img_metas): img_shape = img_info[0]['img_shape'] # using shape before padding scale_factor = img_info[0]['scale_factor'] flip = img_info[0]['flip'] bboxes, labels = bboxes_labels bboxes, scores = bboxes[:, :4], bboxes[:, -1:] bboxes = bbox_mapping_back(bboxes, img_shape, scale_factor, flip) recovered_bboxes.append(torch.cat([bboxes, scores], dim=-1)) aug_labels.append(labels) bboxes = torch.cat(recovered_bboxes, dim=0) labels = torch.cat(aug_labels) if bboxes.shape[0] > 0: out_bboxes, out_labels = self.bbox_head._bboxes_nms( bboxes, labels, self.bbox_head.test_cfg) else: out_bboxes, out_labels = bboxes, labels return out_bboxes, out_labels
[docs] def aug_test(self, imgs, img_metas, rescale=False): """Augment testing of CornerNet. Args: imgs (list[Tensor]): Augmented images. img_metas (list[list[dict]]): Meta information of each image, e.g., image size, scaling factor, etc. rescale (bool): If True, return boxes in original image space. Default: False. Note: ``imgs`` must including flipped image pairs. Returns: list[list[np.ndarray]]: BBox results of each image and classes. The outer list corresponds to each image. The inner list corresponds to each class. """ img_inds = list(range(len(imgs))) assert img_metas[0][0]['flip'] + img_metas[1][0]['flip'], ( 'aug test must have flipped image pair') aug_results = [] for ind, flip_ind in zip(img_inds[0::2], img_inds[1::2]): img_pair = torch.cat([imgs[ind], imgs[flip_ind]]) x = self.extract_feat(img_pair) outs = self.bbox_head(x) bbox_list = self.bbox_head.get_bboxes( *outs, [img_metas[ind], img_metas[flip_ind]], False, False) aug_results.append(bbox_list[0]) aug_results.append(bbox_list[1]) bboxes, labels = self.merge_aug_results(aug_results, img_metas) bbox_results = bbox2result(bboxes, labels, self.bbox_head.num_classes) return [bbox_results]