Shortcuts

Source code for mmdet.models.detectors.cornernet

# Copyright (c) OpenMMLab. All rights reserved.
import torch

from mmdet.core import bbox2result, bbox_mapping_back
from ..builder import DETECTORS
from .single_stage import SingleStageDetector


[docs]@DETECTORS.register_module() class CornerNet(SingleStageDetector): """CornerNet. This detector is the implementation of the paper `CornerNet: Detecting Objects as Paired Keypoints <https://arxiv.org/abs/1808.01244>`_ . """ def __init__(self, backbone, neck, bbox_head, train_cfg=None, test_cfg=None, pretrained=None, init_cfg=None): super(CornerNet, self).__init__(backbone, neck, bbox_head, train_cfg, test_cfg, pretrained, init_cfg)
[docs] def merge_aug_results(self, aug_results, img_metas): """Merge augmented detection bboxes and score. Args: aug_results (list[list[Tensor]]): Det_bboxes and det_labels of each image. img_metas (list[list[dict]]): Meta information of each image, e.g., image size, scaling factor, etc. Returns: tuple: (bboxes, labels) """ recovered_bboxes, aug_labels = [], [] for bboxes_labels, img_info in zip(aug_results, img_metas): img_shape = img_info[0]['img_shape'] # using shape before padding scale_factor = img_info[0]['scale_factor'] flip = img_info[0]['flip'] bboxes, labels = bboxes_labels bboxes, scores = bboxes[:, :4], bboxes[:, -1:] bboxes = bbox_mapping_back(bboxes, img_shape, scale_factor, flip) recovered_bboxes.append(torch.cat([bboxes, scores], dim=-1)) aug_labels.append(labels) bboxes = torch.cat(recovered_bboxes, dim=0) labels = torch.cat(aug_labels) if bboxes.shape[0] > 0: out_bboxes, out_labels = self.bbox_head._bboxes_nms( bboxes, labels, self.bbox_head.test_cfg) else: out_bboxes, out_labels = bboxes, labels return out_bboxes, out_labels
[docs] def aug_test(self, imgs, img_metas, rescale=False): """Augment testing of CornerNet. Args: imgs (list[Tensor]): Augmented images. img_metas (list[list[dict]]): Meta information of each image, e.g., image size, scaling factor, etc. rescale (bool): If True, return boxes in original image space. Default: False. Note: ``imgs`` must including flipped image pairs. Returns: list[list[np.ndarray]]: BBox results of each image and classes. The outer list corresponds to each image. The inner list corresponds to each class. """ img_inds = list(range(len(imgs))) assert img_metas[0][0]['flip'] + img_metas[1][0]['flip'], ( 'aug test must have flipped image pair') aug_results = [] for ind, flip_ind in zip(img_inds[0::2], img_inds[1::2]): img_pair = torch.cat([imgs[ind], imgs[flip_ind]]) x = self.extract_feat(img_pair) outs = self.bbox_head(x) bbox_list = self.bbox_head.get_bboxes( *outs, [img_metas[ind], img_metas[flip_ind]], False, False) aug_results.append(bbox_list[0]) aug_results.append(bbox_list[1]) bboxes, labels = self.merge_aug_results(aug_results, img_metas) bbox_results = bbox2result(bboxes, labels, self.bbox_head.num_classes) return [bbox_results]
Read the Docs v: v2.25.0
Versions
latest
stable
v2.25.0
v2.24.1
v2.24.0
v2.23.0
v2.22.0
v2.21.0
v2.20.0
v2.19.1
v2.19.0
v2.18.1
v2.18.0
v2.17.0
v2.16.0
v2.15.1
v2.15.0
v2.14.0
v2.13.0
v2.12.0
v2.11.0
v2.10.0
v2.9.0
v2.8.0
v2.7.0
v2.6.0
v2.5.0
v2.4.0
v2.3.0
v2.2.1
v2.2.0
v2.1.0
v2.0.0
v1.2.0
dev
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.