Source code for mmdet.models.dense_heads.anchor_free_head

from abc import abstractmethod

import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmcv.runner import force_fp32

from mmdet.core import multi_apply
from ..builder import HEADS, build_loss
from .base_dense_head import BaseDenseHead
from .dense_test_mixins import BBoxTestMixin


[docs]@HEADS.register_module() class AnchorFreeHead(BaseDenseHead, BBoxTestMixin): """Anchor-free head (FCOS, Fovea, RepPoints, etc.). Args: num_classes (int): Number of categories excluding the background category. in_channels (int): Number of channels in the input feature map. feat_channels (int): Number of hidden channels. Used in child classes. stacked_convs (int): Number of stacking convs of the head. strides (tuple): Downsample factor of each feature map. dcn_on_last_conv (bool): If true, use dcn in the last layer of towers. Default: False. conv_bias (bool | str): If specified as `auto`, it will be decided by the norm_cfg. Bias of conv will be set as True if `norm_cfg` is None, otherwise False. Default: "auto". loss_cls (dict): Config of classification loss. loss_bbox (dict): Config of localization loss. conv_cfg (dict): Config dict for convolution layer. Default: None. norm_cfg (dict): Config dict for normalization layer. Default: None. train_cfg (dict): Training config of anchor head. test_cfg (dict): Testing config of anchor head. init_cfg (dict or list[dict], optional): Initialization config dict. """ # noqa: W605 _version = 1 def __init__(self, num_classes, in_channels, feat_channels=256, stacked_convs=4, strides=(4, 8, 16, 32, 64), dcn_on_last_conv=False, conv_bias='auto', loss_cls=dict( type='FocalLoss', use_sigmoid=True, gamma=2.0, alpha=0.25, loss_weight=1.0), loss_bbox=dict(type='IoULoss', loss_weight=1.0), conv_cfg=None, norm_cfg=None, train_cfg=None, test_cfg=None, init_cfg=dict( type='Normal', layer='Conv2d', std=0.01, override=dict( type='Normal', name='conv_cls', std=0.01, bias_prob=0.01))): super(AnchorFreeHead, self).__init__(init_cfg) self.num_classes = num_classes self.cls_out_channels = num_classes self.in_channels = in_channels self.feat_channels = feat_channels self.stacked_convs = stacked_convs self.strides = strides self.dcn_on_last_conv = dcn_on_last_conv assert conv_bias == 'auto' or isinstance(conv_bias, bool) self.conv_bias = conv_bias self.loss_cls = build_loss(loss_cls) self.loss_bbox = build_loss(loss_bbox) self.train_cfg = train_cfg self.test_cfg = test_cfg self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.fp16_enabled = False self._init_layers() def _init_layers(self): """Initialize layers of the head.""" self._init_cls_convs() self._init_reg_convs() self._init_predictor() def _init_cls_convs(self): """Initialize classification conv layers of the head.""" self.cls_convs = nn.ModuleList() for i in range(self.stacked_convs): chn = self.in_channels if i == 0 else self.feat_channels if self.dcn_on_last_conv and i == self.stacked_convs - 1: conv_cfg = dict(type='DCNv2') else: conv_cfg = self.conv_cfg self.cls_convs.append( ConvModule( chn, self.feat_channels, 3, stride=1, padding=1, conv_cfg=conv_cfg, norm_cfg=self.norm_cfg, bias=self.conv_bias)) def _init_reg_convs(self): """Initialize bbox regression conv layers of the head.""" self.reg_convs = nn.ModuleList() for i in range(self.stacked_convs): chn = self.in_channels if i == 0 else self.feat_channels if self.dcn_on_last_conv and i == self.stacked_convs - 1: conv_cfg = dict(type='DCNv2') else: conv_cfg = self.conv_cfg self.reg_convs.append( ConvModule( chn, self.feat_channels, 3, stride=1, padding=1, conv_cfg=conv_cfg, norm_cfg=self.norm_cfg, bias=self.conv_bias)) def _init_predictor(self): """Initialize predictor layers of the head.""" self.conv_cls = nn.Conv2d( self.feat_channels, self.cls_out_channels, 3, padding=1) self.conv_reg = nn.Conv2d(self.feat_channels, 4, 3, padding=1) def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): """Hack some keys of the model state dict so that can load checkpoints of previous version.""" version = local_metadata.get('version', None) if version is None: # the key is different in early versions # for example, 'fcos_cls' become 'conv_cls' now bbox_head_keys = [ k for k in state_dict.keys() if k.startswith(prefix) ] ori_predictor_keys = [] new_predictor_keys = [] # e.g. 'fcos_cls' or 'fcos_reg' for key in bbox_head_keys: ori_predictor_keys.append(key) key = key.split('.') conv_name = None if key[1].endswith('cls'): conv_name = 'conv_cls' elif key[1].endswith('reg'): conv_name = 'conv_reg' elif key[1].endswith('centerness'): conv_name = 'conv_centerness' else: assert NotImplementedError if conv_name is not None: key[1] = conv_name new_predictor_keys.append('.'.join(key)) else: ori_predictor_keys.pop(-1) for i in range(len(new_predictor_keys)): state_dict[new_predictor_keys[i]] = state_dict.pop( ori_predictor_keys[i]) super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
[docs] def forward(self, feats): """Forward features from the upstream network. Args: feats (tuple[Tensor]): Features from the upstream network, each is a 4D-tensor. Returns: tuple: Usually contain classification scores and bbox predictions. cls_scores (list[Tensor]): Box scores for each scale level, each is a 4D-tensor, the channel number is num_points * num_classes. bbox_preds (list[Tensor]): Box energies / deltas for each scale level, each is a 4D-tensor, the channel number is num_points * 4. """ return multi_apply(self.forward_single, feats)[:2]
[docs] def forward_single(self, x): """Forward features of a single scale level. Args: x (Tensor): FPN feature maps of the specified stride. Returns: tuple: Scores for each class, bbox predictions, features after classification and regression conv layers, some models needs these features like FCOS. """ cls_feat = x reg_feat = x for cls_layer in self.cls_convs: cls_feat = cls_layer(cls_feat) cls_score = self.conv_cls(cls_feat) for reg_layer in self.reg_convs: reg_feat = reg_layer(reg_feat) bbox_pred = self.conv_reg(reg_feat) return cls_score, bbox_pred, cls_feat, reg_feat
[docs] @abstractmethod @force_fp32(apply_to=('cls_scores', 'bbox_preds')) def loss(self, cls_scores, bbox_preds, gt_bboxes, gt_labels, img_metas, gt_bboxes_ignore=None): """Compute loss of the head. Args: cls_scores (list[Tensor]): Box scores for each scale level, each is a 4D-tensor, the channel number is num_points * num_classes. bbox_preds (list[Tensor]): Box energies / deltas for each scale level, each is a 4D-tensor, the channel number is num_points * 4. gt_bboxes (list[Tensor]): Ground truth bboxes for each image with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. gt_labels (list[Tensor]): class indices corresponding to each box img_metas (list[dict]): Meta information of each image, e.g., image size, scaling factor, etc. gt_bboxes_ignore (None | list[Tensor]): specify which bounding boxes can be ignored when computing the loss. """ raise NotImplementedError
[docs] @abstractmethod @force_fp32(apply_to=('cls_scores', 'bbox_preds')) def get_bboxes(self, cls_scores, bbox_preds, img_metas, cfg=None, rescale=None): """Transform network output for a batch into bbox predictions. Args: cls_scores (list[Tensor]): Box scores for each scale level Has shape (N, num_points * num_classes, H, W) bbox_preds (list[Tensor]): Box energies / deltas for each scale level with shape (N, num_points * 4, H, W) img_metas (list[dict]): Meta information of each image, e.g., image size, scaling factor, etc. cfg (mmcv.Config): Test / postprocessing configuration, if None, test_cfg would be used rescale (bool): If True, return boxes in original image space """ raise NotImplementedError
[docs] @abstractmethod def get_targets(self, points, gt_bboxes_list, gt_labels_list): """Compute regression, classification and centerness targets for points in multiple images. Args: points (list[Tensor]): Points of each fpn level, each has shape (num_points, 2). gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image, each has shape (num_gt, 4). gt_labels_list (list[Tensor]): Ground truth labels of each box, each has shape (num_gt,). """ raise NotImplementedError
def _get_points_single(self, featmap_size, stride, dtype, device, flatten=False): """Get points of a single scale level.""" h, w = featmap_size # First create Range with the default dtype, than convert to # target `dtype` for onnx exporting. x_range = torch.arange(w, device=device).to(dtype) y_range = torch.arange(h, device=device).to(dtype) y, x = torch.meshgrid(y_range, x_range) if flatten: y = y.flatten() x = x.flatten() return y, x
[docs] def get_points(self, featmap_sizes, dtype, device, flatten=False): """Get points according to feature map sizes. Args: featmap_sizes (list[tuple]): Multi-level feature map sizes. dtype (torch.dtype): Type of points. device (torch.device): Device of points. Returns: tuple: points of each image. """ mlvl_points = [] for i in range(len(featmap_sizes)): mlvl_points.append( self._get_points_single(featmap_sizes[i], self.strides[i], dtype, device, flatten)) return mlvl_points
[docs] def aug_test(self, feats, img_metas, rescale=False): """Test function with test time augmentation. Args: feats (list[Tensor]): the outer list indicates test-time augmentations and inner Tensor should have a shape NxCxHxW, which contains features for all images in the batch. img_metas (list[list[dict]]): the outer list indicates test-time augs (multiscale, flip, etc.) and the inner list indicates images in a batch. each dict has image information. rescale (bool, optional): Whether to rescale the results. Defaults to False. Returns: list[ndarray]: bbox results of each class """ return self.aug_test_bboxes(feats, img_metas, rescale=rescale)