Shortcuts

Source code for mmdet.models.detectors.panoptic_two_stage_segmentor

# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import numpy as np
import torch

from mmdet.core import INSTANCE_OFFSET, bbox2roi, multiclass_nms
from mmdet.core.visualization import imshow_det_bboxes
from ..builder import DETECTORS, build_head
from ..roi_heads.mask_heads.fcn_mask_head import _do_paste_mask
from .two_stage import TwoStageDetector


[docs]@DETECTORS.register_module() class TwoStagePanopticSegmentor(TwoStageDetector): """Base class of Two-stage Panoptic Segmentor. As well as the components in TwoStageDetector, Panoptic Segmentor has extra semantic_head and panoptic_fusion_head. """ def __init__( self, backbone, neck=None, rpn_head=None, roi_head=None, train_cfg=None, test_cfg=None, pretrained=None, init_cfg=None, # for panoptic segmentation semantic_head=None, panoptic_fusion_head=None): super(TwoStagePanopticSegmentor, self).__init__(backbone, neck, rpn_head, roi_head, train_cfg, test_cfg, pretrained, init_cfg) if semantic_head is not None: self.semantic_head = build_head(semantic_head) if panoptic_fusion_head is not None: panoptic_cfg = test_cfg.panoptic if test_cfg is not None else None panoptic_fusion_head_ = panoptic_fusion_head.deepcopy() panoptic_fusion_head_.update(test_cfg=panoptic_cfg) self.panoptic_fusion_head = build_head(panoptic_fusion_head_) self.num_things_classes = self.panoptic_fusion_head.\ num_things_classes self.num_stuff_classes = self.panoptic_fusion_head.\ num_stuff_classes self.num_classes = self.panoptic_fusion_head.num_classes @property def with_semantic_head(self): return hasattr(self, 'semantic_head') and self.semantic_head is not None @property def with_panoptic_fusion_head(self): return hasattr(self, 'panoptic_fusion_heads') and \ self.panoptic_fusion_head is not None
[docs] def forward_dummy(self, img): """Used for computing network flops. See `mmdetection/tools/get_flops.py` """ raise NotImplementedError( f'`forward_dummy` is not implemented in {self.__class__.__name__}')
[docs] def forward_train(self, img, img_metas, gt_bboxes, gt_labels, gt_bboxes_ignore=None, gt_masks=None, gt_semantic_seg=None, proposals=None, **kwargs): x = self.extract_feat(img) losses = dict() # RPN forward and loss if self.with_rpn: proposal_cfg = self.train_cfg.get('rpn_proposal', self.test_cfg.rpn) rpn_losses, proposal_list = self.rpn_head.forward_train( x, img_metas, gt_bboxes, gt_labels=None, gt_bboxes_ignore=gt_bboxes_ignore, proposal_cfg=proposal_cfg) losses.update(rpn_losses) else: proposal_list = proposals roi_losses = self.roi_head.forward_train(x, img_metas, proposal_list, gt_bboxes, gt_labels, gt_bboxes_ignore, gt_masks, **kwargs) losses.update(roi_losses) semantic_loss = self.semantic_head.forward_train(x, gt_semantic_seg) losses.update(semantic_loss) return losses
[docs] def simple_test_mask(self, x, img_metas, det_bboxes, det_labels, rescale=False): """Simple test for mask head without augmentation.""" img_shapes = tuple(meta['ori_shape'] for meta in img_metas) if rescale else tuple( meta['pad_shape'] for meta in img_metas) scale_factors = tuple(meta['scale_factor'] for meta in img_metas) if all(det_bbox.shape[0] == 0 for det_bbox in det_bboxes): masks = [] for img_shape in img_shapes: out_shape = (0, self.roi_head.bbox_head.num_classes) \ + img_shape[:2] masks.append(det_bboxes[0].new_zeros(out_shape)) mask_pred = det_bboxes[0].new_zeros((0, 80, 28, 28)) mask_results = dict( masks=masks, mask_pred=mask_pred, mask_feats=None) return mask_results _bboxes = [det_bboxes[i][:, :4] for i in range(len(det_bboxes))] if rescale: if not isinstance(scale_factors[0], float): scale_factors = [ det_bboxes[0].new_tensor(scale_factor) for scale_factor in scale_factors ] _bboxes = [ _bboxes[i] * scale_factors[i] for i in range(len(_bboxes)) ] mask_rois = bbox2roi(_bboxes) mask_results = self.roi_head._mask_forward(x, mask_rois) mask_pred = mask_results['mask_pred'] # split batch mask prediction back to each image num_mask_roi_per_img = [len(det_bbox) for det_bbox in det_bboxes] mask_preds = mask_pred.split(num_mask_roi_per_img, 0) # resize the mask_preds to (K, H, W) masks = [] for i in range(len(_bboxes)): det_bbox = det_bboxes[i][:, :4] det_label = det_labels[i] mask_pred = mask_preds[i].sigmoid() box_inds = torch.arange(mask_pred.shape[0]) mask_pred = mask_pred[box_inds, det_label][:, None] img_h, img_w, _ = img_shapes[i] mask_pred, _ = _do_paste_mask( mask_pred, det_bbox, img_h, img_w, skip_empty=False) masks.append(mask_pred) mask_results['masks'] = masks return mask_results
[docs] def simple_test(self, img, img_metas, proposals=None, rescale=False): """Test without Augmentation.""" x = self.extract_feat(img) if proposals is None: proposal_list = self.rpn_head.simple_test_rpn(x, img_metas) else: proposal_list = proposals bboxes, scores = self.roi_head.simple_test_bboxes( x, img_metas, proposal_list, None, rescale=rescale) pan_cfg = self.test_cfg.panoptic # class-wise predictions det_bboxes = [] det_labels = [] for bboxe, score in zip(bboxes, scores): det_bbox, det_label = multiclass_nms(bboxe, score, pan_cfg.score_thr, pan_cfg.nms, pan_cfg.max_per_img) det_bboxes.append(det_bbox) det_labels.append(det_label) mask_results = self.simple_test_mask( x, img_metas, det_bboxes, det_labels, rescale=rescale) masks = mask_results['masks'] seg_preds = self.semantic_head.simple_test(x, img_metas, rescale) results = [] for i in range(len(det_bboxes)): pan_results = self.panoptic_fusion_head.simple_test( det_bboxes[i], det_labels[i], masks[i], seg_preds[i]) pan_results = pan_results.int().detach().cpu().numpy() result = dict(pan_results=pan_results) results.append(result) return results
[docs] def show_result(self, img, result, score_thr=0.3, bbox_color=(72, 101, 241), text_color=(72, 101, 241), mask_color=None, thickness=2, font_size=13, win_name='', show=False, wait_time=0, out_file=None): """Draw `result` over `img`. Args: img (str or Tensor): The image to be displayed. result (dict): The results. score_thr (float, optional): Minimum score of bboxes to be shown. Default: 0.3. bbox_color (str or tuple(int) or :obj:`Color`):Color of bbox lines. The tuple of color should be in BGR order. Default: 'green'. text_color (str or tuple(int) or :obj:`Color`):Color of texts. The tuple of color should be in BGR order. Default: 'green'. mask_color (None or str or tuple(int) or :obj:`Color`): Color of masks. The tuple of color should be in BGR order. Default: None. thickness (int): Thickness of lines. Default: 2. font_size (int): Font size of texts. Default: 13. win_name (str): The window name. Default: ''. wait_time (float): Value of waitKey param. Default: 0. show (bool): Whether to show the image. Default: False. out_file (str or None): The filename to write the image. Default: None. Returns: img (Tensor): Only if not `show` or `out_file`. """ img = mmcv.imread(img) img = img.copy() pan_results = result['pan_results'] # keep objects ahead ids = np.unique(pan_results)[::-1] legal_indices = ids != self.num_classes # for VOID label ids = ids[legal_indices] labels = np.array([id % INSTANCE_OFFSET for id in ids], dtype=np.int64) segms = (pan_results[None] == ids[:, None, None]) # if out_file specified, do not show image in window if out_file is not None: show = False # draw bounding boxes img = imshow_det_bboxes( img, segms=segms, labels=labels, class_names=self.CLASSES, bbox_color=bbox_color, text_color=text_color, mask_color=mask_color, thickness=thickness, font_size=font_size, win_name=win_name, show=show, wait_time=wait_time, out_file=out_file) if not (show or out_file): return img