Shortcuts

Source code for mmdet.models.dense_heads.solo_head

# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule

from mmdet.core import InstanceData, mask_matrix_nms, multi_apply
from mmdet.core.utils import center_of_mass, generate_coordinate
from mmdet.models.builder import HEADS, build_loss
from .base_mask_head import BaseMaskHead


[docs]@HEADS.register_module() class SOLOHead(BaseMaskHead): """SOLO mask head used in `SOLO: Segmenting Objects by Locations. <https://arxiv.org/abs/1912.04488>`_ Args: num_classes (int): Number of categories excluding the background category. in_channels (int): Number of channels in the input feature map. feat_channels (int): Number of hidden channels. Used in child classes. Default: 256. stacked_convs (int): Number of stacking convs of the head. Default: 4. strides (tuple): Downsample factor of each feature map. scale_ranges (tuple[tuple[int, int]]): Area range of multiple level masks, in the format [(min1, max1), (min2, max2), ...]. A range of (16, 64) means the area range between (16, 64). pos_scale (float): Constant scale factor to control the center region. num_grids (list[int]): Divided image into a uniform grids, each feature map has a different grid value. The number of output channels is grid ** 2. Default: [40, 36, 24, 16, 12]. cls_down_index (int): The index of downsample operation in classification branch. Default: 0. loss_mask (dict): Config of mask loss. loss_cls (dict): Config of classification loss. norm_cfg (dict): dictionary to construct and config norm layer. Default: norm_cfg=dict(type='GN', num_groups=32, requires_grad=True). train_cfg (dict): Training config of head. test_cfg (dict): Testing config of head. init_cfg (dict or list[dict], optional): Initialization config dict. """ def __init__( self, num_classes, in_channels, feat_channels=256, stacked_convs=4, strides=(4, 8, 16, 32, 64), scale_ranges=((8, 32), (16, 64), (32, 128), (64, 256), (128, 512)), pos_scale=0.2, num_grids=[40, 36, 24, 16, 12], cls_down_index=0, loss_mask=None, loss_cls=None, norm_cfg=dict(type='GN', num_groups=32, requires_grad=True), train_cfg=None, test_cfg=None, init_cfg=[ dict(type='Normal', layer='Conv2d', std=0.01), dict( type='Normal', std=0.01, bias_prob=0.01, override=dict(name='conv_mask_list')), dict( type='Normal', std=0.01, bias_prob=0.01, override=dict(name='conv_cls')) ], ): super(SOLOHead, self).__init__(init_cfg) self.num_classes = num_classes self.cls_out_channels = self.num_classes self.in_channels = in_channels self.feat_channels = feat_channels self.stacked_convs = stacked_convs self.strides = strides self.num_grids = num_grids # number of FPN feats self.num_levels = len(strides) assert self.num_levels == len(scale_ranges) == len(num_grids) self.scale_ranges = scale_ranges self.pos_scale = pos_scale self.cls_down_index = cls_down_index self.loss_cls = build_loss(loss_cls) self.loss_mask = build_loss(loss_mask) self.norm_cfg = norm_cfg self.init_cfg = init_cfg self.train_cfg = train_cfg self.test_cfg = test_cfg self._init_layers() def _init_layers(self): self.mask_convs = nn.ModuleList() self.cls_convs = nn.ModuleList() for i in range(self.stacked_convs): chn = self.in_channels + 2 if i == 0 else self.feat_channels self.mask_convs.append( ConvModule( chn, self.feat_channels, 3, stride=1, padding=1, norm_cfg=self.norm_cfg)) chn = self.in_channels if i == 0 else self.feat_channels self.cls_convs.append( ConvModule( chn, self.feat_channels, 3, stride=1, padding=1, norm_cfg=self.norm_cfg)) self.conv_mask_list = nn.ModuleList() for num_grid in self.num_grids: self.conv_mask_list.append( nn.Conv2d(self.feat_channels, num_grid**2, 1)) self.conv_cls = nn.Conv2d( self.feat_channels, self.cls_out_channels, 3, padding=1)
[docs] def resize_feats(self, feats): """Downsample the first feat and upsample last feat in feats.""" out = [] for i in range(len(feats)): if i == 0: out.append( F.interpolate(feats[0], scale_factor=0.5, mode='bilinear')) elif i == len(feats) - 1: out.append( F.interpolate( feats[i], size=feats[i - 1].shape[-2:], mode='bilinear')) else: out.append(feats[i]) return out
[docs] def forward(self, feats): assert len(feats) == self.num_levels feats = self.resize_feats(feats) mlvl_mask_preds = [] mlvl_cls_preds = [] for i in range(self.num_levels): x = feats[i] mask_feat = x cls_feat = x # generate and concat the coordinate coord_feat = generate_coordinate(mask_feat.size(), mask_feat.device) mask_feat = torch.cat([mask_feat, coord_feat], 1) for mask_layer in (self.mask_convs): mask_feat = mask_layer(mask_feat) mask_feat = F.interpolate( mask_feat, scale_factor=2, mode='bilinear') mask_pred = self.conv_mask_list[i](mask_feat) # cls branch for j, cls_layer in enumerate(self.cls_convs): if j == self.cls_down_index: num_grid = self.num_grids[i] cls_feat = F.interpolate( cls_feat, size=num_grid, mode='bilinear') cls_feat = cls_layer(cls_feat) cls_pred = self.conv_cls(cls_feat) if not self.training: feat_wh = feats[0].size()[-2:] upsampled_size = (feat_wh[0] * 2, feat_wh[1] * 2) mask_pred = F.interpolate( mask_pred.sigmoid(), size=upsampled_size, mode='bilinear') cls_pred = cls_pred.sigmoid() # get local maximum local_max = F.max_pool2d(cls_pred, 2, stride=1, padding=1) keep_mask = local_max[:, :, :-1, :-1] == cls_pred cls_pred = cls_pred * keep_mask mlvl_mask_preds.append(mask_pred) mlvl_cls_preds.append(cls_pred) return mlvl_mask_preds, mlvl_cls_preds
[docs] def loss(self, mlvl_mask_preds, mlvl_cls_preds, gt_labels, gt_masks, img_metas, gt_bboxes=None, **kwargs): """Calculate the loss of total batch. Args: mlvl_mask_preds (list[Tensor]): Multi-level mask prediction. Each element in the list has shape (batch_size, num_grids**2 ,h ,w). mlvl_cls_preds (list[Tensor]): Multi-level scores. Each element in the list has shape (batch_size, num_classes, num_grids ,num_grids). gt_labels (list[Tensor]): Labels of multiple images. gt_masks (list[Tensor]): Ground truth masks of multiple images. Each has shape (num_instances, h, w). img_metas (list[dict]): Meta information of multiple images. gt_bboxes (list[Tensor]): Ground truth bboxes of multiple images. Default: None. Returns: dict[str, Tensor]: A dictionary of loss components. """ num_levels = self.num_levels num_imgs = len(gt_labels) featmap_sizes = [featmap.size()[-2:] for featmap in mlvl_mask_preds] # `BoolTensor` in `pos_masks` represent # whether the corresponding point is # positive pos_mask_targets, labels, pos_masks = multi_apply( self._get_targets_single, gt_bboxes, gt_labels, gt_masks, featmap_sizes=featmap_sizes) # change from the outside list meaning multi images # to the outside list meaning multi levels mlvl_pos_mask_targets = [[] for _ in range(num_levels)] mlvl_pos_mask_preds = [[] for _ in range(num_levels)] mlvl_pos_masks = [[] for _ in range(num_levels)] mlvl_labels = [[] for _ in range(num_levels)] for img_id in range(num_imgs): assert num_levels == len(pos_mask_targets[img_id]) for lvl in range(num_levels): mlvl_pos_mask_targets[lvl].append( pos_mask_targets[img_id][lvl]) mlvl_pos_mask_preds[lvl].append( mlvl_mask_preds[lvl][img_id, pos_masks[img_id][lvl], ...]) mlvl_pos_masks[lvl].append(pos_masks[img_id][lvl].flatten()) mlvl_labels[lvl].append(labels[img_id][lvl].flatten()) # cat multiple image temp_mlvl_cls_preds = [] for lvl in range(num_levels): mlvl_pos_mask_targets[lvl] = torch.cat( mlvl_pos_mask_targets[lvl], dim=0) mlvl_pos_mask_preds[lvl] = torch.cat( mlvl_pos_mask_preds[lvl], dim=0) mlvl_pos_masks[lvl] = torch.cat(mlvl_pos_masks[lvl], dim=0) mlvl_labels[lvl] = torch.cat(mlvl_labels[lvl], dim=0) temp_mlvl_cls_preds.append(mlvl_cls_preds[lvl].permute( 0, 2, 3, 1).reshape(-1, self.cls_out_channels)) num_pos = sum(item.sum() for item in mlvl_pos_masks) # dice loss loss_mask = [] for pred, target in zip(mlvl_pos_mask_preds, mlvl_pos_mask_targets): if pred.size()[0] == 0: loss_mask.append(pred.sum().unsqueeze(0)) continue loss_mask.append( self.loss_mask(pred, target, reduction_override='none')) if num_pos > 0: loss_mask = torch.cat(loss_mask).sum() / num_pos else: loss_mask = torch.cat(loss_mask).mean() flatten_labels = torch.cat(mlvl_labels) flatten_cls_preds = torch.cat(temp_mlvl_cls_preds) loss_cls = self.loss_cls( flatten_cls_preds, flatten_labels, avg_factor=num_pos + 1) return dict(loss_mask=loss_mask, loss_cls=loss_cls)
def _get_targets_single(self, gt_bboxes, gt_labels, gt_masks, featmap_sizes=None): """Compute targets for predictions of single image. Args: gt_bboxes (Tensor): Ground truth bbox of each instance, shape (num_gts, 4). gt_labels (Tensor): Ground truth label of each instance, shape (num_gts,). gt_masks (Tensor): Ground truth mask of each instance, shape (num_gts, h, w). featmap_sizes (list[:obj:`torch.size`]): Size of each feature map from feature pyramid, each element means (feat_h, feat_w). Default: None. Returns: Tuple: Usually returns a tuple containing targets for predictions. - mlvl_pos_mask_targets (list[Tensor]): Each element represent the binary mask targets for positive points in this level, has shape (num_pos, out_h, out_w). - mlvl_labels (list[Tensor]): Each element is classification labels for all points in this level, has shape (num_grid, num_grid). - mlvl_pos_masks (list[Tensor]): Each element is a `BoolTensor` to represent whether the corresponding point in single level is positive, has shape (num_grid **2). """ device = gt_labels.device gt_areas = torch.sqrt((gt_bboxes[:, 2] - gt_bboxes[:, 0]) * (gt_bboxes[:, 3] - gt_bboxes[:, 1])) mlvl_pos_mask_targets = [] mlvl_labels = [] mlvl_pos_masks = [] for (lower_bound, upper_bound), stride, featmap_size, num_grid \ in zip(self.scale_ranges, self.strides, featmap_sizes, self.num_grids): mask_target = torch.zeros( [num_grid**2, featmap_size[0], featmap_size[1]], dtype=torch.uint8, device=device) # FG cat_id: [0, num_classes -1], BG cat_id: num_classes labels = torch.zeros([num_grid, num_grid], dtype=torch.int64, device=device) + self.num_classes pos_mask = torch.zeros([num_grid**2], dtype=torch.bool, device=device) gt_inds = ((gt_areas >= lower_bound) & (gt_areas <= upper_bound)).nonzero().flatten() if len(gt_inds) == 0: mlvl_pos_mask_targets.append( mask_target.new_zeros(0, featmap_size[0], featmap_size[1])) mlvl_labels.append(labels) mlvl_pos_masks.append(pos_mask) continue hit_gt_bboxes = gt_bboxes[gt_inds] hit_gt_labels = gt_labels[gt_inds] hit_gt_masks = gt_masks[gt_inds, ...] pos_w_ranges = 0.5 * (hit_gt_bboxes[:, 2] - hit_gt_bboxes[:, 0]) * self.pos_scale pos_h_ranges = 0.5 * (hit_gt_bboxes[:, 3] - hit_gt_bboxes[:, 1]) * self.pos_scale # Make sure hit_gt_masks has a value valid_mask_flags = hit_gt_masks.sum(dim=-1).sum(dim=-1) > 0 output_stride = stride / 2 for gt_mask, gt_label, pos_h_range, pos_w_range, \ valid_mask_flag in \ zip(hit_gt_masks, hit_gt_labels, pos_h_ranges, pos_w_ranges, valid_mask_flags): if not valid_mask_flag: continue upsampled_size = (featmap_sizes[0][0] * 4, featmap_sizes[0][1] * 4) center_h, center_w = center_of_mass(gt_mask) coord_w = int( (center_w / upsampled_size[1]) // (1. / num_grid)) coord_h = int( (center_h / upsampled_size[0]) // (1. / num_grid)) # left, top, right, down top_box = max( 0, int(((center_h - pos_h_range) / upsampled_size[0]) // (1. / num_grid))) down_box = min( num_grid - 1, int(((center_h + pos_h_range) / upsampled_size[0]) // (1. / num_grid))) left_box = max( 0, int(((center_w - pos_w_range) / upsampled_size[1]) // (1. / num_grid))) right_box = min( num_grid - 1, int(((center_w + pos_w_range) / upsampled_size[1]) // (1. / num_grid))) top = max(top_box, coord_h - 1) down = min(down_box, coord_h + 1) left = max(coord_w - 1, left_box) right = min(right_box, coord_w + 1) labels[top:(down + 1), left:(right + 1)] = gt_label # ins gt_mask = np.uint8(gt_mask.cpu().numpy()) # Follow the original implementation, F.interpolate is # different from cv2 and opencv gt_mask = mmcv.imrescale(gt_mask, scale=1. / output_stride) gt_mask = torch.from_numpy(gt_mask).to(device=device) for i in range(top, down + 1): for j in range(left, right + 1): index = int(i * num_grid + j) mask_target[index, :gt_mask.shape[0], :gt_mask. shape[1]] = gt_mask pos_mask[index] = True mlvl_pos_mask_targets.append(mask_target[pos_mask]) mlvl_labels.append(labels) mlvl_pos_masks.append(pos_mask) return mlvl_pos_mask_targets, mlvl_labels, mlvl_pos_masks
[docs] def get_results(self, mlvl_mask_preds, mlvl_cls_scores, img_metas, **kwargs): """Get multi-image mask results. Args: mlvl_mask_preds (list[Tensor]): Multi-level mask prediction. Each element in the list has shape (batch_size, num_grids**2 ,h ,w). mlvl_cls_scores (list[Tensor]): Multi-level scores. Each element in the list has shape (batch_size, num_classes, num_grids ,num_grids). img_metas (list[dict]): Meta information of all images. Returns: list[:obj:`InstanceData`]: Processed results of multiple images.Each :obj:`InstanceData` usually contains following keys. - scores (Tensor): Classification scores, has shape (num_instance,). - labels (Tensor): Has shape (num_instances,). - masks (Tensor): Processed mask results, has shape (num_instances, h, w). """ mlvl_cls_scores = [ item.permute(0, 2, 3, 1) for item in mlvl_cls_scores ] assert len(mlvl_mask_preds) == len(mlvl_cls_scores) num_levels = len(mlvl_cls_scores) results_list = [] for img_id in range(len(img_metas)): cls_pred_list = [ mlvl_cls_scores[lvl][img_id].view(-1, self.cls_out_channels) for lvl in range(num_levels) ] mask_pred_list = [ mlvl_mask_preds[lvl][img_id] for lvl in range(num_levels) ] cls_pred_list = torch.cat(cls_pred_list, dim=0) mask_pred_list = torch.cat(mask_pred_list, dim=0) results = self._get_results_single( cls_pred_list, mask_pred_list, img_meta=img_metas[img_id]) results_list.append(results) return results_list
def _get_results_single(self, cls_scores, mask_preds, img_meta, cfg=None): """Get processed mask related results of single image. Args: cls_scores (Tensor): Classification score of all points in single image, has shape (num_points, num_classes). mask_preds (Tensor): Mask prediction of all points in single image, has shape (num_points, feat_h, feat_w). img_meta (dict): Meta information of corresponding image. cfg (dict, optional): Config used in test phase. Default: None. Returns: :obj:`InstanceData`: Processed results of single image. it usually contains following keys. - scores (Tensor): Classification scores, has shape (num_instance,). - labels (Tensor): Has shape (num_instances,). - masks (Tensor): Processed mask results, has shape (num_instances, h, w). """ def empty_results(results, cls_scores): """Generate a empty results.""" results.scores = cls_scores.new_ones(0) results.masks = cls_scores.new_zeros(0, *results.ori_shape[:2]) results.labels = cls_scores.new_ones(0) return results cfg = self.test_cfg if cfg is None else cfg assert len(cls_scores) == len(mask_preds) results = InstanceData(img_meta) featmap_size = mask_preds.size()[-2:] img_shape = results.img_shape ori_shape = results.ori_shape h, w, _ = img_shape upsampled_size = (featmap_size[0] * 4, featmap_size[1] * 4) score_mask = (cls_scores > cfg.score_thr) cls_scores = cls_scores[score_mask] if len(cls_scores) == 0: return empty_results(results, cls_scores) inds = score_mask.nonzero() cls_labels = inds[:, 1] # Filter the mask mask with an area is smaller than # stride of corresponding feature level lvl_interval = cls_labels.new_tensor(self.num_grids).pow(2).cumsum(0) strides = cls_scores.new_ones(lvl_interval[-1]) strides[:lvl_interval[0]] *= self.strides[0] for lvl in range(1, self.num_levels): strides[lvl_interval[lvl - 1]:lvl_interval[lvl]] *= self.strides[lvl] strides = strides[inds[:, 0]] mask_preds = mask_preds[inds[:, 0]] masks = mask_preds > cfg.mask_thr sum_masks = masks.sum((1, 2)).float() keep = sum_masks > strides if keep.sum() == 0: return empty_results(results, cls_scores) masks = masks[keep] mask_preds = mask_preds[keep] sum_masks = sum_masks[keep] cls_scores = cls_scores[keep] cls_labels = cls_labels[keep] # maskness. mask_scores = (mask_preds * masks).sum((1, 2)) / sum_masks cls_scores *= mask_scores scores, labels, _, keep_inds = mask_matrix_nms( masks, cls_labels, cls_scores, mask_area=sum_masks, nms_pre=cfg.nms_pre, max_num=cfg.max_per_img, kernel=cfg.kernel, sigma=cfg.sigma, filter_thr=cfg.filter_thr) mask_preds = mask_preds[keep_inds] mask_preds = F.interpolate( mask_preds.unsqueeze(0), size=upsampled_size, mode='bilinear')[:, :, :h, :w] mask_preds = F.interpolate( mask_preds, size=ori_shape[:2], mode='bilinear').squeeze(0) masks = mask_preds > cfg.mask_thr results.masks = masks results.labels = labels results.scores = scores return results
[docs]@HEADS.register_module() class DecoupledSOLOHead(SOLOHead): """Decoupled SOLO mask head used in `SOLO: Segmenting Objects by Locations. <https://arxiv.org/abs/1912.04488>`_ Args: init_cfg (dict or list[dict], optional): Initialization config dict. """ def __init__(self, *args, init_cfg=[ dict(type='Normal', layer='Conv2d', std=0.01), dict( type='Normal', std=0.01, bias_prob=0.01, override=dict(name='conv_mask_list_x')), dict( type='Normal', std=0.01, bias_prob=0.01, override=dict(name='conv_mask_list_y')), dict( type='Normal', std=0.01, bias_prob=0.01, override=dict(name='conv_cls')) ], **kwargs): super(DecoupledSOLOHead, self).__init__( *args, init_cfg=init_cfg, **kwargs) def _init_layers(self): self.mask_convs_x = nn.ModuleList() self.mask_convs_y = nn.ModuleList() self.cls_convs = nn.ModuleList() for i in range(self.stacked_convs): chn = self.in_channels + 1 if i == 0 else self.feat_channels self.mask_convs_x.append( ConvModule( chn, self.feat_channels, 3, stride=1, padding=1, norm_cfg=self.norm_cfg)) self.mask_convs_y.append( ConvModule( chn, self.feat_channels, 3, stride=1, padding=1, norm_cfg=self.norm_cfg)) chn = self.in_channels if i == 0 else self.feat_channels self.cls_convs.append( ConvModule( chn, self.feat_channels, 3, stride=1, padding=1, norm_cfg=self.norm_cfg)) self.conv_mask_list_x = nn.ModuleList() self.conv_mask_list_y = nn.ModuleList() for num_grid in self.num_grids: self.conv_mask_list_x.append( nn.Conv2d(self.feat_channels, num_grid, 3, padding=1)) self.conv_mask_list_y.append( nn.Conv2d(self.feat_channels, num_grid, 3, padding=1)) self.conv_cls = nn.Conv2d( self.feat_channels, self.cls_out_channels, 3, padding=1)
[docs] def forward(self, feats): assert len(feats) == self.num_levels feats = self.resize_feats(feats) mask_preds_x = [] mask_preds_y = [] cls_preds = [] for i in range(self.num_levels): x = feats[i] mask_feat = x cls_feat = x # generate and concat the coordinate coord_feat = generate_coordinate(mask_feat.size(), mask_feat.device) mask_feat_x = torch.cat([mask_feat, coord_feat[:, 0:1, ...]], 1) mask_feat_y = torch.cat([mask_feat, coord_feat[:, 1:2, ...]], 1) for mask_layer_x, mask_layer_y in \ zip(self.mask_convs_x, self.mask_convs_y): mask_feat_x = mask_layer_x(mask_feat_x) mask_feat_y = mask_layer_y(mask_feat_y) mask_feat_x = F.interpolate( mask_feat_x, scale_factor=2, mode='bilinear') mask_feat_y = F.interpolate( mask_feat_y, scale_factor=2, mode='bilinear') mask_pred_x = self.conv_mask_list_x[i](mask_feat_x) mask_pred_y = self.conv_mask_list_y[i](mask_feat_y) # cls branch for j, cls_layer in enumerate(self.cls_convs): if j == self.cls_down_index: num_grid = self.num_grids[i] cls_feat = F.interpolate( cls_feat, size=num_grid, mode='bilinear') cls_feat = cls_layer(cls_feat) cls_pred = self.conv_cls(cls_feat) if not self.training: feat_wh = feats[0].size()[-2:] upsampled_size = (feat_wh[0] * 2, feat_wh[1] * 2) mask_pred_x = F.interpolate( mask_pred_x.sigmoid(), size=upsampled_size, mode='bilinear') mask_pred_y = F.interpolate( mask_pred_y.sigmoid(), size=upsampled_size, mode='bilinear') cls_pred = cls_pred.sigmoid() # get local maximum local_max = F.max_pool2d(cls_pred, 2, stride=1, padding=1) keep_mask = local_max[:, :, :-1, :-1] == cls_pred cls_pred = cls_pred * keep_mask mask_preds_x.append(mask_pred_x) mask_preds_y.append(mask_pred_y) cls_preds.append(cls_pred) return mask_preds_x, mask_preds_y, cls_preds
[docs] def loss(self, mlvl_mask_preds_x, mlvl_mask_preds_y, mlvl_cls_preds, gt_labels, gt_masks, img_metas, gt_bboxes=None, **kwargs): """Calculate the loss of total batch. Args: mlvl_mask_preds_x (list[Tensor]): Multi-level mask prediction from x branch. Each element in the list has shape (batch_size, num_grids ,h ,w). mlvl_mask_preds_x (list[Tensor]): Multi-level mask prediction from y branch. Each element in the list has shape (batch_size, num_grids ,h ,w). mlvl_cls_preds (list[Tensor]): Multi-level scores. Each element in the list has shape (batch_size, num_classes, num_grids ,num_grids). gt_labels (list[Tensor]): Labels of multiple images. gt_masks (list[Tensor]): Ground truth masks of multiple images. Each has shape (num_instances, h, w). img_metas (list[dict]): Meta information of multiple images. gt_bboxes (list[Tensor]): Ground truth bboxes of multiple images. Default: None. Returns: dict[str, Tensor]: A dictionary of loss components. """ num_levels = self.num_levels num_imgs = len(gt_labels) featmap_sizes = [featmap.size()[-2:] for featmap in mlvl_mask_preds_x] pos_mask_targets, labels, \ xy_pos_indexes = \ multi_apply(self._get_targets_single, gt_bboxes, gt_labels, gt_masks, featmap_sizes=featmap_sizes) # change from the outside list meaning multi images # to the outside list meaning multi levels mlvl_pos_mask_targets = [[] for _ in range(num_levels)] mlvl_pos_mask_preds_x = [[] for _ in range(num_levels)] mlvl_pos_mask_preds_y = [[] for _ in range(num_levels)] mlvl_labels = [[] for _ in range(num_levels)] for img_id in range(num_imgs): for lvl in range(num_levels): mlvl_pos_mask_targets[lvl].append( pos_mask_targets[img_id][lvl]) mlvl_pos_mask_preds_x[lvl].append( mlvl_mask_preds_x[lvl][img_id, xy_pos_indexes[img_id][lvl][:, 1]]) mlvl_pos_mask_preds_y[lvl].append( mlvl_mask_preds_y[lvl][img_id, xy_pos_indexes[img_id][lvl][:, 0]]) mlvl_labels[lvl].append(labels[img_id][lvl].flatten()) # cat multiple image temp_mlvl_cls_preds = [] for lvl in range(num_levels): mlvl_pos_mask_targets[lvl] = torch.cat( mlvl_pos_mask_targets[lvl], dim=0) mlvl_pos_mask_preds_x[lvl] = torch.cat( mlvl_pos_mask_preds_x[lvl], dim=0) mlvl_pos_mask_preds_y[lvl] = torch.cat( mlvl_pos_mask_preds_y[lvl], dim=0) mlvl_labels[lvl] = torch.cat(mlvl_labels[lvl], dim=0) temp_mlvl_cls_preds.append(mlvl_cls_preds[lvl].permute( 0, 2, 3, 1).reshape(-1, self.cls_out_channels)) num_pos = 0. # dice loss loss_mask = [] for pred_x, pred_y, target in \ zip(mlvl_pos_mask_preds_x, mlvl_pos_mask_preds_y, mlvl_pos_mask_targets): num_masks = pred_x.size(0) if num_masks == 0: # make sure can get grad loss_mask.append((pred_x.sum() + pred_y.sum()).unsqueeze(0)) continue num_pos += num_masks pred_mask = pred_y.sigmoid() * pred_x.sigmoid() loss_mask.append( self.loss_mask(pred_mask, target, reduction_override='none')) if num_pos > 0: loss_mask = torch.cat(loss_mask).sum() / num_pos else: loss_mask = torch.cat(loss_mask).mean() # cate flatten_labels = torch.cat(mlvl_labels) flatten_cls_preds = torch.cat(temp_mlvl_cls_preds) loss_cls = self.loss_cls( flatten_cls_preds, flatten_labels, avg_factor=num_pos + 1) return dict(loss_mask=loss_mask, loss_cls=loss_cls)
def _get_targets_single(self, gt_bboxes, gt_labels, gt_masks, featmap_sizes=None): """Compute targets for predictions of single image. Args: gt_bboxes (Tensor): Ground truth bbox of each instance, shape (num_gts, 4). gt_labels (Tensor): Ground truth label of each instance, shape (num_gts,). gt_masks (Tensor): Ground truth mask of each instance, shape (num_gts, h, w). featmap_sizes (list[:obj:`torch.size`]): Size of each feature map from feature pyramid, each element means (feat_h, feat_w). Default: None. Returns: Tuple: Usually returns a tuple containing targets for predictions. - mlvl_pos_mask_targets (list[Tensor]): Each element represent the binary mask targets for positive points in this level, has shape (num_pos, out_h, out_w). - mlvl_labels (list[Tensor]): Each element is classification labels for all points in this level, has shape (num_grid, num_grid). - mlvl_xy_pos_indexes (list[Tensor]): Each element in the list contains the index of positive samples in corresponding level, has shape (num_pos, 2), last dimension 2 present (index_x, index_y). """ mlvl_pos_mask_targets, mlvl_labels, \ mlvl_pos_masks = \ super()._get_targets_single(gt_bboxes, gt_labels, gt_masks, featmap_sizes=featmap_sizes) mlvl_xy_pos_indexes = [(item - self.num_classes).nonzero() for item in mlvl_labels] return mlvl_pos_mask_targets, mlvl_labels, mlvl_xy_pos_indexes
[docs] def get_results(self, mlvl_mask_preds_x, mlvl_mask_preds_y, mlvl_cls_scores, img_metas, rescale=None, **kwargs): """Get multi-image mask results. Args: mlvl_mask_preds_x (list[Tensor]): Multi-level mask prediction from x branch. Each element in the list has shape (batch_size, num_grids ,h ,w). mlvl_mask_preds_y (list[Tensor]): Multi-level mask prediction from y branch. Each element in the list has shape (batch_size, num_grids ,h ,w). mlvl_cls_scores (list[Tensor]): Multi-level scores. Each element in the list has shape (batch_size, num_classes ,num_grids ,num_grids). img_metas (list[dict]): Meta information of all images. Returns: list[:obj:`InstanceData`]: Processed results of multiple images.Each :obj:`InstanceData` usually contains following keys. - scores (Tensor): Classification scores, has shape (num_instance,). - labels (Tensor): Has shape (num_instances,). - masks (Tensor): Processed mask results, has shape (num_instances, h, w). """ mlvl_cls_scores = [ item.permute(0, 2, 3, 1) for item in mlvl_cls_scores ] assert len(mlvl_mask_preds_x) == len(mlvl_cls_scores) num_levels = len(mlvl_cls_scores) results_list = [] for img_id in range(len(img_metas)): cls_pred_list = [ mlvl_cls_scores[i][img_id].view( -1, self.cls_out_channels).detach() for i in range(num_levels) ] mask_pred_list_x = [ mlvl_mask_preds_x[i][img_id] for i in range(num_levels) ] mask_pred_list_y = [ mlvl_mask_preds_y[i][img_id] for i in range(num_levels) ] cls_pred_list = torch.cat(cls_pred_list, dim=0) mask_pred_list_x = torch.cat(mask_pred_list_x, dim=0) mask_pred_list_y = torch.cat(mask_pred_list_y, dim=0) results = self._get_results_single( cls_pred_list, mask_pred_list_x, mask_pred_list_y, img_meta=img_metas[img_id], cfg=self.test_cfg) results_list.append(results) return results_list
def _get_results_single(self, cls_scores, mask_preds_x, mask_preds_y, img_meta, cfg): """Get processed mask related results of single image. Args: cls_scores (Tensor): Classification score of all points in single image, has shape (num_points, num_classes). mask_preds_x (Tensor): Mask prediction of x branch of all points in single image, has shape (sum_num_grids, feat_h, feat_w). mask_preds_y (Tensor): Mask prediction of y branch of all points in single image, has shape (sum_num_grids, feat_h, feat_w). img_meta (dict): Meta information of corresponding image. cfg (dict): Config used in test phase. Returns: :obj:`InstanceData`: Processed results of single image. it usually contains following keys. - scores (Tensor): Classification scores, has shape (num_instance,). - labels (Tensor): Has shape (num_instances,). - masks (Tensor): Processed mask results, has shape (num_instances, h, w). """ def empty_results(results, cls_scores): """Generate a empty results.""" results.scores = cls_scores.new_ones(0) results.masks = cls_scores.new_zeros(0, *results.ori_shape[:2]) results.labels = cls_scores.new_ones(0) return results cfg = self.test_cfg if cfg is None else cfg results = InstanceData(img_meta) img_shape = results.img_shape ori_shape = results.ori_shape h, w, _ = img_shape featmap_size = mask_preds_x.size()[-2:] upsampled_size = (featmap_size[0] * 4, featmap_size[1] * 4) score_mask = (cls_scores > cfg.score_thr) cls_scores = cls_scores[score_mask] inds = score_mask.nonzero() lvl_interval = inds.new_tensor(self.num_grids).pow(2).cumsum(0) num_all_points = lvl_interval[-1] lvl_start_index = inds.new_ones(num_all_points) num_grids = inds.new_ones(num_all_points) seg_size = inds.new_tensor(self.num_grids).cumsum(0) mask_lvl_start_index = inds.new_ones(num_all_points) strides = inds.new_ones(num_all_points) lvl_start_index[:lvl_interval[0]] *= 0 mask_lvl_start_index[:lvl_interval[0]] *= 0 num_grids[:lvl_interval[0]] *= self.num_grids[0] strides[:lvl_interval[0]] *= self.strides[0] for lvl in range(1, self.num_levels): lvl_start_index[lvl_interval[lvl - 1]:lvl_interval[lvl]] *= \ lvl_interval[lvl - 1] mask_lvl_start_index[lvl_interval[lvl - 1]:lvl_interval[lvl]] *= \ seg_size[lvl - 1] num_grids[lvl_interval[lvl - 1]:lvl_interval[lvl]] *= \ self.num_grids[lvl] strides[lvl_interval[lvl - 1]:lvl_interval[lvl]] *= \ self.strides[lvl] lvl_start_index = lvl_start_index[inds[:, 0]] mask_lvl_start_index = mask_lvl_start_index[inds[:, 0]] num_grids = num_grids[inds[:, 0]] strides = strides[inds[:, 0]] y_lvl_offset = (inds[:, 0] - lvl_start_index) // num_grids x_lvl_offset = (inds[:, 0] - lvl_start_index) % num_grids y_inds = mask_lvl_start_index + y_lvl_offset x_inds = mask_lvl_start_index + x_lvl_offset cls_labels = inds[:, 1] mask_preds = mask_preds_x[x_inds, ...] * mask_preds_y[y_inds, ...] masks = mask_preds > cfg.mask_thr sum_masks = masks.sum((1, 2)).float() keep = sum_masks > strides if keep.sum() == 0: return empty_results(results, cls_scores) masks = masks[keep] mask_preds = mask_preds[keep] sum_masks = sum_masks[keep] cls_scores = cls_scores[keep] cls_labels = cls_labels[keep] # maskness. mask_scores = (mask_preds * masks).sum((1, 2)) / sum_masks cls_scores *= mask_scores scores, labels, _, keep_inds = mask_matrix_nms( masks, cls_labels, cls_scores, mask_area=sum_masks, nms_pre=cfg.nms_pre, max_num=cfg.max_per_img, kernel=cfg.kernel, sigma=cfg.sigma, filter_thr=cfg.filter_thr) mask_preds = mask_preds[keep_inds] mask_preds = F.interpolate( mask_preds.unsqueeze(0), size=upsampled_size, mode='bilinear')[:, :, :h, :w] mask_preds = F.interpolate( mask_preds, size=ori_shape[:2], mode='bilinear').squeeze(0) masks = mask_preds > cfg.mask_thr results.masks = masks results.labels = labels results.scores = scores return results
[docs]@HEADS.register_module() class DecoupledSOLOLightHead(DecoupledSOLOHead): """Decoupled Light SOLO mask head used in `SOLO: Segmenting Objects by Locations <https://arxiv.org/abs/1912.04488>`_ Args: with_dcn (bool): Whether use dcn in mask_convs and cls_convs, default: False. init_cfg (dict or list[dict], optional): Initialization config dict. """ def __init__(self, *args, dcn_cfg=None, init_cfg=[ dict(type='Normal', layer='Conv2d', std=0.01), dict( type='Normal', std=0.01, bias_prob=0.01, override=dict(name='conv_mask_list_x')), dict( type='Normal', std=0.01, bias_prob=0.01, override=dict(name='conv_mask_list_y')), dict( type='Normal', std=0.01, bias_prob=0.01, override=dict(name='conv_cls')) ], **kwargs): assert dcn_cfg is None or isinstance(dcn_cfg, dict) self.dcn_cfg = dcn_cfg super(DecoupledSOLOLightHead, self).__init__( *args, init_cfg=init_cfg, **kwargs) def _init_layers(self): self.mask_convs = nn.ModuleList() self.cls_convs = nn.ModuleList() for i in range(self.stacked_convs): if self.dcn_cfg is not None\ and i == self.stacked_convs - 1: conv_cfg = self.dcn_cfg else: conv_cfg = None chn = self.in_channels + 2 if i == 0 else self.feat_channels self.mask_convs.append( ConvModule( chn, self.feat_channels, 3, stride=1, padding=1, conv_cfg=conv_cfg, norm_cfg=self.norm_cfg)) chn = self.in_channels if i == 0 else self.feat_channels self.cls_convs.append( ConvModule( chn, self.feat_channels, 3, stride=1, padding=1, conv_cfg=conv_cfg, norm_cfg=self.norm_cfg)) self.conv_mask_list_x = nn.ModuleList() self.conv_mask_list_y = nn.ModuleList() for num_grid in self.num_grids: self.conv_mask_list_x.append( nn.Conv2d(self.feat_channels, num_grid, 3, padding=1)) self.conv_mask_list_y.append( nn.Conv2d(self.feat_channels, num_grid, 3, padding=1)) self.conv_cls = nn.Conv2d( self.feat_channels, self.cls_out_channels, 3, padding=1)
[docs] def forward(self, feats): assert len(feats) == self.num_levels feats = self.resize_feats(feats) mask_preds_x = [] mask_preds_y = [] cls_preds = [] for i in range(self.num_levels): x = feats[i] mask_feat = x cls_feat = x # generate and concat the coordinate coord_feat = generate_coordinate(mask_feat.size(), mask_feat.device) mask_feat = torch.cat([mask_feat, coord_feat], 1) for mask_layer in self.mask_convs: mask_feat = mask_layer(mask_feat) mask_feat = F.interpolate( mask_feat, scale_factor=2, mode='bilinear') mask_pred_x = self.conv_mask_list_x[i](mask_feat) mask_pred_y = self.conv_mask_list_y[i](mask_feat) # cls branch for j, cls_layer in enumerate(self.cls_convs): if j == self.cls_down_index: num_grid = self.num_grids[i] cls_feat = F.interpolate( cls_feat, size=num_grid, mode='bilinear') cls_feat = cls_layer(cls_feat) cls_pred = self.conv_cls(cls_feat) if not self.training: feat_wh = feats[0].size()[-2:] upsampled_size = (feat_wh[0] * 2, feat_wh[1] * 2) mask_pred_x = F.interpolate( mask_pred_x.sigmoid(), size=upsampled_size, mode='bilinear') mask_pred_y = F.interpolate( mask_pred_y.sigmoid(), size=upsampled_size, mode='bilinear') cls_pred = cls_pred.sigmoid() # get local maximum local_max = F.max_pool2d(cls_pred, 2, stride=1, padding=1) keep_mask = local_max[:, :, :-1, :-1] == cls_pred cls_pred = cls_pred * keep_mask mask_preds_x.append(mask_pred_x) mask_preds_y.append(mask_pred_y) cls_preds.append(cls_pred) return mask_preds_x, mask_preds_y, cls_preds
Read the Docs v: v2.18.1
Versions
latest
stable
v2.18.1
v2.18.0
v2.17.0
v2.16.0
v2.15.1
v2.15.0
v2.14.0
v2.13.0
v2.12.0
v2.11.0
v2.10.0
v2.9.0
v2.8.0
v2.7.0
v2.6.0
v2.5.0
v2.4.0
v2.3.0
v2.2.1
v2.2.0
v2.1.0
v2.0.0
v1.2.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.