Source code for mmdet.models.detectors.panoptic_two_stage_segmentor
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import numpy as np
import torch
from mmdet.core import INSTANCE_OFFSET, bbox2roi, multiclass_nms
from mmdet.core.visualization import imshow_det_bboxes
from ..builder import DETECTORS, build_head
from ..roi_heads.mask_heads.fcn_mask_head import _do_paste_mask
from .two_stage import TwoStageDetector
[docs]@DETECTORS.register_module()
class TwoStagePanopticSegmentor(TwoStageDetector):
"""Base class of Two-stage Panoptic Segmentor.
As well as the components in TwoStageDetector, Panoptic Segmentor has extra
semantic_head and panoptic_fusion_head.
"""
def __init__(
self,
backbone,
neck=None,
rpn_head=None,
roi_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None,
init_cfg=None,
# for panoptic segmentation
semantic_head=None,
panoptic_fusion_head=None):
super(TwoStagePanopticSegmentor,
self).__init__(backbone, neck, rpn_head, roi_head, train_cfg,
test_cfg, pretrained, init_cfg)
if semantic_head is not None:
self.semantic_head = build_head(semantic_head)
if panoptic_fusion_head is not None:
panoptic_cfg = test_cfg.panoptic if test_cfg is not None else None
panoptic_fusion_head_ = panoptic_fusion_head.deepcopy()
panoptic_fusion_head_.update(test_cfg=panoptic_cfg)
self.panoptic_fusion_head = build_head(panoptic_fusion_head_)
self.num_things_classes = self.panoptic_fusion_head.\
num_things_classes
self.num_stuff_classes = self.panoptic_fusion_head.\
num_stuff_classes
self.num_classes = self.panoptic_fusion_head.num_classes
@property
def with_semantic_head(self):
return hasattr(self,
'semantic_head') and self.semantic_head is not None
@property
def with_panoptic_fusion_head(self):
return hasattr(self, 'panoptic_fusion_heads') and \
self.panoptic_fusion_head is not None
[docs] def forward_dummy(self, img):
"""Used for computing network flops.
See `mmdetection/tools/get_flops.py`
"""
raise NotImplementedError(
f'`forward_dummy` is not implemented in {self.__class__.__name__}')
[docs] def forward_train(self,
img,
img_metas,
gt_bboxes,
gt_labels,
gt_bboxes_ignore=None,
gt_masks=None,
gt_semantic_seg=None,
proposals=None,
**kwargs):
x = self.extract_feat(img)
losses = dict()
# RPN forward and loss
if self.with_rpn:
proposal_cfg = self.train_cfg.get('rpn_proposal',
self.test_cfg.rpn)
rpn_losses, proposal_list = self.rpn_head.forward_train(
x,
img_metas,
gt_bboxes,
gt_labels=None,
gt_bboxes_ignore=gt_bboxes_ignore,
proposal_cfg=proposal_cfg)
losses.update(rpn_losses)
else:
proposal_list = proposals
roi_losses = self.roi_head.forward_train(x, img_metas, proposal_list,
gt_bboxes, gt_labels,
gt_bboxes_ignore, gt_masks,
**kwargs)
losses.update(roi_losses)
semantic_loss = self.semantic_head.forward_train(x, gt_semantic_seg)
losses.update(semantic_loss)
return losses
[docs] def simple_test_mask(self,
x,
img_metas,
det_bboxes,
det_labels,
rescale=False):
"""Simple test for mask head without augmentation."""
img_shapes = tuple(meta['ori_shape']
for meta in img_metas) if rescale else tuple(
meta['pad_shape'] for meta in img_metas)
scale_factors = tuple(meta['scale_factor'] for meta in img_metas)
if all(det_bbox.shape[0] == 0 for det_bbox in det_bboxes):
masks = []
for img_shape in img_shapes:
out_shape = (0, self.roi_head.bbox_head.num_classes) \
+ img_shape[:2]
masks.append(det_bboxes[0].new_zeros(out_shape))
mask_pred = det_bboxes[0].new_zeros((0, 80, 28, 28))
mask_results = dict(
masks=masks, mask_pred=mask_pred, mask_feats=None)
return mask_results
_bboxes = [det_bboxes[i][:, :4] for i in range(len(det_bboxes))]
if rescale:
if not isinstance(scale_factors[0], float):
scale_factors = [
det_bboxes[0].new_tensor(scale_factor)
for scale_factor in scale_factors
]
_bboxes = [
_bboxes[i] * scale_factors[i] for i in range(len(_bboxes))
]
mask_rois = bbox2roi(_bboxes)
mask_results = self.roi_head._mask_forward(x, mask_rois)
mask_pred = mask_results['mask_pred']
# split batch mask prediction back to each image
num_mask_roi_per_img = [len(det_bbox) for det_bbox in det_bboxes]
mask_preds = mask_pred.split(num_mask_roi_per_img, 0)
# resize the mask_preds to (K, H, W)
masks = []
for i in range(len(_bboxes)):
det_bbox = det_bboxes[i][:, :4]
det_label = det_labels[i]
mask_pred = mask_preds[i].sigmoid()
box_inds = torch.arange(mask_pred.shape[0])
mask_pred = mask_pred[box_inds, det_label][:, None]
img_h, img_w, _ = img_shapes[i]
mask_pred, _ = _do_paste_mask(
mask_pred, det_bbox, img_h, img_w, skip_empty=False)
masks.append(mask_pred)
mask_results['masks'] = masks
return mask_results
[docs] def simple_test(self, img, img_metas, proposals=None, rescale=False):
"""Test without Augmentation."""
x = self.extract_feat(img)
if proposals is None:
proposal_list = self.rpn_head.simple_test_rpn(x, img_metas)
else:
proposal_list = proposals
bboxes, scores = self.roi_head.simple_test_bboxes(
x, img_metas, proposal_list, None, rescale=rescale)
pan_cfg = self.test_cfg.panoptic
# class-wise predictions
det_bboxes = []
det_labels = []
for bboxe, score in zip(bboxes, scores):
det_bbox, det_label = multiclass_nms(bboxe, score,
pan_cfg.score_thr,
pan_cfg.nms,
pan_cfg.max_per_img)
det_bboxes.append(det_bbox)
det_labels.append(det_label)
mask_results = self.simple_test_mask(
x, img_metas, det_bboxes, det_labels, rescale=rescale)
masks = mask_results['masks']
seg_preds = self.semantic_head.simple_test(x, img_metas, rescale)
results = []
for i in range(len(det_bboxes)):
pan_results = self.panoptic_fusion_head.simple_test(
det_bboxes[i], det_labels[i], masks[i], seg_preds[i])
pan_results = pan_results.int().detach().cpu().numpy()
result = dict(pan_results=pan_results)
results.append(result)
return results
[docs] def show_result(self,
img,
result,
score_thr=0.3,
bbox_color=(72, 101, 241),
text_color=(72, 101, 241),
mask_color=None,
thickness=2,
font_size=13,
win_name='',
show=False,
wait_time=0,
out_file=None):
"""Draw `result` over `img`.
Args:
img (str or Tensor): The image to be displayed.
result (dict): The results.
score_thr (float, optional): Minimum score of bboxes to be shown.
Default: 0.3.
bbox_color (str or tuple(int) or :obj:`Color`):Color of bbox lines.
The tuple of color should be in BGR order. Default: 'green'.
text_color (str or tuple(int) or :obj:`Color`):Color of texts.
The tuple of color should be in BGR order. Default: 'green'.
mask_color (None or str or tuple(int) or :obj:`Color`):
Color of masks. The tuple of color should be in BGR order.
Default: None.
thickness (int): Thickness of lines. Default: 2.
font_size (int): Font size of texts. Default: 13.
win_name (str): The window name. Default: ''.
wait_time (float): Value of waitKey param.
Default: 0.
show (bool): Whether to show the image.
Default: False.
out_file (str or None): The filename to write the image.
Default: None.
Returns:
img (Tensor): Only if not `show` or `out_file`.
"""
img = mmcv.imread(img)
img = img.copy()
pan_results = result['pan_results']
# keep objects ahead
ids = np.unique(pan_results)[::-1]
legal_indices = ids != self.num_classes # for VOID label
ids = ids[legal_indices]
labels = np.array([id % INSTANCE_OFFSET for id in ids], dtype=np.int64)
segms = (pan_results[None] == ids[:, None, None])
# if out_file specified, do not show image in window
if out_file is not None:
show = False
# draw bounding boxes
img = imshow_det_bboxes(
img,
segms=segms,
labels=labels,
class_names=self.CLASSES,
bbox_color=bbox_color,
text_color=text_color,
mask_color=mask_color,
thickness=thickness,
font_size=font_size,
win_name=win_name,
show=show,
wait_time=wait_time,
out_file=out_file)
if not (show or out_file):
return img