Shortcuts

Source code for mmdet.models.roi_heads.bbox_heads.sabl_head

# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule, force_fp32

from mmdet.core import build_bbox_coder, multi_apply, multiclass_nms
from mmdet.models.builder import HEADS, build_loss
from mmdet.models.losses import accuracy


[docs]@HEADS.register_module() class SABLHead(BaseModule): """Side-Aware Boundary Localization (SABL) for RoI-Head. Side-Aware features are extracted by conv layers with an attention mechanism. Boundary Localization with Bucketing and Bucketing Guided Rescoring are implemented in BucketingBBoxCoder. Please refer to https://arxiv.org/abs/1912.04260 for more details. Args: cls_in_channels (int): Input channels of cls RoI feature. \ Defaults to 256. reg_in_channels (int): Input channels of reg RoI feature. \ Defaults to 256. roi_feat_size (int): Size of RoI features. Defaults to 7. reg_feat_up_ratio (int): Upsample ratio of reg features. \ Defaults to 2. reg_pre_kernel (int): Kernel of 2D conv layers before \ attention pooling. Defaults to 3. reg_post_kernel (int): Kernel of 1D conv layers after \ attention pooling. Defaults to 3. reg_pre_num (int): Number of pre convs. Defaults to 2. reg_post_num (int): Number of post convs. Defaults to 1. num_classes (int): Number of classes in dataset. Defaults to 80. cls_out_channels (int): Hidden channels in cls fcs. Defaults to 1024. reg_offset_out_channels (int): Hidden and output channel \ of reg offset branch. Defaults to 256. reg_cls_out_channels (int): Hidden and output channel \ of reg cls branch. Defaults to 256. num_cls_fcs (int): Number of fcs for cls branch. Defaults to 1. num_reg_fcs (int): Number of fcs for reg branch.. Defaults to 0. reg_class_agnostic (bool): Class agnostic regression or not. \ Defaults to True. norm_cfg (dict): Config of norm layers. Defaults to None. bbox_coder (dict): Config of bbox coder. Defaults 'BucketingBBoxCoder'. loss_cls (dict): Config of classification loss. loss_bbox_cls (dict): Config of classification loss for bbox branch. loss_bbox_reg (dict): Config of regression loss for bbox branch. init_cfg (dict or list[dict], optional): Initialization config dict. Default: None """ def __init__(self, num_classes, cls_in_channels=256, reg_in_channels=256, roi_feat_size=7, reg_feat_up_ratio=2, reg_pre_kernel=3, reg_post_kernel=3, reg_pre_num=2, reg_post_num=1, cls_out_channels=1024, reg_offset_out_channels=256, reg_cls_out_channels=256, num_cls_fcs=1, num_reg_fcs=0, reg_class_agnostic=True, norm_cfg=None, bbox_coder=dict( type='BucketingBBoxCoder', num_buckets=14, scale_factor=1.7), loss_cls=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), loss_bbox_cls=dict( type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), loss_bbox_reg=dict( type='SmoothL1Loss', beta=0.1, loss_weight=1.0), init_cfg=None): super(SABLHead, self).__init__(init_cfg) self.cls_in_channels = cls_in_channels self.reg_in_channels = reg_in_channels self.roi_feat_size = roi_feat_size self.reg_feat_up_ratio = int(reg_feat_up_ratio) self.num_buckets = bbox_coder['num_buckets'] assert self.reg_feat_up_ratio // 2 >= 1 self.up_reg_feat_size = roi_feat_size * self.reg_feat_up_ratio assert self.up_reg_feat_size == bbox_coder['num_buckets'] self.reg_pre_kernel = reg_pre_kernel self.reg_post_kernel = reg_post_kernel self.reg_pre_num = reg_pre_num self.reg_post_num = reg_post_num self.num_classes = num_classes self.cls_out_channels = cls_out_channels self.reg_offset_out_channels = reg_offset_out_channels self.reg_cls_out_channels = reg_cls_out_channels self.num_cls_fcs = num_cls_fcs self.num_reg_fcs = num_reg_fcs self.reg_class_agnostic = reg_class_agnostic assert self.reg_class_agnostic self.norm_cfg = norm_cfg self.bbox_coder = build_bbox_coder(bbox_coder) self.loss_cls = build_loss(loss_cls) self.loss_bbox_cls = build_loss(loss_bbox_cls) self.loss_bbox_reg = build_loss(loss_bbox_reg) self.cls_fcs = self._add_fc_branch(self.num_cls_fcs, self.cls_in_channels, self.roi_feat_size, self.cls_out_channels) self.side_num = int(np.ceil(self.num_buckets / 2)) if self.reg_feat_up_ratio > 1: self.upsample_x = nn.ConvTranspose1d( reg_in_channels, reg_in_channels, self.reg_feat_up_ratio, stride=self.reg_feat_up_ratio) self.upsample_y = nn.ConvTranspose1d( reg_in_channels, reg_in_channels, self.reg_feat_up_ratio, stride=self.reg_feat_up_ratio) self.reg_pre_convs = nn.ModuleList() for i in range(self.reg_pre_num): reg_pre_conv = ConvModule( reg_in_channels, reg_in_channels, kernel_size=reg_pre_kernel, padding=reg_pre_kernel // 2, norm_cfg=norm_cfg, act_cfg=dict(type='ReLU')) self.reg_pre_convs.append(reg_pre_conv) self.reg_post_conv_xs = nn.ModuleList() for i in range(self.reg_post_num): reg_post_conv_x = ConvModule( reg_in_channels, reg_in_channels, kernel_size=(1, reg_post_kernel), padding=(0, reg_post_kernel // 2), norm_cfg=norm_cfg, act_cfg=dict(type='ReLU')) self.reg_post_conv_xs.append(reg_post_conv_x) self.reg_post_conv_ys = nn.ModuleList() for i in range(self.reg_post_num): reg_post_conv_y = ConvModule( reg_in_channels, reg_in_channels, kernel_size=(reg_post_kernel, 1), padding=(reg_post_kernel // 2, 0), norm_cfg=norm_cfg, act_cfg=dict(type='ReLU')) self.reg_post_conv_ys.append(reg_post_conv_y) self.reg_conv_att_x = nn.Conv2d(reg_in_channels, 1, 1) self.reg_conv_att_y = nn.Conv2d(reg_in_channels, 1, 1) self.fc_cls = nn.Linear(self.cls_out_channels, self.num_classes + 1) self.relu = nn.ReLU(inplace=True) self.reg_cls_fcs = self._add_fc_branch(self.num_reg_fcs, self.reg_in_channels, 1, self.reg_cls_out_channels) self.reg_offset_fcs = self._add_fc_branch(self.num_reg_fcs, self.reg_in_channels, 1, self.reg_offset_out_channels) self.fc_reg_cls = nn.Linear(self.reg_cls_out_channels, 1) self.fc_reg_offset = nn.Linear(self.reg_offset_out_channels, 1) if init_cfg is None: self.init_cfg = [ dict( type='Xavier', layer='Linear', distribution='uniform', override=[ dict(type='Normal', name='reg_conv_att_x', std=0.01), dict(type='Normal', name='reg_conv_att_y', std=0.01), dict(type='Normal', name='fc_reg_cls', std=0.01), dict(type='Normal', name='fc_cls', std=0.01), dict(type='Normal', name='fc_reg_offset', std=0.001) ]) ] if self.reg_feat_up_ratio > 1: self.init_cfg += [ dict( type='Kaiming', distribution='normal', override=[ dict(name='upsample_x'), dict(name='upsample_y') ]) ] def _add_fc_branch(self, num_branch_fcs, in_channels, roi_feat_size, fc_out_channels): in_channels = in_channels * roi_feat_size * roi_feat_size branch_fcs = nn.ModuleList() for i in range(num_branch_fcs): fc_in_channels = (in_channels if i == 0 else fc_out_channels) branch_fcs.append(nn.Linear(fc_in_channels, fc_out_channels)) return branch_fcs def cls_forward(self, cls_x): cls_x = cls_x.view(cls_x.size(0), -1) for fc in self.cls_fcs: cls_x = self.relu(fc(cls_x)) cls_score = self.fc_cls(cls_x) return cls_score
[docs] def attention_pool(self, reg_x): """Extract direction-specific features fx and fy with attention methanism.""" reg_fx = reg_x reg_fy = reg_x reg_fx_att = self.reg_conv_att_x(reg_fx).sigmoid() reg_fy_att = self.reg_conv_att_y(reg_fy).sigmoid() reg_fx_att = reg_fx_att / reg_fx_att.sum(dim=2).unsqueeze(2) reg_fy_att = reg_fy_att / reg_fy_att.sum(dim=3).unsqueeze(3) reg_fx = (reg_fx * reg_fx_att).sum(dim=2) reg_fy = (reg_fy * reg_fy_att).sum(dim=3) return reg_fx, reg_fy
[docs] def side_aware_feature_extractor(self, reg_x): """Refine and extract side-aware features without split them.""" for reg_pre_conv in self.reg_pre_convs: reg_x = reg_pre_conv(reg_x) reg_fx, reg_fy = self.attention_pool(reg_x) if self.reg_post_num > 0: reg_fx = reg_fx.unsqueeze(2) reg_fy = reg_fy.unsqueeze(3) for i in range(self.reg_post_num): reg_fx = self.reg_post_conv_xs[i](reg_fx) reg_fy = self.reg_post_conv_ys[i](reg_fy) reg_fx = reg_fx.squeeze(2) reg_fy = reg_fy.squeeze(3) if self.reg_feat_up_ratio > 1: reg_fx = self.relu(self.upsample_x(reg_fx)) reg_fy = self.relu(self.upsample_y(reg_fy)) reg_fx = torch.transpose(reg_fx, 1, 2) reg_fy = torch.transpose(reg_fy, 1, 2) return reg_fx.contiguous(), reg_fy.contiguous()
[docs] def reg_pred(self, x, offset_fcs, cls_fcs): """Predict bucketing estimation (cls_pred) and fine regression (offset pred) with side-aware features.""" x_offset = x.view(-1, self.reg_in_channels) x_cls = x.view(-1, self.reg_in_channels) for fc in offset_fcs: x_offset = self.relu(fc(x_offset)) for fc in cls_fcs: x_cls = self.relu(fc(x_cls)) offset_pred = self.fc_reg_offset(x_offset) cls_pred = self.fc_reg_cls(x_cls) offset_pred = offset_pred.view(x.size(0), -1) cls_pred = cls_pred.view(x.size(0), -1) return offset_pred, cls_pred
[docs] def side_aware_split(self, feat): """Split side-aware features aligned with orders of bucketing targets.""" l_end = int(np.ceil(self.up_reg_feat_size / 2)) r_start = int(np.floor(self.up_reg_feat_size / 2)) feat_fl = feat[:, :l_end] feat_fr = feat[:, r_start:].flip(dims=(1, )) feat_fl = feat_fl.contiguous() feat_fr = feat_fr.contiguous() feat = torch.cat([feat_fl, feat_fr], dim=-1) return feat
[docs] def bbox_pred_split(self, bbox_pred, num_proposals_per_img): """Split batch bbox prediction back to each image.""" bucket_cls_preds, bucket_offset_preds = bbox_pred bucket_cls_preds = bucket_cls_preds.split(num_proposals_per_img, 0) bucket_offset_preds = bucket_offset_preds.split( num_proposals_per_img, 0) bbox_pred = tuple(zip(bucket_cls_preds, bucket_offset_preds)) return bbox_pred
def reg_forward(self, reg_x): outs = self.side_aware_feature_extractor(reg_x) edge_offset_preds = [] edge_cls_preds = [] reg_fx = outs[0] reg_fy = outs[1] offset_pred_x, cls_pred_x = self.reg_pred(reg_fx, self.reg_offset_fcs, self.reg_cls_fcs) offset_pred_y, cls_pred_y = self.reg_pred(reg_fy, self.reg_offset_fcs, self.reg_cls_fcs) offset_pred_x = self.side_aware_split(offset_pred_x) offset_pred_y = self.side_aware_split(offset_pred_y) cls_pred_x = self.side_aware_split(cls_pred_x) cls_pred_y = self.side_aware_split(cls_pred_y) edge_offset_preds = torch.cat([offset_pred_x, offset_pred_y], dim=-1) edge_cls_preds = torch.cat([cls_pred_x, cls_pred_y], dim=-1) return (edge_cls_preds, edge_offset_preds)
[docs] def forward(self, x): bbox_pred = self.reg_forward(x) cls_score = self.cls_forward(x) return cls_score, bbox_pred
def get_targets(self, sampling_results, gt_bboxes, gt_labels, rcnn_train_cfg): pos_proposals = [res.pos_bboxes for res in sampling_results] neg_proposals = [res.neg_bboxes for res in sampling_results] pos_gt_bboxes = [res.pos_gt_bboxes for res in sampling_results] pos_gt_labels = [res.pos_gt_labels for res in sampling_results] cls_reg_targets = self.bucket_target(pos_proposals, neg_proposals, pos_gt_bboxes, pos_gt_labels, rcnn_train_cfg) (labels, label_weights, bucket_cls_targets, bucket_cls_weights, bucket_offset_targets, bucket_offset_weights) = cls_reg_targets return (labels, label_weights, (bucket_cls_targets, bucket_offset_targets), (bucket_cls_weights, bucket_offset_weights)) def bucket_target(self, pos_proposals_list, neg_proposals_list, pos_gt_bboxes_list, pos_gt_labels_list, rcnn_train_cfg, concat=True): (labels, label_weights, bucket_cls_targets, bucket_cls_weights, bucket_offset_targets, bucket_offset_weights) = multi_apply( self._bucket_target_single, pos_proposals_list, neg_proposals_list, pos_gt_bboxes_list, pos_gt_labels_list, cfg=rcnn_train_cfg) if concat: labels = torch.cat(labels, 0) label_weights = torch.cat(label_weights, 0) bucket_cls_targets = torch.cat(bucket_cls_targets, 0) bucket_cls_weights = torch.cat(bucket_cls_weights, 0) bucket_offset_targets = torch.cat(bucket_offset_targets, 0) bucket_offset_weights = torch.cat(bucket_offset_weights, 0) return (labels, label_weights, bucket_cls_targets, bucket_cls_weights, bucket_offset_targets, bucket_offset_weights) def _bucket_target_single(self, pos_proposals, neg_proposals, pos_gt_bboxes, pos_gt_labels, cfg): """Compute bucketing estimation targets and fine regression targets for a single image. Args: pos_proposals (Tensor): positive proposals of a single image, Shape (n_pos, 4) neg_proposals (Tensor): negative proposals of a single image, Shape (n_neg, 4). pos_gt_bboxes (Tensor): gt bboxes assigned to positive proposals of a single image, Shape (n_pos, 4). pos_gt_labels (Tensor): gt labels assigned to positive proposals of a single image, Shape (n_pos, ). cfg (dict): Config of calculating targets Returns: tuple: - labels (Tensor): Labels in a single image. \ Shape (n,). - label_weights (Tensor): Label weights in a single image.\ Shape (n,) - bucket_cls_targets (Tensor): Bucket cls targets in \ a single image. Shape (n, num_buckets*2). - bucket_cls_weights (Tensor): Bucket cls weights in \ a single image. Shape (n, num_buckets*2). - bucket_offset_targets (Tensor): Bucket offset targets \ in a single image. Shape (n, num_buckets*2). - bucket_offset_targets (Tensor): Bucket offset weights \ in a single image. Shape (n, num_buckets*2). """ num_pos = pos_proposals.size(0) num_neg = neg_proposals.size(0) num_samples = num_pos + num_neg labels = pos_gt_bboxes.new_full((num_samples, ), self.num_classes, dtype=torch.long) label_weights = pos_proposals.new_zeros(num_samples) bucket_cls_targets = pos_proposals.new_zeros(num_samples, 4 * self.side_num) bucket_cls_weights = pos_proposals.new_zeros(num_samples, 4 * self.side_num) bucket_offset_targets = pos_proposals.new_zeros( num_samples, 4 * self.side_num) bucket_offset_weights = pos_proposals.new_zeros( num_samples, 4 * self.side_num) if num_pos > 0: labels[:num_pos] = pos_gt_labels label_weights[:num_pos] = 1.0 (pos_bucket_offset_targets, pos_bucket_offset_weights, pos_bucket_cls_targets, pos_bucket_cls_weights) = self.bbox_coder.encode( pos_proposals, pos_gt_bboxes) bucket_cls_targets[:num_pos, :] = pos_bucket_cls_targets bucket_cls_weights[:num_pos, :] = pos_bucket_cls_weights bucket_offset_targets[:num_pos, :] = pos_bucket_offset_targets bucket_offset_weights[:num_pos, :] = pos_bucket_offset_weights if num_neg > 0: label_weights[-num_neg:] = 1.0 return (labels, label_weights, bucket_cls_targets, bucket_cls_weights, bucket_offset_targets, bucket_offset_weights) def loss(self, cls_score, bbox_pred, rois, labels, label_weights, bbox_targets, bbox_weights, reduction_override=None): losses = dict() if cls_score is not None: avg_factor = max(torch.sum(label_weights > 0).float().item(), 1.) losses['loss_cls'] = self.loss_cls( cls_score, labels, label_weights, avg_factor=avg_factor, reduction_override=reduction_override) losses['acc'] = accuracy(cls_score, labels) if bbox_pred is not None: bucket_cls_preds, bucket_offset_preds = bbox_pred bucket_cls_targets, bucket_offset_targets = bbox_targets bucket_cls_weights, bucket_offset_weights = bbox_weights # edge cls bucket_cls_preds = bucket_cls_preds.view(-1, self.side_num) bucket_cls_targets = bucket_cls_targets.view(-1, self.side_num) bucket_cls_weights = bucket_cls_weights.view(-1, self.side_num) losses['loss_bbox_cls'] = self.loss_bbox_cls( bucket_cls_preds, bucket_cls_targets, bucket_cls_weights, avg_factor=bucket_cls_targets.size(0), reduction_override=reduction_override) losses['loss_bbox_reg'] = self.loss_bbox_reg( bucket_offset_preds, bucket_offset_targets, bucket_offset_weights, avg_factor=bucket_offset_targets.size(0), reduction_override=reduction_override) return losses @force_fp32(apply_to=('cls_score', 'bbox_pred')) def get_bboxes(self, rois, cls_score, bbox_pred, img_shape, scale_factor, rescale=False, cfg=None): if isinstance(cls_score, list): cls_score = sum(cls_score) / float(len(cls_score)) scores = F.softmax(cls_score, dim=1) if cls_score is not None else None if bbox_pred is not None: bboxes, confidences = self.bbox_coder.decode( rois[:, 1:], bbox_pred, img_shape) else: bboxes = rois[:, 1:].clone() confidences = None if img_shape is not None: bboxes[:, [0, 2]].clamp_(min=0, max=img_shape[1] - 1) bboxes[:, [1, 3]].clamp_(min=0, max=img_shape[0] - 1) if rescale and bboxes.size(0) > 0: if isinstance(scale_factor, float): bboxes /= scale_factor else: bboxes /= torch.from_numpy(scale_factor).to(bboxes.device) if cfg is None: return bboxes, scores else: det_bboxes, det_labels = multiclass_nms( bboxes, scores, cfg.score_thr, cfg.nms, cfg.max_per_img, score_factors=confidences) return det_bboxes, det_labels
[docs] @force_fp32(apply_to=('bbox_preds', )) def refine_bboxes(self, rois, labels, bbox_preds, pos_is_gts, img_metas): """Refine bboxes during training. Args: rois (Tensor): Shape (n*bs, 5), where n is image number per GPU, and bs is the sampled RoIs per image. labels (Tensor): Shape (n*bs, ). bbox_preds (list[Tensor]): Shape [(n*bs, num_buckets*2), \ (n*bs, num_buckets*2)]. pos_is_gts (list[Tensor]): Flags indicating if each positive bbox is a gt bbox. img_metas (list[dict]): Meta info of each image. Returns: list[Tensor]: Refined bboxes of each image in a mini-batch. """ img_ids = rois[:, 0].long().unique(sorted=True) assert img_ids.numel() == len(img_metas) bboxes_list = [] for i in range(len(img_metas)): inds = torch.nonzero( rois[:, 0] == i, as_tuple=False).squeeze(dim=1) num_rois = inds.numel() bboxes_ = rois[inds, 1:] label_ = labels[inds] edge_cls_preds, edge_offset_preds = bbox_preds edge_cls_preds_ = edge_cls_preds[inds] edge_offset_preds_ = edge_offset_preds[inds] bbox_pred_ = [edge_cls_preds_, edge_offset_preds_] img_meta_ = img_metas[i] pos_is_gts_ = pos_is_gts[i] bboxes = self.regress_by_class(bboxes_, label_, bbox_pred_, img_meta_) # filter gt bboxes pos_keep = 1 - pos_is_gts_ keep_inds = pos_is_gts_.new_ones(num_rois) keep_inds[:len(pos_is_gts_)] = pos_keep bboxes_list.append(bboxes[keep_inds.type(torch.bool)]) return bboxes_list
[docs] @force_fp32(apply_to=('bbox_pred', )) def regress_by_class(self, rois, label, bbox_pred, img_meta): """Regress the bbox for the predicted class. Used in Cascade R-CNN. Args: rois (Tensor): shape (n, 4) or (n, 5) label (Tensor): shape (n, ) bbox_pred (list[Tensor]): shape [(n, num_buckets *2), \ (n, num_buckets *2)] img_meta (dict): Image meta info. Returns: Tensor: Regressed bboxes, the same shape as input rois. """ assert rois.size(1) == 4 or rois.size(1) == 5 if rois.size(1) == 4: new_rois, _ = self.bbox_coder.decode(rois, bbox_pred, img_meta['img_shape']) else: bboxes, _ = self.bbox_coder.decode(rois[:, 1:], bbox_pred, img_meta['img_shape']) new_rois = torch.cat((rois[:, [0]], bboxes), dim=1) return new_rois
Read the Docs v: v2.19.1
Versions
latest
stable
v2.19.1
v2.19.0
v2.18.1
v2.18.0
v2.17.0
v2.16.0
v2.15.1
v2.15.0
v2.14.0
v2.13.0
v2.12.0
v2.11.0
v2.10.0
v2.9.0
v2.8.0
v2.7.0
v2.6.0
v2.5.0
v2.4.0
v2.3.0
v2.2.1
v2.2.0
v2.1.0
v2.0.0
v1.2.0
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.