Shortcuts

Source code for mmdet.models.detectors.maskformer

# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import numpy as np

from mmdet.core import INSTANCE_OFFSET
from mmdet.core.visualization import imshow_det_bboxes
from ..builder import DETECTORS, build_backbone, build_head, build_neck
from .single_stage import SingleStageDetector


[docs]@DETECTORS.register_module() class MaskFormer(SingleStageDetector): r"""Implementation of `Per-Pixel Classification is NOT All You Need for Semantic Segmentation <https://arxiv.org/pdf/2107.06278>`_.""" def __init__(self, backbone, neck=None, panoptic_head=None, train_cfg=None, test_cfg=None, init_cfg=None): super(SingleStageDetector, self).__init__(init_cfg=init_cfg) self.backbone = build_backbone(backbone) if neck is not None: self.neck = build_neck(neck) panoptic_head.update(train_cfg=train_cfg) panoptic_head.update(test_cfg=test_cfg) self.panoptic_head = build_head(panoptic_head) self.num_things_classes = self.panoptic_head.num_things_classes self.num_stuff_classes = self.panoptic_head.num_stuff_classes self.num_classes = self.panoptic_head.num_classes self.train_cfg = train_cfg self.test_cfg = test_cfg
[docs] def forward_dummy(self, img, img_metas): """Used for computing network flops. See `mmdetection/tools/analysis_tools/get_flops.py` Args: img (Tensor): of shape (N, C, H, W) encoding input images. Typically these should be mean centered and std scaled. img_metas (list[Dict]): list of image info dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details on the values of these keys see `mmdet/datasets/pipelines/formatting.py:Collect`. """ super(SingleStageDetector, self).forward_train(img, img_metas) x = self.extract_feat(img) outs = self.panoptic_head(x, img_metas) return outs
[docs] def forward_train(self, img, img_metas, gt_bboxes, gt_labels, gt_masks, gt_semantic_seg, gt_bboxes_ignore=None, **kargs): """ Args: img (Tensor): of shape (N, C, H, W) encoding input images. Typically these should be mean centered and std scaled. img_metas (list[Dict]): list of image info dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details on the values of these keys see `mmdet/datasets/pipelines/formatting.py:Collect`. gt_bboxes (list[Tensor]): Ground truth bboxes for each image with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. gt_labels (list[Tensor]): class indices corresponding to each box. gt_masks (list[BitmapMasks]): true segmentation masks for each box used if the architecture supports a segmentation task. gt_semantic_seg (list[tensor]): semantic segmentation mask for images. gt_bboxes_ignore (list[Tensor]): specify which bounding boxes can be ignored when computing the loss. Defaults to None. Returns: dict[str, Tensor]: a dictionary of loss components """ # add batch_input_shape in img_metas super(SingleStageDetector, self).forward_train(img, img_metas) x = self.extract_feat(img) losses = self.panoptic_head.forward_train(x, img_metas, gt_bboxes, gt_labels, gt_masks, gt_semantic_seg, gt_bboxes_ignore) return losses
[docs] def simple_test(self, img, img_metas, **kwargs): """Test without augmentation.""" feat = self.extract_feat(img) mask_results = self.panoptic_head.simple_test(feat, img_metas, **kwargs) results = [] for mask in mask_results: result = {'pan_results': mask.detach().cpu().numpy()} results.append(result) return results
[docs] def aug_test(self, imgs, img_metas, **kwargs): raise NotImplementedError
[docs] def onnx_export(self, img, img_metas): raise NotImplementedError
[docs] def show_result(self, img, result, score_thr=0.3, bbox_color=(72, 101, 241), text_color=(72, 101, 241), mask_color=None, thickness=2, font_size=13, win_name='', show=False, wait_time=0, out_file=None): """Draw `result` over `img`. Args: img (str or Tensor): The image to be displayed. result (dict): The results. score_thr (float, optional): Minimum score of bboxes to be shown. Default: 0.3. bbox_color (str or tuple(int) or :obj:`Color`):Color of bbox lines. The tuple of color should be in BGR order. Default: 'green'. text_color (str or tuple(int) or :obj:`Color`):Color of texts. The tuple of color should be in BGR order. Default: 'green'. mask_color (None or str or tuple(int) or :obj:`Color`): Color of masks. The tuple of color should be in BGR order. Default: None. thickness (int): Thickness of lines. Default: 2. font_size (int): Font size of texts. Default: 13. win_name (str): The window name. Default: ''. wait_time (float): Value of waitKey param. Default: 0. show (bool): Whether to show the image. Default: False. out_file (str or None): The filename to write the image. Default: None. Returns: img (Tensor): Only if not `show` or `out_file`. """ img = mmcv.imread(img) img = img.copy() pan_results = result['pan_results'] # keep objects ahead ids = np.unique(pan_results)[::-1] legal_indices = ids != self.num_classes # for VOID label ids = ids[legal_indices] labels = np.array([id % INSTANCE_OFFSET for id in ids], dtype=np.int64) segms = (pan_results[None] == ids[:, None, None]) # if out_file specified, do not show image in window if out_file is not None: show = False # draw bounding boxes img = imshow_det_bboxes( img, segms=segms, labels=labels, class_names=self.CLASSES, bbox_color=bbox_color, text_color=text_color, mask_color=mask_color, thickness=thickness, font_size=font_size, win_name=win_name, show=show, wait_time=wait_time, out_file=out_file) if not (show or out_file): return img