Source code for mmdet.models.roi_heads.standard_roi_head

import torch

from mmdet.core import bbox2result, bbox2roi, build_assigner, build_sampler
from ..builder import HEADS, build_head, build_roi_extractor
from .base_roi_head import BaseRoIHead
from .test_mixins import BBoxTestMixin, MaskTestMixin


[docs]@HEADS.register_module() class StandardRoIHead(BaseRoIHead, BBoxTestMixin, MaskTestMixin): """Simplest base roi head including one bbox head and one mask head."""
[docs] def init_assigner_sampler(self): """Initialize assigner and sampler.""" self.bbox_assigner = None self.bbox_sampler = None if self.train_cfg: self.bbox_assigner = build_assigner(self.train_cfg.assigner) self.bbox_sampler = build_sampler( self.train_cfg.sampler, context=self)
[docs] def init_bbox_head(self, bbox_roi_extractor, bbox_head): """Initialize ``bbox_head``""" self.bbox_roi_extractor = build_roi_extractor(bbox_roi_extractor) self.bbox_head = build_head(bbox_head)
[docs] def init_mask_head(self, mask_roi_extractor, mask_head): """Initialize ``mask_head``""" if mask_roi_extractor is not None: self.mask_roi_extractor = build_roi_extractor(mask_roi_extractor) self.share_roi_extractor = False else: self.share_roi_extractor = True self.mask_roi_extractor = self.bbox_roi_extractor self.mask_head = build_head(mask_head)
[docs] def forward_dummy(self, x, proposals): """Dummy forward function.""" # bbox head outs = () rois = bbox2roi([proposals]) if self.with_bbox: bbox_results = self._bbox_forward(x, rois) outs = outs + (bbox_results['cls_score'], bbox_results['bbox_pred']) # mask head if self.with_mask: mask_rois = rois[:100] mask_results = self._mask_forward(x, mask_rois) outs = outs + (mask_results['mask_pred'], ) return outs
[docs] def forward_train(self, x, img_metas, proposal_list, gt_bboxes, gt_labels, gt_bboxes_ignore=None, gt_masks=None): """ Args: x (list[Tensor]): list of multi-level img features. img_metas (list[dict]): list of image info dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details on the values of these keys see `mmdet/datasets/pipelines/formatting.py:Collect`. proposals (list[Tensors]): list of region proposals. gt_bboxes (list[Tensor]): Ground truth bboxes for each image with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. gt_labels (list[Tensor]): class indices corresponding to each box gt_bboxes_ignore (None | list[Tensor]): specify which bounding boxes can be ignored when computing the loss. gt_masks (None | Tensor) : true segmentation masks for each box used if the architecture supports a segmentation task. Returns: dict[str, Tensor]: a dictionary of loss components """ # assign gts and sample proposals if self.with_bbox or self.with_mask: num_imgs = len(img_metas) if gt_bboxes_ignore is None: gt_bboxes_ignore = [None for _ in range(num_imgs)] sampling_results = [] for i in range(num_imgs): assign_result = self.bbox_assigner.assign( proposal_list[i], gt_bboxes[i], gt_bboxes_ignore[i], gt_labels[i]) sampling_result = self.bbox_sampler.sample( assign_result, proposal_list[i], gt_bboxes[i], gt_labels[i], feats=[lvl_feat[i][None] for lvl_feat in x]) sampling_results.append(sampling_result) losses = dict() # bbox head forward and loss if self.with_bbox: bbox_results = self._bbox_forward_train(x, sampling_results, gt_bboxes, gt_labels, img_metas) losses.update(bbox_results['loss_bbox']) # mask head forward and loss if self.with_mask: mask_results = self._mask_forward_train(x, sampling_results, bbox_results['bbox_feats'], gt_masks, img_metas) losses.update(mask_results['loss_mask']) return losses
def _bbox_forward(self, x, rois): """Box head forward function used in both training and testing.""" # TODO: a more flexible way to decide which feature maps to use bbox_feats = self.bbox_roi_extractor( x[:self.bbox_roi_extractor.num_inputs], rois) if self.with_shared_head: bbox_feats = self.shared_head(bbox_feats) cls_score, bbox_pred = self.bbox_head(bbox_feats) bbox_results = dict( cls_score=cls_score, bbox_pred=bbox_pred, bbox_feats=bbox_feats) return bbox_results def _bbox_forward_train(self, x, sampling_results, gt_bboxes, gt_labels, img_metas): """Run forward function and calculate loss for box head in training.""" rois = bbox2roi([res.bboxes for res in sampling_results]) bbox_results = self._bbox_forward(x, rois) bbox_targets = self.bbox_head.get_targets(sampling_results, gt_bboxes, gt_labels, self.train_cfg) loss_bbox = self.bbox_head.loss(bbox_results['cls_score'], bbox_results['bbox_pred'], rois, *bbox_targets) bbox_results.update(loss_bbox=loss_bbox) return bbox_results def _mask_forward_train(self, x, sampling_results, bbox_feats, gt_masks, img_metas): """Run forward function and calculate loss for mask head in training.""" if not self.share_roi_extractor: pos_rois = bbox2roi([res.pos_bboxes for res in sampling_results]) mask_results = self._mask_forward(x, pos_rois) else: pos_inds = [] device = bbox_feats.device for res in sampling_results: pos_inds.append( torch.ones( res.pos_bboxes.shape[0], device=device, dtype=torch.uint8)) pos_inds.append( torch.zeros( res.neg_bboxes.shape[0], device=device, dtype=torch.uint8)) pos_inds = torch.cat(pos_inds) mask_results = self._mask_forward( x, pos_inds=pos_inds, bbox_feats=bbox_feats) mask_targets = self.mask_head.get_targets(sampling_results, gt_masks, self.train_cfg) pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results]) loss_mask = self.mask_head.loss(mask_results['mask_pred'], mask_targets, pos_labels) mask_results.update(loss_mask=loss_mask, mask_targets=mask_targets) return mask_results def _mask_forward(self, x, rois=None, pos_inds=None, bbox_feats=None): """Mask head forward function used in both training and testing.""" assert ((rois is not None) ^ (pos_inds is not None and bbox_feats is not None)) if rois is not None: mask_feats = self.mask_roi_extractor( x[:self.mask_roi_extractor.num_inputs], rois) if self.with_shared_head: mask_feats = self.shared_head(mask_feats) else: assert bbox_feats is not None mask_feats = bbox_feats[pos_inds] mask_pred = self.mask_head(mask_feats) mask_results = dict(mask_pred=mask_pred, mask_feats=mask_feats) return mask_results
[docs] async def async_simple_test(self, x, proposal_list, img_metas, proposals=None, rescale=False): """Async test without augmentation.""" assert self.with_bbox, 'Bbox head must be implemented.' det_bboxes, det_labels = await self.async_test_bboxes( x, img_metas, proposal_list, self.test_cfg, rescale=rescale) bbox_results = bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes) if not self.with_mask: return bbox_results else: segm_results = await self.async_test_mask( x, img_metas, det_bboxes, det_labels, rescale=rescale, mask_test_cfg=self.test_cfg.get('mask')) return bbox_results, segm_results
[docs] def simple_test(self, x, proposal_list, img_metas, proposals=None, rescale=False): """Test without augmentation.""" assert self.with_bbox, 'Bbox head must be implemented.' det_bboxes, det_labels = self.simple_test_bboxes( x, img_metas, proposal_list, self.test_cfg, rescale=rescale) bbox_results = [ bbox2result(det_bboxes[i], det_labels[i], self.bbox_head.num_classes) for i in range(len(det_bboxes)) ] if not self.with_mask: return bbox_results else: segm_results = self.simple_test_mask( x, img_metas, det_bboxes, det_labels, rescale=rescale) return list(zip(bbox_results, segm_results))
[docs] def aug_test(self, x, proposal_list, img_metas, rescale=False): """Test with augmentations. If rescale is False, then returned bboxes and masks will fit the scale of imgs[0]. """ det_bboxes, det_labels = self.aug_test_bboxes(x, img_metas, proposal_list, self.test_cfg) if rescale: _det_bboxes = det_bboxes else: _det_bboxes = det_bboxes.clone() _det_bboxes[:, :4] *= det_bboxes.new_tensor( img_metas[0][0]['scale_factor']) bbox_results = bbox2result(_det_bboxes, det_labels, self.bbox_head.num_classes) # det_bboxes always keep the original scale if self.with_mask: segm_results = self.aug_test_mask(x, img_metas, det_bboxes, det_labels) return [(bbox_results, segm_results)] else: return [bbox_results]
[docs] def onnx_export(self, x, proposals, img_metas, rescale=False): """Test without augmentation.""" assert self.with_bbox, 'Bbox head must be implemented.' det_bboxes, det_labels = self.bbox_onnx_export( x, img_metas, proposals, self.test_cfg, rescale=rescale) if not self.with_mask: return det_bboxes, det_labels else: segm_results = self.mask_onnx_export( x, img_metas, det_bboxes, det_labels, rescale=rescale) return det_bboxes, det_labels, segm_results
[docs] def mask_onnx_export(self, x, img_metas, det_bboxes, det_labels, **kwargs): """Export mask branch to onnx which supports batch inference. Args: x (tuple[Tensor]): Feature maps of all scale level. img_metas (list[dict]): Image meta info. det_bboxes (Tensor): Bboxes and corresponding scores. has shape [N, num_bboxes, 5]. det_labels (Tensor): class labels of shape [N, num_bboxes]. Returns: Tensor: The segmentation results of shape [N, num_bboxes, image_height, image_width]. """ # image shapes of images in the batch if all(det_bbox.shape[0] == 0 for det_bbox in det_bboxes): raise RuntimeError('[ONNX Error] Can not record MaskHead ' 'as it has not been executed this time') batch_size = det_bboxes.size(0) # if det_bboxes is rescaled to the original image size, we need to # rescale it back to the testing scale to obtain RoIs. det_bboxes = det_bboxes[..., :4] batch_index = torch.arange( det_bboxes.size(0), device=det_bboxes.device).float().view( -1, 1, 1).expand(det_bboxes.size(0), det_bboxes.size(1), 1) mask_rois = torch.cat([batch_index, det_bboxes], dim=-1) mask_rois = mask_rois.view(-1, 5) mask_results = self._mask_forward(x, mask_rois) mask_pred = mask_results['mask_pred'] max_shape = img_metas[0]['img_shape_for_onnx'] num_det = det_bboxes.shape[1] det_bboxes = det_bboxes.reshape(-1, 4) det_labels = det_labels.reshape(-1) segm_results = self.mask_head.onnx_export(mask_pred, det_bboxes, det_labels, self.test_cfg, max_shape) segm_results = segm_results.reshape(batch_size, num_det, max_shape[0], max_shape[1]) return segm_results
[docs] def bbox_onnx_export(self, x, img_metas, proposals, rcnn_test_cfg, **kwargs): """Export bbox branch to onnx which supports batch inference. Args: x (tuple[Tensor]): Feature maps of all scale level. img_metas (list[dict]): Image meta info. proposals (Tensor): Region proposals with batch dimension, has shape [N, num_bboxes, 5]. rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of R-CNN. Returns: tuple[Tensor, Tensor]: bboxes of shape [N, num_bboxes, 5] and class labels of shape [N, num_bboxes]. """ # get origin input shape to support onnx dynamic input shape assert len( img_metas ) == 1, 'Only support one input image while in exporting to ONNX' img_shapes = img_metas[0]['img_shape_for_onnx'] rois = proposals batch_index = torch.arange( rois.size(0), device=rois.device).float().view(-1, 1, 1).expand( rois.size(0), rois.size(1), 1) rois = torch.cat([batch_index, rois[..., :4]], dim=-1) batch_size = rois.shape[0] num_proposals_per_img = rois.shape[1] # Eliminate the batch dimension rois = rois.view(-1, 5) bbox_results = self._bbox_forward(x, rois) cls_score = bbox_results['cls_score'] bbox_pred = bbox_results['bbox_pred'] # Recover the batch dimension rois = rois.reshape(batch_size, num_proposals_per_img, rois.size(-1)) cls_score = cls_score.reshape(batch_size, num_proposals_per_img, cls_score.size(-1)) bbox_pred = bbox_pred.reshape(batch_size, num_proposals_per_img, bbox_pred.size(-1)) det_bboxes, det_labels = self.bbox_head.onnx_export( rois, cls_score, bbox_pred, img_shapes, cfg=rcnn_test_cfg) return det_bboxes, det_labels