Source code for mmdet.models.detectors.maskformer
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import numpy as np
from mmdet.core import INSTANCE_OFFSET
from mmdet.core.visualization import imshow_det_bboxes
from ..builder import DETECTORS, build_backbone, build_head, build_neck
from .single_stage import SingleStageDetector
[docs]@DETECTORS.register_module()
class MaskFormer(SingleStageDetector):
r"""Implementation of `Per-Pixel Classification is
NOT All You Need for Semantic Segmentation
<https://arxiv.org/pdf/2107.06278>`_."""
def __init__(self,
backbone,
neck=None,
panoptic_head=None,
train_cfg=None,
test_cfg=None,
init_cfg=None):
super(SingleStageDetector, self).__init__(init_cfg=init_cfg)
self.backbone = build_backbone(backbone)
if neck is not None:
self.neck = build_neck(neck)
panoptic_head.update(train_cfg=train_cfg)
panoptic_head.update(test_cfg=test_cfg)
self.panoptic_head = build_head(panoptic_head)
self.num_things_classes = self.panoptic_head.num_things_classes
self.num_stuff_classes = self.panoptic_head.num_stuff_classes
self.num_classes = self.panoptic_head.num_classes
self.train_cfg = train_cfg
self.test_cfg = test_cfg
[docs] def forward_dummy(self, img, img_metas):
"""Used for computing network flops. See
`mmdetection/tools/analysis_tools/get_flops.py`
Args:
img (Tensor): of shape (N, C, H, W) encoding input images.
Typically these should be mean centered and std scaled.
img_metas (list[Dict]): list of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`mmdet/datasets/pipelines/formatting.py:Collect`.
"""
super(SingleStageDetector, self).forward_train(img, img_metas)
x = self.extract_feat(img)
outs = self.panoptic_head(x, img_metas)
return outs
[docs] def forward_train(self,
img,
img_metas,
gt_bboxes,
gt_labels,
gt_masks,
gt_semantic_seg,
gt_bboxes_ignore=None,
**kargs):
"""
Args:
img (Tensor): of shape (N, C, H, W) encoding input images.
Typically these should be mean centered and std scaled.
img_metas (list[Dict]): list of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`mmdet/datasets/pipelines/formatting.py:Collect`.
gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels (list[Tensor]): class indices corresponding to each box.
gt_masks (list[BitmapMasks]): true segmentation masks for each box
used if the architecture supports a segmentation task.
gt_semantic_seg (list[tensor]): semantic segmentation mask for
images.
gt_bboxes_ignore (list[Tensor]): specify which bounding
boxes can be ignored when computing the loss.
Defaults to None.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
# add batch_input_shape in img_metas
super(SingleStageDetector, self).forward_train(img, img_metas)
x = self.extract_feat(img)
losses = self.panoptic_head.forward_train(x, img_metas, gt_bboxes,
gt_labels, gt_masks,
gt_semantic_seg,
gt_bboxes_ignore)
return losses
[docs] def simple_test(self, img, img_metas, **kwargs):
"""Test without augmentation."""
feat = self.extract_feat(img)
mask_results = self.panoptic_head.simple_test(feat, img_metas,
**kwargs)
results = []
for mask in mask_results:
result = {'pan_results': mask.detach().cpu().numpy()}
results.append(result)
return results
[docs] def show_result(self,
img,
result,
score_thr=0.3,
bbox_color=(72, 101, 241),
text_color=(72, 101, 241),
mask_color=None,
thickness=2,
font_size=13,
win_name='',
show=False,
wait_time=0,
out_file=None):
"""Draw `result` over `img`.
Args:
img (str or Tensor): The image to be displayed.
result (dict): The results.
score_thr (float, optional): Minimum score of bboxes to be shown.
Default: 0.3.
bbox_color (str or tuple(int) or :obj:`Color`):Color of bbox lines.
The tuple of color should be in BGR order. Default: 'green'.
text_color (str or tuple(int) or :obj:`Color`):Color of texts.
The tuple of color should be in BGR order. Default: 'green'.
mask_color (None or str or tuple(int) or :obj:`Color`):
Color of masks. The tuple of color should be in BGR order.
Default: None.
thickness (int): Thickness of lines. Default: 2.
font_size (int): Font size of texts. Default: 13.
win_name (str): The window name. Default: ''.
wait_time (float): Value of waitKey param.
Default: 0.
show (bool): Whether to show the image.
Default: False.
out_file (str or None): The filename to write the image.
Default: None.
Returns:
img (Tensor): Only if not `show` or `out_file`.
"""
img = mmcv.imread(img)
img = img.copy()
pan_results = result['pan_results']
# keep objects ahead
ids = np.unique(pan_results)[::-1]
legal_indices = ids != self.num_classes # for VOID label
ids = ids[legal_indices]
labels = np.array([id % INSTANCE_OFFSET for id in ids], dtype=np.int64)
segms = (pan_results[None] == ids[:, None, None])
# if out_file specified, do not show image in window
if out_file is not None:
show = False
# draw bounding boxes
img = imshow_det_bboxes(
img,
segms=segms,
labels=labels,
class_names=self.CLASSES,
bbox_color=bbox_color,
text_color=text_color,
mask_color=mask_color,
thickness=thickness,
font_size=font_size,
win_name=win_name,
show=show,
wait_time=wait_time,
out_file=out_file)
if not (show or out_file):
return img