Source code for mmdet.models.roi_heads.htc_roi_head

import torch
import torch.nn.functional as F

from mmdet.core import (bbox2result, bbox2roi, bbox_mapping, merge_aug_bboxes,
                        merge_aug_masks, multiclass_nms)
from ..builder import HEADS, build_head, build_roi_extractor
from .cascade_roi_head import CascadeRoIHead


[docs]@HEADS.register_module() class HybridTaskCascadeRoIHead(CascadeRoIHead): """Hybrid task cascade roi head including one bbox head and one mask head. https://arxiv.org/abs/1901.07518 """ def __init__(self, num_stages, stage_loss_weights, semantic_roi_extractor=None, semantic_head=None, semantic_fusion=('bbox', 'mask'), interleaved=True, mask_info_flow=True, **kwargs): super(HybridTaskCascadeRoIHead, self).__init__(num_stages, stage_loss_weights, **kwargs) assert self.with_bbox and self.with_mask assert not self.with_shared_head # shared head is not supported if semantic_head is not None: self.semantic_roi_extractor = build_roi_extractor( semantic_roi_extractor) self.semantic_head = build_head(semantic_head) self.semantic_fusion = semantic_fusion self.interleaved = interleaved self.mask_info_flow = mask_info_flow def init_weights(self, pretrained): super(HybridTaskCascadeRoIHead, self).init_weights(pretrained) if self.with_semantic: self.semantic_head.init_weights() @property def with_semantic(self): if hasattr(self, 'semantic_head') and self.semantic_head is not None: return True else: return False def forward_dummy(self, x, proposals): outs = () # semantic head if self.with_semantic: _, semantic_feat = self.semantic_head(x) else: semantic_feat = None # bbox heads rois = bbox2roi([proposals]) for i in range(self.num_stages): bbox_results = self._bbox_forward( i, x, rois, semantic_feat=semantic_feat) outs = outs + (bbox_results['cls_score'], bbox_results['bbox_pred']) # mask heads if self.with_mask: mask_rois = rois[:100] mask_roi_extractor = self.mask_roi_extractor[-1] mask_feats = mask_roi_extractor( x[:len(mask_roi_extractor.featmap_strides)], mask_rois) if self.with_semantic and 'mask' in self.semantic_fusion: mask_semantic_feat = self.semantic_roi_extractor( [semantic_feat], mask_rois) mask_feats += mask_semantic_feat last_feat = None for i in range(self.num_stages): mask_head = self.mask_head[i] if self.mask_info_flow: mask_pred, last_feat = mask_head(mask_feats, last_feat) else: mask_pred = mask_head(mask_feats) outs = outs + (mask_pred, ) return outs def _bbox_forward_train(self, stage, x, sampling_results, gt_bboxes, gt_labels, rcnn_train_cfg, semantic_feat=None): bbox_head = self.bbox_head[stage] rois = bbox2roi([res.bboxes for res in sampling_results]) bbox_results = self._bbox_forward( stage, x, rois, semantic_feat=semantic_feat) bbox_targets = bbox_head.get_targets(sampling_results, gt_bboxes, gt_labels, rcnn_train_cfg) loss_bbox = bbox_head.loss(bbox_results['cls_score'], bbox_results['bbox_pred'], rois, *bbox_targets) bbox_results.update( loss_bbox=loss_bbox, rois=rois, bbox_targets=bbox_targets, ) return bbox_results def _mask_forward_train(self, stage, x, sampling_results, gt_masks, rcnn_train_cfg, semantic_feat=None): mask_roi_extractor = self.mask_roi_extractor[stage] mask_head = self.mask_head[stage] pos_rois = bbox2roi([res.pos_bboxes for res in sampling_results]) mask_feats = mask_roi_extractor(x[:mask_roi_extractor.num_inputs], pos_rois) # semantic feature fusion # element-wise sum for original features and pooled semantic features if self.with_semantic and 'mask' in self.semantic_fusion: mask_semantic_feat = self.semantic_roi_extractor([semantic_feat], pos_rois) if mask_semantic_feat.shape[-2:] != mask_feats.shape[-2:]: mask_semantic_feat = F.adaptive_avg_pool2d( mask_semantic_feat, mask_feats.shape[-2:]) mask_feats += mask_semantic_feat # mask information flow # forward all previous mask heads to obtain last_feat, and fuse it # with the normal mask feature if self.mask_info_flow: last_feat = None for i in range(stage): last_feat = self.mask_head[i]( mask_feats, last_feat, return_logits=False) mask_pred = mask_head(mask_feats, last_feat, return_feat=False) else: mask_pred = mask_head(mask_feats, return_feat=False) mask_targets = mask_head.get_targets(sampling_results, gt_masks, rcnn_train_cfg) pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results]) loss_mask = mask_head.loss(mask_pred, mask_targets, pos_labels) mask_results = dict(loss_mask=loss_mask) return mask_results def _bbox_forward(self, stage, x, rois, semantic_feat=None): bbox_roi_extractor = self.bbox_roi_extractor[stage] bbox_head = self.bbox_head[stage] bbox_feats = bbox_roi_extractor( x[:len(bbox_roi_extractor.featmap_strides)], rois) if self.with_semantic and 'bbox' in self.semantic_fusion: bbox_semantic_feat = self.semantic_roi_extractor([semantic_feat], rois) if bbox_semantic_feat.shape[-2:] != bbox_feats.shape[-2:]: bbox_semantic_feat = F.adaptive_avg_pool2d( bbox_semantic_feat, bbox_feats.shape[-2:]) bbox_feats += bbox_semantic_feat cls_score, bbox_pred = bbox_head(bbox_feats) bbox_results = dict(cls_score=cls_score, bbox_pred=bbox_pred) return bbox_results def _mask_forward_test(self, stage, x, bboxes, semantic_feat=None): mask_roi_extractor = self.mask_roi_extractor[stage] mask_head = self.mask_head[stage] mask_rois = bbox2roi([bboxes]) mask_feats = mask_roi_extractor( x[:len(mask_roi_extractor.featmap_strides)], mask_rois) if self.with_semantic and 'mask' in self.semantic_fusion: mask_semantic_feat = self.semantic_roi_extractor([semantic_feat], mask_rois) if mask_semantic_feat.shape[-2:] != mask_feats.shape[-2:]: mask_semantic_feat = F.adaptive_avg_pool2d( mask_semantic_feat, mask_feats.shape[-2:]) mask_feats += mask_semantic_feat if self.mask_info_flow: last_feat = None last_pred = None for i in range(stage): mask_pred, last_feat = self.mask_head[i](mask_feats, last_feat) if last_pred is not None: mask_pred = mask_pred + last_pred last_pred = mask_pred mask_pred = mask_head(mask_feats, last_feat, return_feat=False) if last_pred is not None: mask_pred = mask_pred + last_pred else: mask_pred = mask_head(mask_feats) return mask_pred
[docs] def forward_train(self, x, img_metas, proposal_list, gt_bboxes, gt_labels, gt_bboxes_ignore=None, gt_masks=None, gt_semantic_seg=None): # semantic segmentation part # 2 outputs: segmentation prediction and embedded features losses = dict() if self.with_semantic: semantic_pred, semantic_feat = self.semantic_head(x) loss_seg = self.semantic_head.loss(semantic_pred, gt_semantic_seg) losses['loss_semantic_seg'] = loss_seg else: semantic_feat = None for i in range(self.num_stages): self.current_stage = i rcnn_train_cfg = self.train_cfg[i] lw = self.stage_loss_weights[i] # assign gts and sample proposals sampling_results = [] bbox_assigner = self.bbox_assigner[i] bbox_sampler = self.bbox_sampler[i] num_imgs = len(img_metas) if gt_bboxes_ignore is None: gt_bboxes_ignore = [None for _ in range(num_imgs)] for j in range(num_imgs): assign_result = bbox_assigner.assign(proposal_list[j], gt_bboxes[j], gt_bboxes_ignore[j], gt_labels[j]) sampling_result = bbox_sampler.sample( assign_result, proposal_list[j], gt_bboxes[j], gt_labels[j], feats=[lvl_feat[j][None] for lvl_feat in x]) sampling_results.append(sampling_result) # bbox head forward and loss bbox_results = \ self._bbox_forward_train( i, x, sampling_results, gt_bboxes, gt_labels, rcnn_train_cfg, semantic_feat) roi_labels = bbox_results['bbox_targets'][0] for name, value in bbox_results['loss_bbox'].items(): losses[f's{i}.{name}'] = ( value * lw if 'loss' in name else value) # mask head forward and loss if self.with_mask: # interleaved execution: use regressed bboxes by the box branch # to train the mask branch if self.interleaved: pos_is_gts = [res.pos_is_gt for res in sampling_results] with torch.no_grad(): proposal_list = self.bbox_head[i].refine_bboxes( bbox_results['rois'], roi_labels, bbox_results['bbox_pred'], pos_is_gts, img_metas) # re-assign and sample 512 RoIs from 512 RoIs sampling_results = [] for j in range(num_imgs): assign_result = bbox_assigner.assign( proposal_list[j], gt_bboxes[j], gt_bboxes_ignore[j], gt_labels[j]) sampling_result = bbox_sampler.sample( assign_result, proposal_list[j], gt_bboxes[j], gt_labels[j], feats=[lvl_feat[j][None] for lvl_feat in x]) sampling_results.append(sampling_result) mask_results = self._mask_forward_train( i, x, sampling_results, gt_masks, rcnn_train_cfg, semantic_feat) for name, value in mask_results['loss_mask'].items(): losses[f's{i}.{name}'] = ( value * lw if 'loss' in name else value) # refine bboxes (same as Cascade R-CNN) if i < self.num_stages - 1 and not self.interleaved: pos_is_gts = [res.pos_is_gt for res in sampling_results] with torch.no_grad(): proposal_list = self.bbox_head[i].refine_bboxes( bbox_results['rois'], roi_labels, bbox_results['bbox_pred'], pos_is_gts, img_metas) return losses
[docs] def simple_test(self, x, proposal_list, img_metas, rescale=False): if self.with_semantic: _, semantic_feat = self.semantic_head(x) else: semantic_feat = None img_shape = img_metas[0]['img_shape'] ori_shape = img_metas[0]['ori_shape'] scale_factor = img_metas[0]['scale_factor'] # "ms" in variable names means multi-stage ms_bbox_result = {} ms_segm_result = {} ms_scores = [] rcnn_test_cfg = self.test_cfg rois = bbox2roi(proposal_list) for i in range(self.num_stages): bbox_head = self.bbox_head[i] bbox_results = self._bbox_forward( i, x, rois, semantic_feat=semantic_feat) ms_scores.append(bbox_results['cls_score']) if i < self.num_stages - 1: bbox_label = bbox_results['cls_score'].argmax(dim=1) rois = bbox_head.regress_by_class(rois, bbox_label, bbox_results['bbox_pred'], img_metas[0]) cls_score = sum(ms_scores) / float(len(ms_scores)) det_bboxes, det_labels = self.bbox_head[-1].get_bboxes( rois, cls_score, bbox_results['bbox_pred'], img_shape, scale_factor, rescale=rescale, cfg=rcnn_test_cfg) bbox_result = bbox2result(det_bboxes, det_labels, self.bbox_head[-1].num_classes) ms_bbox_result['ensemble'] = bbox_result if self.with_mask: if det_bboxes.shape[0] == 0: mask_classes = self.mask_head[-1].num_classes segm_result = [[] for _ in range(mask_classes)] else: _bboxes = ( det_bboxes[:, :4] * det_bboxes.new_tensor(scale_factor) if rescale else det_bboxes) mask_rois = bbox2roi([_bboxes]) aug_masks = [] mask_roi_extractor = self.mask_roi_extractor[-1] mask_feats = mask_roi_extractor( x[:len(mask_roi_extractor.featmap_strides)], mask_rois) if self.with_semantic and 'mask' in self.semantic_fusion: mask_semantic_feat = self.semantic_roi_extractor( [semantic_feat], mask_rois) mask_feats += mask_semantic_feat last_feat = None for i in range(self.num_stages): mask_head = self.mask_head[i] if self.mask_info_flow: mask_pred, last_feat = mask_head(mask_feats, last_feat) else: mask_pred = mask_head(mask_feats) aug_masks.append(mask_pred.sigmoid().cpu().numpy()) merged_masks = merge_aug_masks(aug_masks, [img_metas] * self.num_stages, self.test_cfg) segm_result = self.mask_head[-1].get_seg_masks( merged_masks, _bboxes, det_labels, rcnn_test_cfg, ori_shape, scale_factor, rescale) ms_segm_result['ensemble'] = segm_result if self.with_mask: results = (ms_bbox_result['ensemble'], ms_segm_result['ensemble']) else: results = ms_bbox_result['ensemble'] return results
[docs] def aug_test(self, img_feats, proposal_list, img_metas, rescale=False): """Test with augmentations. If rescale is False, then returned bboxes and masks will fit the scale of imgs[0]. """ if self.with_semantic: semantic_feats = [ self.semantic_head(feat)[1] for feat in img_feats ] else: semantic_feats = [None] * len(img_metas) rcnn_test_cfg = self.test_cfg aug_bboxes = [] aug_scores = [] for x, img_meta, semantic in zip(img_feats, img_metas, semantic_feats): # only one image in the batch img_shape = img_meta[0]['img_shape'] scale_factor = img_meta[0]['scale_factor'] flip = img_meta[0]['flip'] flip_direction = img_meta[0]['flip_direction'] proposals = bbox_mapping(proposal_list[0][:, :4], img_shape, scale_factor, flip, flip_direction) # "ms" in variable names means multi-stage ms_scores = [] rois = bbox2roi([proposals]) for i in range(self.num_stages): bbox_head = self.bbox_head[i] bbox_results = self._bbox_forward( i, x, rois, semantic_feat=semantic) ms_scores.append(bbox_results['cls_score']) if i < self.num_stages - 1: bbox_label = bbox_results['cls_score'].argmax(dim=1) rois = bbox_head.regress_by_class( rois, bbox_label, bbox_results['bbox_pred'], img_meta[0]) cls_score = sum(ms_scores) / float(len(ms_scores)) bboxes, scores = self.bbox_head[-1].get_bboxes( rois, cls_score, bbox_results['bbox_pred'], img_shape, scale_factor, rescale=False, cfg=None) aug_bboxes.append(bboxes) aug_scores.append(scores) # after merging, bboxes will be rescaled to the original image size merged_bboxes, merged_scores = merge_aug_bboxes( aug_bboxes, aug_scores, img_metas, rcnn_test_cfg) det_bboxes, det_labels = multiclass_nms(merged_bboxes, merged_scores, rcnn_test_cfg.score_thr, rcnn_test_cfg.nms, rcnn_test_cfg.max_per_img) bbox_result = bbox2result(det_bboxes, det_labels, self.bbox_head[-1].num_classes) if self.with_mask: if det_bboxes.shape[0] == 0: segm_result = [[] for _ in range(self.mask_head[-1].num_classes - 1)] else: aug_masks = [] aug_img_metas = [] for x, img_meta, semantic in zip(img_feats, img_metas, semantic_feats): img_shape = img_meta[0]['img_shape'] scale_factor = img_meta[0]['scale_factor'] flip = img_meta[0]['flip'] flip_direction = img_meta[0]['flip_direction'] _bboxes = bbox_mapping(det_bboxes[:, :4], img_shape, scale_factor, flip, flip_direction) mask_rois = bbox2roi([_bboxes]) mask_feats = self.mask_roi_extractor[-1]( x[:len(self.mask_roi_extractor[-1].featmap_strides)], mask_rois) if self.with_semantic: semantic_feat = semantic mask_semantic_feat = self.semantic_roi_extractor( [semantic_feat], mask_rois) if mask_semantic_feat.shape[-2:] != mask_feats.shape[ -2:]: mask_semantic_feat = F.adaptive_avg_pool2d( mask_semantic_feat, mask_feats.shape[-2:]) mask_feats += mask_semantic_feat last_feat = None for i in range(self.num_stages): mask_head = self.mask_head[i] if self.mask_info_flow: mask_pred, last_feat = mask_head( mask_feats, last_feat) else: mask_pred = mask_head(mask_feats) aug_masks.append(mask_pred.sigmoid().cpu().numpy()) aug_img_metas.append(img_meta) merged_masks = merge_aug_masks(aug_masks, aug_img_metas, self.test_cfg) ori_shape = img_metas[0][0]['ori_shape'] segm_result = self.mask_head[-1].get_seg_masks( merged_masks, det_bboxes, det_labels, rcnn_test_cfg, ori_shape, scale_factor=1.0, rescale=False) return bbox_result, segm_result else: return bbox_result