Source code for mmdet.models.dense_heads.rpn_head

import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import normal_init

from mmdet.ops import batched_nms
from ..builder import HEADS
from .anchor_head import AnchorHead


[docs]@HEADS.register_module() class RPNHead(AnchorHead): def __init__(self, in_channels, **kwargs): super(RPNHead, self).__init__( 1, in_channels, background_label=0, **kwargs) def _init_layers(self): self.rpn_conv = nn.Conv2d( self.in_channels, self.feat_channels, 3, padding=1) self.rpn_cls = nn.Conv2d(self.feat_channels, self.num_anchors * self.cls_out_channels, 1) self.rpn_reg = nn.Conv2d(self.feat_channels, self.num_anchors * 4, 1) def init_weights(self): normal_init(self.rpn_conv, std=0.01) normal_init(self.rpn_cls, std=0.01) normal_init(self.rpn_reg, std=0.01) def forward_single(self, x): x = self.rpn_conv(x) x = F.relu(x, inplace=True) rpn_cls_score = self.rpn_cls(x) rpn_bbox_pred = self.rpn_reg(x) return rpn_cls_score, rpn_bbox_pred def loss(self, cls_scores, bbox_preds, gt_bboxes, img_metas, gt_bboxes_ignore=None): losses = super(RPNHead, self).loss( cls_scores, bbox_preds, gt_bboxes, None, img_metas, gt_bboxes_ignore=gt_bboxes_ignore) return dict( loss_rpn_cls=losses['loss_cls'], loss_rpn_bbox=losses['loss_bbox']) def _get_bboxes_single(self, cls_scores, bbox_preds, mlvl_anchors, img_shape, scale_factor, cfg, rescale=False): cfg = self.test_cfg if cfg is None else cfg # bboxes from different level should be independent during NMS, # level_ids are used as labels for batched NMS to separate them level_ids = [] mlvl_scores = [] mlvl_bbox_preds = [] mlvl_valid_anchors = [] for idx in range(len(cls_scores)): rpn_cls_score = cls_scores[idx] rpn_bbox_pred = bbox_preds[idx] assert rpn_cls_score.size()[-2:] == rpn_bbox_pred.size()[-2:] rpn_cls_score = rpn_cls_score.permute(1, 2, 0) if self.use_sigmoid_cls: rpn_cls_score = rpn_cls_score.reshape(-1) scores = rpn_cls_score.sigmoid() else: rpn_cls_score = rpn_cls_score.reshape(-1, 2) # remind that we set FG labels to [0, num_class-1] # since mmdet v2.0 # BG cat_id: num_class scores = rpn_cls_score.softmax(dim=1)[:, :-1] rpn_bbox_pred = rpn_bbox_pred.permute(1, 2, 0).reshape(-1, 4) anchors = mlvl_anchors[idx] if cfg.nms_pre > 0 and scores.shape[0] > cfg.nms_pre: # sort is faster than topk # _, topk_inds = scores.topk(cfg.nms_pre) ranked_scores, rank_inds = scores.sort(descending=True) topk_inds = rank_inds[:cfg.nms_pre] scores = ranked_scores[:cfg.nms_pre] rpn_bbox_pred = rpn_bbox_pred[topk_inds, :] anchors = anchors[topk_inds, :] mlvl_scores.append(scores) mlvl_bbox_preds.append(rpn_bbox_pred) mlvl_valid_anchors.append(anchors) level_ids.append( scores.new_full((scores.size(0), ), idx, dtype=torch.long)) scores = torch.cat(mlvl_scores) anchors = torch.cat(mlvl_valid_anchors) rpn_bbox_pred = torch.cat(mlvl_bbox_preds) proposals = self.bbox_coder.decode( anchors, rpn_bbox_pred, max_shape=img_shape) ids = torch.cat(level_ids) if cfg.min_bbox_size > 0: w = proposals[:, 2] - proposals[:, 0] h = proposals[:, 3] - proposals[:, 1] valid_inds = torch.nonzero( (w >= cfg.min_bbox_size) & (h >= cfg.min_bbox_size), as_tuple=False).squeeze() if valid_inds.sum().item() != len(proposals): proposals = proposals[valid_inds, :] scores = scores[valid_inds] ids = ids[valid_inds] # TODO: remove the hard coded nms type nms_cfg = dict(type='nms', iou_thr=cfg.nms_thr) dets, keep = batched_nms(proposals, scores, ids, nms_cfg) return dets[:cfg.nms_post]