Shortcuts

Source code for mmdet.models.detectors.panoptic_two_stage_segmentor

# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import numpy as np
import torch

from mmdet.core import INSTANCE_OFFSET, bbox2roi, multiclass_nms
from mmdet.core.visualization import imshow_det_bboxes
from ..builder import DETECTORS, build_head
from ..roi_heads.mask_heads.fcn_mask_head import _do_paste_mask
from .two_stage import TwoStageDetector


[docs]@DETECTORS.register_module() class TwoStagePanopticSegmentor(TwoStageDetector): """Base class of Two-stage Panoptic Segmentor. As well as the components in TwoStageDetector, Panoptic Segmentor has extra semantic_head and panoptic_fusion_head. """ def __init__( self, backbone, neck=None, rpn_head=None, roi_head=None, train_cfg=None, test_cfg=None, pretrained=None, init_cfg=None, # for panoptic segmentation semantic_head=None, panoptic_fusion_head=None): super(TwoStagePanopticSegmentor, self).__init__(backbone, neck, rpn_head, roi_head, train_cfg, test_cfg, pretrained, init_cfg) if semantic_head is not None: self.semantic_head = build_head(semantic_head) if panoptic_fusion_head is not None: panoptic_cfg = test_cfg.panoptic if test_cfg is not None else None panoptic_fusion_head_ = panoptic_fusion_head.deepcopy() panoptic_fusion_head_.update(test_cfg=panoptic_cfg) self.panoptic_fusion_head = build_head(panoptic_fusion_head_) self.num_things_classes = self.panoptic_fusion_head.\ num_things_classes self.num_stuff_classes = self.panoptic_fusion_head.\ num_stuff_classes self.num_classes = self.panoptic_fusion_head.num_classes @property def with_semantic_head(self): return hasattr(self, 'semantic_head') and self.semantic_head is not None @property def with_panoptic_fusion_head(self): return hasattr(self, 'panoptic_fusion_heads') and \ self.panoptic_fusion_head is not None
[docs] def forward_dummy(self, img): """Used for computing network flops. See `mmdetection/tools/get_flops.py` """ raise NotImplementedError( f'`forward_dummy` is not implemented in {self.__class__.__name__}')
[docs] def forward_train(self, img, img_metas, gt_bboxes, gt_labels, gt_bboxes_ignore=None, gt_masks=None, gt_semantic_seg=None, proposals=None, **kwargs): x = self.extract_feat(img) losses = dict() # RPN forward and loss if self.with_rpn: proposal_cfg = self.train_cfg.get('rpn_proposal', self.test_cfg.rpn) rpn_losses, proposal_list = self.rpn_head.forward_train( x, img_metas, gt_bboxes, gt_labels=None, gt_bboxes_ignore=gt_bboxes_ignore, proposal_cfg=proposal_cfg) losses.update(rpn_losses) else: proposal_list = proposals roi_losses = self.roi_head.forward_train(x, img_metas, proposal_list, gt_bboxes, gt_labels, gt_bboxes_ignore, gt_masks, **kwargs) losses.update(roi_losses) semantic_loss = self.semantic_head.forward_train(x, gt_semantic_seg) losses.update(semantic_loss) return losses
[docs] def simple_test_mask(self, x, img_metas, det_bboxes, det_labels, rescale=False): """Simple test for mask head without augmentation.""" img_shapes = tuple(meta['ori_shape'] for meta in img_metas) if rescale else tuple( meta['pad_shape'] for meta in img_metas) scale_factors = tuple(meta['scale_factor'] for meta in img_metas) if all(det_bbox.shape[0] == 0 for det_bbox in det_bboxes): masks = [] for img_shape in img_shapes: out_shape = (0, self.roi_head.bbox_head.num_classes) \ + img_shape[:2] masks.append(det_bboxes[0].new_zeros(out_shape)) mask_pred = det_bboxes[0].new_zeros((0, 80, 28, 28)) mask_results = dict( masks=masks, mask_pred=mask_pred, mask_feats=None) return mask_results _bboxes = [det_bboxes[i][:, :4] for i in range(len(det_bboxes))] if rescale: if not isinstance(scale_factors[0], float): scale_factors = [ det_bboxes[0].new_tensor(scale_factor) for scale_factor in scale_factors ] _bboxes = [ _bboxes[i] * scale_factors[i] for i in range(len(_bboxes)) ] mask_rois = bbox2roi(_bboxes) mask_results = self.roi_head._mask_forward(x, mask_rois) mask_pred = mask_results['mask_pred'] # split batch mask prediction back to each image num_mask_roi_per_img = [len(det_bbox) for det_bbox in det_bboxes] mask_preds = mask_pred.split(num_mask_roi_per_img, 0) # resize the mask_preds to (K, H, W) masks = [] for i in range(len(_bboxes)): det_bbox = det_bboxes[i][:, :4] det_label = det_labels[i] mask_pred = mask_preds[i].sigmoid() box_inds = torch.arange(mask_pred.shape[0]) mask_pred = mask_pred[box_inds, det_label][:, None] img_h, img_w, _ = img_shapes[i] mask_pred, _ = _do_paste_mask( mask_pred, det_bbox, img_h, img_w, skip_empty=False) masks.append(mask_pred) mask_results['masks'] = masks return mask_results
[docs] def simple_test(self, img, img_metas, proposals=None, rescale=False): """Test without Augmentation.""" x = self.extract_feat(img) if proposals is None: proposal_list = self.rpn_head.simple_test_rpn(x, img_metas) else: proposal_list = proposals bboxes, scores = self.roi_head.simple_test_bboxes( x, img_metas, proposal_list, None, rescale=rescale) pan_cfg = self.test_cfg.panoptic # class-wise predictions det_bboxes = [] det_labels = [] for bboxe, score in zip(bboxes, scores): det_bbox, det_label = multiclass_nms(bboxe, score, pan_cfg.score_thr, pan_cfg.nms, pan_cfg.max_per_img) det_bboxes.append(det_bbox) det_labels.append(det_label) mask_results = self.simple_test_mask( x, img_metas, det_bboxes, det_labels, rescale=rescale) masks = mask_results['masks'] seg_preds = self.semantic_head.simple_test(x, img_metas, rescale) results = [] for i in range(len(det_bboxes)): pan_results = self.panoptic_fusion_head.simple_test( det_bboxes[i], det_labels[i], masks[i], seg_preds[i]) pan_results = pan_results.int().detach().cpu().numpy() result = dict(pan_results=pan_results) results.append(result) return results
[docs] def show_result(self, img, result, score_thr=0.3, bbox_color=(72, 101, 241), text_color=(72, 101, 241), mask_color=None, thickness=2, font_size=13, win_name='', show=False, wait_time=0, out_file=None): """Draw `result` over `img`. Args: img (str or Tensor): The image to be displayed. result (dict): The results. score_thr (float, optional): Minimum score of bboxes to be shown. Default: 0.3. bbox_color (str or tuple(int) or :obj:`Color`):Color of bbox lines. The tuple of color should be in BGR order. Default: 'green'. text_color (str or tuple(int) or :obj:`Color`):Color of texts. The tuple of color should be in BGR order. Default: 'green'. mask_color (None or str or tuple(int) or :obj:`Color`): Color of masks. The tuple of color should be in BGR order. Default: None. thickness (int): Thickness of lines. Default: 2. font_size (int): Font size of texts. Default: 13. win_name (str): The window name. Default: ''. wait_time (float): Value of waitKey param. Default: 0. show (bool): Whether to show the image. Default: False. out_file (str or None): The filename to write the image. Default: None. Returns: img (Tensor): Only if not `show` or `out_file`. """ img = mmcv.imread(img) img = img.copy() pan_results = result['pan_results'] # keep objects ahead ids = np.unique(pan_results)[::-1] legal_indices = ids != self.num_classes # for VOID label ids = ids[legal_indices] labels = np.array([id % INSTANCE_OFFSET for id in ids], dtype=np.int64) segms = (pan_results[None] == ids[:, None, None]) # if out_file specified, do not show image in window if out_file is not None: show = False # draw bounding boxes img = imshow_det_bboxes( img, segms=segms, labels=labels, class_names=self.CLASSES, bbox_color=bbox_color, text_color=text_color, mask_color=mask_color, thickness=thickness, font_size=font_size, win_name=win_name, show=show, wait_time=wait_time, out_file=out_file) if not (show or out_file): return img
Read the Docs v: v2.24.1
Versions
latest
stable
v2.24.1
v2.24.0
v2.23.0
v2.22.0
v2.21.0
v2.20.0
v2.19.1
v2.19.0
v2.18.1
v2.18.0
v2.17.0
v2.16.0
v2.15.1
v2.15.0
v2.14.0
v2.13.0
v2.12.0
v2.11.0
v2.10.0
v2.9.0
v2.8.0
v2.7.0
v2.6.0
v2.5.0
v2.4.0
v2.3.0
v2.2.1
v2.2.0
v2.1.0
v2.0.0
v1.2.0
dev
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.