Shortcuts

Source code for mmdet.models.dense_heads.embedding_rpn_head

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.runner import BaseModule

from mmdet.models.builder import HEADS
from ...core import bbox_cxcywh_to_xyxy


[docs]@HEADS.register_module() class EmbeddingRPNHead(BaseModule): """RPNHead in the `Sparse R-CNN <https://arxiv.org/abs/2011.12450>`_ . Unlike traditional RPNHead, this module does not need FPN input, but just decode `init_proposal_bboxes` and expand the first dimension of `init_proposal_bboxes` and `init_proposal_features` to the batch_size. Args: num_proposals (int): Number of init_proposals. Default 100. proposal_feature_channel (int): Channel number of init_proposal_feature. Defaults to 256. init_cfg (dict or list[dict], optional): Initialization config dict. Default: None """ def __init__(self, num_proposals=100, proposal_feature_channel=256, init_cfg=None, **kwargs): assert init_cfg is None, 'To prevent abnormal initialization ' \ 'behavior, init_cfg is not allowed to be set' super(EmbeddingRPNHead, self).__init__(init_cfg) self.num_proposals = num_proposals self.proposal_feature_channel = proposal_feature_channel self._init_layers() def _init_layers(self): """Initialize a sparse set of proposal boxes and proposal features.""" self.init_proposal_bboxes = nn.Embedding(self.num_proposals, 4) self.init_proposal_features = nn.Embedding( self.num_proposals, self.proposal_feature_channel)
[docs] def init_weights(self): """Initialize the init_proposal_bboxes as normalized. [c_x, c_y, w, h], and we initialize it to the size of the entire image. """ super(EmbeddingRPNHead, self).init_weights() nn.init.constant_(self.init_proposal_bboxes.weight[:, :2], 0.5) nn.init.constant_(self.init_proposal_bboxes.weight[:, 2:], 1)
def _decode_init_proposals(self, imgs, img_metas): """Decode init_proposal_bboxes according to the size of images and expand dimension of init_proposal_features to batch_size. Args: imgs (list[Tensor]): List of FPN features. img_metas (list[dict]): List of meta-information of images. Need the img_shape to decode the init_proposals. Returns: Tuple(Tensor): - proposals (Tensor): Decoded proposal bboxes, has shape (batch_size, num_proposals, 4). - init_proposal_features (Tensor): Expanded proposal features, has shape (batch_size, num_proposals, proposal_feature_channel). - imgs_whwh (Tensor): Tensor with shape (batch_size, 4), the dimension means [img_width, img_height, img_width, img_height]. """ proposals = self.init_proposal_bboxes.weight.clone() proposals = bbox_cxcywh_to_xyxy(proposals) num_imgs = len(imgs[0]) imgs_whwh = [] for meta in img_metas: h, w, _ = meta['img_shape'] imgs_whwh.append(imgs[0].new_tensor([[w, h, w, h]])) imgs_whwh = torch.cat(imgs_whwh, dim=0) imgs_whwh = imgs_whwh[:, None, :] # imgs_whwh has shape (batch_size, 1, 4) # The shape of proposals change from (num_proposals, 4) # to (batch_size ,num_proposals, 4) proposals = proposals * imgs_whwh init_proposal_features = self.init_proposal_features.weight.clone() init_proposal_features = init_proposal_features[None].expand( num_imgs, *init_proposal_features.size()) return proposals, init_proposal_features, imgs_whwh
[docs] def forward_dummy(self, img, img_metas): """Dummy forward function. Used in flops calculation. """ return self._decode_init_proposals(img, img_metas)
[docs] def forward_train(self, img, img_metas): """Forward function in training stage.""" return self._decode_init_proposals(img, img_metas)
[docs] def simple_test_rpn(self, img, img_metas): """Forward function in testing stage.""" return self._decode_init_proposals(img, img_metas)
[docs] def simple_test(self, img, img_metas): """Forward function in testing stage.""" raise NotImplementedError
def aug_test_rpn(self, feats, img_metas): raise NotImplementedError( 'EmbeddingRPNHead does not support test-time augmentation')
Read the Docs v: v2.18.1
Versions
latest
stable
v2.18.1
v2.18.0
v2.17.0
v2.16.0
v2.15.1
v2.15.0
v2.14.0
v2.13.0
v2.12.0
v2.11.0
v2.10.0
v2.9.0
v2.8.0
v2.7.0
v2.6.0
v2.5.0
v2.4.0
v2.3.0
v2.2.1
v2.2.0
v2.1.0
v2.0.0
v1.2.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.