Shortcuts

Source code for mmdet.models.dense_heads.maskformer_head

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import Conv2d, build_plugin_layer, caffe2_xavier_init
from mmcv.cnn.bricks.transformer import (build_positional_encoding,
                                         build_transformer_layer_sequence)
from mmcv.runner import force_fp32

from mmdet.core import build_assigner, build_sampler, multi_apply, reduce_mean
from mmdet.core.evaluation import INSTANCE_OFFSET
from mmdet.models.utils import preprocess_panoptic_gt
from ..builder import HEADS, build_loss
from .anchor_free_head import AnchorFreeHead


[docs]@HEADS.register_module() class MaskFormerHead(AnchorFreeHead): """Implements the MaskFormer head. See `Per-Pixel Classification is Not All You Need for Semantic Segmentation <https://arxiv.org/pdf/2107.06278>`_ for details. Args: in_channels (list[int]): Number of channels in the input feature map. feat_channels (int): Number of channels for feature. out_channels (int): Number of channels for output. num_things_classes (int): Number of things. num_stuff_classes (int): Number of stuff. num_queries (int): Number of query in Transformer. pixel_decoder (:obj:`mmcv.ConfigDict` | dict): Config for pixel decoder. Defaults to None. enforce_decoder_input_project (bool, optional): Whether to add a layer to change the embed_dim of tranformer encoder in pixel decoder to the embed_dim of transformer decoder. Defaults to False. transformer_decoder (:obj:`mmcv.ConfigDict` | dict): Config for transformer decoder. Defaults to None. positional_encoding (:obj:`mmcv.ConfigDict` | dict): Config for transformer decoder position encoding. Defaults to None. loss_cls (:obj:`mmcv.ConfigDict` | dict): Config of the classification loss. Defaults to `CrossEntropyLoss`. loss_mask (:obj:`mmcv.ConfigDict` | dict): Config of the mask loss. Defaults to `FocalLoss`. loss_dice (:obj:`mmcv.ConfigDict` | dict): Config of the dice loss. Defaults to `DiceLoss`. train_cfg (:obj:`mmcv.ConfigDict` | dict): Training config of Maskformer head. test_cfg (:obj:`mmcv.ConfigDict` | dict): Testing config of Maskformer head. init_cfg (dict or list[dict], optional): Initialization config dict. Defaults to None. """ def __init__(self, in_channels, feat_channels, out_channels, num_things_classes=80, num_stuff_classes=53, num_queries=100, pixel_decoder=None, enforce_decoder_input_project=False, transformer_decoder=None, positional_encoding=None, loss_cls=dict( type='CrossEntropyLoss', bg_cls_weight=0.1, use_sigmoid=False, loss_weight=1.0, class_weight=1.0), loss_mask=dict( type='FocalLoss', use_sigmoid=True, gamma=2.0, alpha=0.25, loss_weight=20.0), loss_dice=dict( type='DiceLoss', use_sigmoid=True, activate=True, naive_dice=True, loss_weight=1.0), train_cfg=None, test_cfg=None, init_cfg=None, **kwargs): super(AnchorFreeHead, self).__init__(init_cfg) self.num_things_classes = num_things_classes self.num_stuff_classes = num_stuff_classes self.num_classes = self.num_things_classes + self.num_stuff_classes self.num_queries = num_queries pixel_decoder.update( in_channels=in_channels, feat_channels=feat_channels, out_channels=out_channels) self.pixel_decoder = build_plugin_layer(pixel_decoder)[1] self.transformer_decoder = build_transformer_layer_sequence( transformer_decoder) self.decoder_embed_dims = self.transformer_decoder.embed_dims pixel_decoder_type = pixel_decoder.get('type') if pixel_decoder_type == 'PixelDecoder' and ( self.decoder_embed_dims != in_channels[-1] or enforce_decoder_input_project): self.decoder_input_proj = Conv2d( in_channels[-1], self.decoder_embed_dims, kernel_size=1) else: self.decoder_input_proj = nn.Identity() self.decoder_pe = build_positional_encoding(positional_encoding) self.query_embed = nn.Embedding(self.num_queries, out_channels) self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1) self.mask_embed = nn.Sequential( nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), nn.Linear(feat_channels, out_channels)) self.test_cfg = test_cfg self.train_cfg = train_cfg if train_cfg: assert 'assigner' in train_cfg, 'assigner should be provided '\ 'when train_cfg is set.' assigner = train_cfg['assigner'] self.assigner = build_assigner(assigner) sampler_cfg = dict(type='MaskPseudoSampler') self.sampler = build_sampler(sampler_cfg, context=self) self.bg_cls_weight = 0 class_weight = loss_cls.get('class_weight', None) if class_weight is not None and (self.__class__ is MaskFormerHead): assert isinstance(class_weight, float), 'Expected ' \ 'class_weight to have type float. Found ' \ f'{type(class_weight)}.' # NOTE following the official MaskFormerHead repo, bg_cls_weight # means relative classification weight of the VOID class. bg_cls_weight = loss_cls.get('bg_cls_weight', class_weight) assert isinstance(bg_cls_weight, float), 'Expected ' \ 'bg_cls_weight to have type float. Found ' \ f'{type(bg_cls_weight)}.' class_weight = torch.ones(self.num_classes + 1) * class_weight # set VOID class as the last indice class_weight[self.num_classes] = bg_cls_weight loss_cls.update({'class_weight': class_weight}) if 'bg_cls_weight' in loss_cls: loss_cls.pop('bg_cls_weight') self.bg_cls_weight = bg_cls_weight self.loss_cls = build_loss(loss_cls) self.loss_mask = build_loss(loss_mask) self.loss_dice = build_loss(loss_dice)
[docs] def init_weights(self): if isinstance(self.decoder_input_proj, Conv2d): caffe2_xavier_init(self.decoder_input_proj, bias=0) self.pixel_decoder.init_weights() for p in self.transformer_decoder.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p)
[docs] def preprocess_gt(self, gt_labels_list, gt_masks_list, gt_semantic_segs): """Preprocess the ground truth for all images. Args: gt_labels_list (list[Tensor]): Each is ground truth labels of each bbox, with shape (num_gts, ). gt_masks_list (list[BitmapMasks]): Each is ground truth masks of each instances of a image, shape (num_gts, h, w). gt_semantic_seg (Tensor): Ground truth of semantic segmentation with the shape (batch_size, n, h, w). [0, num_thing_class - 1] means things, [num_thing_class, num_class-1] means stuff, 255 means VOID. target_shape (tuple[int]): Shape of output mask_preds. Resize the masks to shape of mask_preds. Returns: tuple: a tuple containing the following targets. - labels (list[Tensor]): Ground truth class indices\ for all images. Each with shape (n, ), n is the sum of\ number of stuff type and number of instance in a image. - masks (list[Tensor]): Ground truth mask for each\ image, each with shape (n, h, w). """ num_things_list = [self.num_things_classes] * len(gt_labels_list) num_stuff_list = [self.num_stuff_classes] * len(gt_labels_list) targets = multi_apply(preprocess_panoptic_gt, gt_labels_list, gt_masks_list, gt_semantic_segs, num_things_list, num_stuff_list) labels, masks = targets return labels, masks
[docs] def get_targets(self, cls_scores_list, mask_preds_list, gt_labels_list, gt_masks_list, img_metas): """Compute classification and mask targets for all images for a decoder layer. Args: cls_scores_list (list[Tensor]): Mask score logits from a single decoder layer for all images. Each with shape (num_queries, cls_out_channels). mask_preds_list (list[Tensor]): Mask logits from a single decoder layer for all images. Each with shape (num_queries, h, w). gt_labels_list (list[Tensor]): Ground truth class indices for all images. Each with shape (n, ), n is the sum of number of stuff type and number of instance in a image. gt_masks_list (list[Tensor]): Ground truth mask for each image, each with shape (n, h, w). img_metas (list[dict]): List of image meta information. Returns: tuple[list[Tensor]]: a tuple containing the following targets. - labels_list (list[Tensor]): Labels of all images.\ Each with shape (num_queries, ). - label_weights_list (list[Tensor]): Label weights\ of all images. Each with shape (num_queries, ). - mask_targets_list (list[Tensor]): Mask targets of\ all images. Each with shape (num_queries, h, w). - mask_weights_list (list[Tensor]): Mask weights of\ all images. Each with shape (num_queries, ). - num_total_pos (int): Number of positive samples in\ all images. - num_total_neg (int): Number of negative samples in\ all images. """ (labels_list, label_weights_list, mask_targets_list, mask_weights_list, pos_inds_list, neg_inds_list) = multi_apply(self._get_target_single, cls_scores_list, mask_preds_list, gt_labels_list, gt_masks_list, img_metas) num_total_pos = sum((inds.numel() for inds in pos_inds_list)) num_total_neg = sum((inds.numel() for inds in neg_inds_list)) return (labels_list, label_weights_list, mask_targets_list, mask_weights_list, num_total_pos, num_total_neg)
def _get_target_single(self, cls_score, mask_pred, gt_labels, gt_masks, img_metas): """Compute classification and mask targets for one image. Args: cls_score (Tensor): Mask score logits from a single decoder layer for one image. Shape (num_queries, cls_out_channels). mask_pred (Tensor): Mask logits for a single decoder layer for one image. Shape (num_queries, h, w). gt_labels (Tensor): Ground truth class indices for one image with shape (n, ). n is the sum of number of stuff type and number of instance in a image. gt_masks (Tensor): Ground truth mask for each image, each with shape (n, h, w). img_metas (dict): Image informtation. Returns: tuple[Tensor]: a tuple containing the following for one image. - labels (Tensor): Labels of each image. shape (num_queries, ). - label_weights (Tensor): Label weights of each image. shape (num_queries, ). - mask_targets (Tensor): Mask targets of each image. shape (num_queries, h, w). - mask_weights (Tensor): Mask weights of each image. shape (num_queries, ). - pos_inds (Tensor): Sampled positive indices for each image. - neg_inds (Tensor): Sampled negative indices for each image. """ target_shape = mask_pred.shape[-2:] if gt_masks.shape[0] > 0: gt_masks_downsampled = F.interpolate( gt_masks.unsqueeze(1).float(), target_shape, mode='nearest').squeeze(1).long() else: gt_masks_downsampled = gt_masks # assign and sample assign_result = self.assigner.assign(cls_score, mask_pred, gt_labels, gt_masks_downsampled, img_metas) sampling_result = self.sampler.sample(assign_result, mask_pred, gt_masks) pos_inds = sampling_result.pos_inds neg_inds = sampling_result.neg_inds # label target labels = gt_labels.new_full((self.num_queries, ), self.num_classes, dtype=torch.long) labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds] label_weights = gt_labels.new_ones(self.num_queries) # mask target mask_targets = gt_masks[sampling_result.pos_assigned_gt_inds] mask_weights = mask_pred.new_zeros((self.num_queries, )) mask_weights[pos_inds] = 1.0 return (labels, label_weights, mask_targets, mask_weights, pos_inds, neg_inds)
[docs] @force_fp32(apply_to=('all_cls_scores', 'all_mask_preds')) def loss(self, all_cls_scores, all_mask_preds, gt_labels_list, gt_masks_list, img_metas): """Loss function. Args: all_cls_scores (Tensor): Classification scores for all decoder layers with shape (num_decoder, batch_size, num_queries, cls_out_channels). all_mask_preds (Tensor): Mask scores for all decoder layers with shape (num_decoder, batch_size, num_queries, h, w). gt_labels_list (list[Tensor]): Ground truth class indices for each image with shape (n, ). n is the sum of number of stuff type and number of instance in a image. gt_masks_list (list[Tensor]): Ground truth mask for each image with shape (n, h, w). img_metas (list[dict]): List of image meta information. Returns: dict[str, Tensor]: A dictionary of loss components. """ num_dec_layers = len(all_cls_scores) all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)] all_gt_masks_list = [gt_masks_list for _ in range(num_dec_layers)] img_metas_list = [img_metas for _ in range(num_dec_layers)] losses_cls, losses_mask, losses_dice = multi_apply( self.loss_single, all_cls_scores, all_mask_preds, all_gt_labels_list, all_gt_masks_list, img_metas_list) loss_dict = dict() # loss from the last decoder layer loss_dict['loss_cls'] = losses_cls[-1] loss_dict['loss_mask'] = losses_mask[-1] loss_dict['loss_dice'] = losses_dice[-1] # loss from other decoder layers num_dec_layer = 0 for loss_cls_i, loss_mask_i, loss_dice_i in zip( losses_cls[:-1], losses_mask[:-1], losses_dice[:-1]): loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i loss_dict[f'd{num_dec_layer}.loss_mask'] = loss_mask_i loss_dict[f'd{num_dec_layer}.loss_dice'] = loss_dice_i num_dec_layer += 1 return loss_dict
[docs] def loss_single(self, cls_scores, mask_preds, gt_labels_list, gt_masks_list, img_metas): """Loss function for outputs from a single decoder layer. Args: cls_scores (Tensor): Mask score logits from a single decoder layer for all images. Shape (batch_size, num_queries, cls_out_channels). mask_preds (Tensor): Mask logits for a pixel decoder for all images. Shape (batch_size, num_queries, h, w). gt_labels_list (list[Tensor]): Ground truth class indices for each image, each with shape (n, ). n is the sum of number of stuff types and number of instances in a image. gt_masks_list (list[Tensor]): Ground truth mask for each image, each with shape (n, h, w). img_metas (list[dict]): List of image meta information. Returns: tuple[Tensor]: Loss components for outputs from a single decoder\ layer. """ num_imgs = cls_scores.size(0) cls_scores_list = [cls_scores[i] for i in range(num_imgs)] mask_preds_list = [mask_preds[i] for i in range(num_imgs)] (labels_list, label_weights_list, mask_targets_list, mask_weights_list, num_total_pos, num_total_neg) = self.get_targets(cls_scores_list, mask_preds_list, gt_labels_list, gt_masks_list, img_metas) # shape (batch_size, num_queries) labels = torch.stack(labels_list, dim=0) # shape (batch_size, num_queries) label_weights = torch.stack(label_weights_list, dim=0) # shape (num_total_gts, h, w) mask_targets = torch.cat(mask_targets_list, dim=0) # shape (batch_size, num_queries) mask_weights = torch.stack(mask_weights_list, dim=0) # classfication loss # shape (batch_size * num_queries, ) cls_scores = cls_scores.flatten(0, 1) labels = labels.flatten(0, 1) label_weights = label_weights.flatten(0, 1) class_weight = cls_scores.new_ones(self.num_classes + 1) class_weight[-1] = self.bg_cls_weight loss_cls = self.loss_cls( cls_scores, labels, label_weights, avg_factor=class_weight[labels].sum()) num_total_masks = reduce_mean(cls_scores.new_tensor([num_total_pos])) num_total_masks = max(num_total_masks, 1) # extract positive ones # shape (batch_size, num_queries, h, w) -> (num_total_gts, h, w) mask_preds = mask_preds[mask_weights > 0] target_shape = mask_targets.shape[-2:] if mask_targets.shape[0] == 0: # zero match loss_dice = mask_preds.sum() loss_mask = mask_preds.sum() return loss_cls, loss_mask, loss_dice # upsample to shape of target # shape (num_total_gts, h, w) mask_preds = F.interpolate( mask_preds.unsqueeze(1), target_shape, mode='bilinear', align_corners=False).squeeze(1) # dice loss loss_dice = self.loss_dice( mask_preds, mask_targets, avg_factor=num_total_masks) # mask loss # FocalLoss support input of shape (n, num_class) h, w = mask_preds.shape[-2:] # shape (num_total_gts, h, w) -> (num_total_gts * h * w, 1) mask_preds = mask_preds.reshape(-1, 1) # shape (num_total_gts, h, w) -> (num_total_gts * h * w) mask_targets = mask_targets.reshape(-1) # target is (1 - mask_targets) !!! loss_mask = self.loss_mask( mask_preds, 1 - mask_targets, avg_factor=num_total_masks * h * w) return loss_cls, loss_mask, loss_dice
[docs] def forward(self, feats, img_metas): """Forward function. Args: feats (list[Tensor]): Features from the upstream network, each is a 4D-tensor. img_metas (list[dict]): List of image information. Returns: tuple: a tuple contains two elements. - all_cls_scores (Tensor): Classification scores for each\ scale level. Each is a 4D-tensor with shape\ (num_decoder, batch_size, num_queries, cls_out_channels).\ Note `cls_out_channels` should includes background. - all_mask_preds (Tensor): Mask scores for each decoder\ layer. Each with shape (num_decoder, batch_size,\ num_queries, h, w). """ batch_size = len(img_metas) input_img_h, input_img_w = img_metas[0]['batch_input_shape'] padding_mask = feats[-1].new_ones( (batch_size, input_img_h, input_img_w), dtype=torch.float32) for i in range(batch_size): img_h, img_w, _ = img_metas[i]['img_shape'] padding_mask[i, :img_h, :img_w] = 0 padding_mask = F.interpolate( padding_mask.unsqueeze(1), size=feats[-1].shape[-2:], mode='nearest').to(torch.bool).squeeze(1) # when backbone is swin, memory is output of last stage of swin. # when backbone is r50, memory is output of tranformer encoder. mask_features, memory = self.pixel_decoder(feats, img_metas) pos_embed = self.decoder_pe(padding_mask) memory = self.decoder_input_proj(memory) # shape (batch_size, c, h, w) -> (h*w, batch_size, c) memory = memory.flatten(2).permute(2, 0, 1) pos_embed = pos_embed.flatten(2).permute(2, 0, 1) # shape (batch_size, h * w) padding_mask = padding_mask.flatten(1) # shape = (num_queries, embed_dims) query_embed = self.query_embed.weight # shape = (num_queries, batch_size, embed_dims) query_embed = query_embed.unsqueeze(1).repeat(1, batch_size, 1) target = torch.zeros_like(query_embed) # shape (num_decoder, num_queries, batch_size, embed_dims) out_dec = self.transformer_decoder( query=target, key=memory, value=memory, key_pos=pos_embed, query_pos=query_embed, key_padding_mask=padding_mask) # shape (num_decoder, batch_size, num_queries, embed_dims) out_dec = out_dec.transpose(1, 2) # cls_scores all_cls_scores = self.cls_embed(out_dec) # mask_preds mask_embed = self.mask_embed(out_dec) all_mask_preds = torch.einsum('lbqc,bchw->lbqhw', mask_embed, mask_features) return all_cls_scores, all_mask_preds
[docs] def forward_train(self, feats, img_metas, gt_bboxes, gt_labels, gt_masks, gt_semantic_seg, gt_bboxes_ignore=None): """Forward function for training mode. Args: feats (list[Tensor]): Multi-level features from the upstream network, each is a 4D-tensor. img_metas (list[Dict]): List of image information. gt_bboxes (list[Tensor]): Each element is ground truth bboxes of the image, shape (num_gts, 4). Not used here. gt_labels (list[Tensor]): Each element is ground truth labels of each box, shape (num_gts,). gt_masks (list[BitmapMasks]): Each element is masks of instances of a image, shape (num_gts, h, w). gt_semantic_seg (list[tensor]):Each element is the ground truth of semantic segmentation with the shape (N, H, W). [0, num_thing_class - 1] means things, [num_thing_class, num_class-1] means stuff, 255 means VOID. gt_bboxes_ignore (list[Tensor]): Ground truth bboxes to be ignored. Defaults to None. Returns: dict[str, Tensor]: a dictionary of loss components """ # not consider ignoring bboxes assert gt_bboxes_ignore is None # forward all_cls_scores, all_mask_preds = self(feats, img_metas) # preprocess ground truth gt_labels, gt_masks = self.preprocess_gt(gt_labels, gt_masks, gt_semantic_seg) # loss losses = self.loss(all_cls_scores, all_mask_preds, gt_labels, gt_masks, img_metas) return losses
[docs] def simple_test(self, feats, img_metas, rescale=False): """Test segment without test-time aumengtation. Only the output of last decoder layers was used. Args: feats (list[Tensor]): Multi-level features from the upstream network, each is a 4D-tensor. img_metas (list[dict]): List of image information. rescale (bool, optional): If True, return boxes in original image space. Default False. Returns: list[dict[str, np.array]]: semantic segmentation results\ and panoptic segmentation results for each image. .. code-block:: none [ { 'pan_results': <np.ndarray>, # shape = [h, w] }, ... ] """ all_cls_scores, all_mask_preds = self(feats, img_metas) mask_cls_results = all_cls_scores[-1] mask_pred_results = all_mask_preds[-1] # upsample masks img_shape = img_metas[0]['batch_input_shape'] mask_pred_results = F.interpolate( mask_pred_results, size=(img_shape[0], img_shape[1]), mode='bilinear', align_corners=False) results = [] for mask_cls_result, mask_pred_result, meta in zip( mask_cls_results, mask_pred_results, img_metas): # remove padding img_height, img_width = meta['img_shape'][:2] mask_pred_result = mask_pred_result[:, :img_height, :img_width] if rescale: # return result in original resolution ori_height, ori_width = meta['ori_shape'][:2] mask_pred_result = F.interpolate(mask_pred_result.unsqueeze(1), size=(ori_height, ori_width), mode='bilinear', align_corners=False)\ .squeeze(1) mask = self.post_process(mask_cls_result, mask_pred_result) results.append(mask) return results
[docs] def post_process(self, mask_cls, mask_pred): """Panoptic segmengation inference. This implementation is modified from `MaskFormer <https://github.com/facebookresearch/MaskFormer>`_. Args: mask_cls (Tensor): Classfication outputs for a image. shape = (num_queries, cls_out_channels). mask_pred (Tensor): Mask outputs for a image. shape = (num_queries, h, w). Returns: Tensor: panoptic segment result of shape (h, w),\ each element in Tensor means: segment_id = _cls + instance_id * INSTANCE_OFFSET. """ object_mask_thr = self.test_cfg.get('object_mask_thr', 0.8) iou_thr = self.test_cfg.get('iou_thr', 0.8) scores, labels = F.softmax(mask_cls, dim=-1).max(-1) mask_pred = mask_pred.sigmoid() keep = labels.ne(self.num_classes) & (scores > object_mask_thr) cur_scores = scores[keep] cur_classes = labels[keep] cur_masks = mask_pred[keep] cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks h, w = cur_masks.shape[-2:] panoptic_seg = torch.full((h, w), self.num_classes, dtype=torch.int32, device=cur_masks.device) if cur_masks.shape[0] == 0: # We didn't detect any mask :( pass else: cur_mask_ids = cur_prob_masks.argmax(0) instance_id = 1 for k in range(cur_classes.shape[0]): pred_class = int(cur_classes[k].item()) isthing = pred_class < self.num_things_classes mask = cur_mask_ids == k mask_area = mask.sum().item() original_area = (cur_masks[k] >= 0.5).sum().item() if mask_area > 0 and original_area > 0: if mask_area / original_area < iou_thr: continue if not isthing: # different stuff regions of same class will be # merged here, and stuff share the instance_id 0. panoptic_seg[mask] = pred_class else: panoptic_seg[mask] = ( pred_class + instance_id * INSTANCE_OFFSET) instance_id += 1 return panoptic_seg