Shortcuts

Source code for mmdet.models.roi_heads.mask_scoring_roi_head

# Copyright (c) OpenMMLab. All rights reserved.
import torch

from mmdet.core import bbox2roi
from ..builder import HEADS, build_head
from .standard_roi_head import StandardRoIHead


[docs]@HEADS.register_module() class MaskScoringRoIHead(StandardRoIHead): """Mask Scoring RoIHead for Mask Scoring RCNN. https://arxiv.org/abs/1903.00241 """ def __init__(self, mask_iou_head, **kwargs): assert mask_iou_head is not None super(MaskScoringRoIHead, self).__init__(**kwargs) self.mask_iou_head = build_head(mask_iou_head) def _mask_forward_train(self, x, sampling_results, bbox_feats, gt_masks, img_metas): """Run forward function and calculate loss for Mask head in training.""" pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results]) mask_results = super(MaskScoringRoIHead, self)._mask_forward_train(x, sampling_results, bbox_feats, gt_masks, img_metas) if mask_results['loss_mask'] is None: return mask_results # mask iou head forward and loss pos_mask_pred = mask_results['mask_pred'][ range(mask_results['mask_pred'].size(0)), pos_labels] mask_iou_pred = self.mask_iou_head(mask_results['mask_feats'], pos_mask_pred) pos_mask_iou_pred = mask_iou_pred[range(mask_iou_pred.size(0)), pos_labels] mask_iou_targets = self.mask_iou_head.get_targets( sampling_results, gt_masks, pos_mask_pred, mask_results['mask_targets'], self.train_cfg) loss_mask_iou = self.mask_iou_head.loss(pos_mask_iou_pred, mask_iou_targets) mask_results['loss_mask'].update(loss_mask_iou) return mask_results
[docs] def simple_test_mask(self, x, img_metas, det_bboxes, det_labels, rescale=False): """Obtain mask prediction without augmentation.""" # image shapes of images in the batch ori_shapes = tuple(meta['ori_shape'] for meta in img_metas) scale_factors = tuple(meta['scale_factor'] for meta in img_metas) num_imgs = len(det_bboxes) if all(det_bbox.shape[0] == 0 for det_bbox in det_bboxes): num_classes = self.mask_head.num_classes segm_results = [[[] for _ in range(num_classes)] for _ in range(num_imgs)] mask_scores = [[[] for _ in range(num_classes)] for _ in range(num_imgs)] else: # if det_bboxes is rescaled to the original image size, we need to # rescale it back to the testing scale to obtain RoIs. if rescale and not isinstance(scale_factors[0], float): scale_factors = [ torch.from_numpy(scale_factor).to(det_bboxes[0].device) for scale_factor in scale_factors ] _bboxes = [ det_bboxes[i][:, :4] * scale_factors[i] if rescale else det_bboxes[i] for i in range(num_imgs) ] mask_rois = bbox2roi(_bboxes) mask_results = self._mask_forward(x, mask_rois) concat_det_labels = torch.cat(det_labels) # get mask scores with mask iou head mask_feats = mask_results['mask_feats'] mask_pred = mask_results['mask_pred'] mask_iou_pred = self.mask_iou_head( mask_feats, mask_pred[range(concat_det_labels.size(0)), concat_det_labels]) # split batch mask prediction back to each image num_bboxes_per_img = tuple(len(_bbox) for _bbox in _bboxes) mask_preds = mask_pred.split(num_bboxes_per_img, 0) mask_iou_preds = mask_iou_pred.split(num_bboxes_per_img, 0) # apply mask post-processing to each image individually segm_results = [] mask_scores = [] for i in range(num_imgs): if det_bboxes[i].shape[0] == 0: segm_results.append( [[] for _ in range(self.mask_head.num_classes)]) mask_scores.append( [[] for _ in range(self.mask_head.num_classes)]) else: segm_result = self.mask_head.get_seg_masks( mask_preds[i], _bboxes[i], det_labels[i], self.test_cfg, ori_shapes[i], scale_factors[i], rescale) # get mask scores with mask iou head mask_score = self.mask_iou_head.get_mask_scores( mask_iou_preds[i], det_bboxes[i], det_labels[i]) segm_results.append(segm_result) mask_scores.append(mask_score) return list(zip(segm_results, mask_scores))
Read the Docs v: v2.19.1
Versions
latest
stable
v2.19.1
v2.19.0
v2.18.1
v2.18.0
v2.17.0
v2.16.0
v2.15.1
v2.15.0
v2.14.0
v2.13.0
v2.12.0
v2.11.0
v2.10.0
v2.9.0
v2.8.0
v2.7.0
v2.6.0
v2.5.0
v2.4.0
v2.3.0
v2.2.1
v2.2.0
v2.1.0
v2.0.0
v1.2.0
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.