Source code for mmdet.models.dense_heads.ssd_head

import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import xavier_init

from mmdet.core import (build_anchor_generator, build_assigner,
                        build_bbox_coder, build_sampler, multi_apply)
from ..builder import HEADS
from ..losses import smooth_l1_loss
from .anchor_head import AnchorHead


# TODO: add loss evaluator for SSD
[docs]@HEADS.register_module() class SSDHead(AnchorHead): def __init__(self, num_classes=80, in_channels=(512, 1024, 512, 256, 256, 256), anchor_generator=dict( type='SSDAnchorGenerator', scale_major=False, input_size=300, strides=[8, 16, 32, 64, 100, 300], ratios=([2], [2, 3], [2, 3], [2, 3], [2], [2]), basesize_ratio_range=(0.1, 0.9)), background_label=None, bbox_coder=dict( type='DeltaXYWHBBoxCoder', target_means=[.0, .0, .0, .0], target_stds=[1.0, 1.0, 1.0, 1.0], ), reg_decoded_bbox=False, train_cfg=None, test_cfg=None): super(AnchorHead, self).__init__() self.num_classes = num_classes self.in_channels = in_channels self.cls_out_channels = num_classes + 1 # add background class self.anchor_generator = build_anchor_generator(anchor_generator) num_anchors = self.anchor_generator.num_base_anchors reg_convs = [] cls_convs = [] for i in range(len(in_channels)): reg_convs.append( nn.Conv2d( in_channels[i], num_anchors[i] * 4, kernel_size=3, padding=1)) cls_convs.append( nn.Conv2d( in_channels[i], num_anchors[i] * (num_classes + 1), kernel_size=3, padding=1)) self.reg_convs = nn.ModuleList(reg_convs) self.cls_convs = nn.ModuleList(cls_convs) self.background_label = ( num_classes if background_label is None else background_label) # background_label should be either 0 or num_classes assert (self.background_label == 0 or self.background_label == num_classes) self.bbox_coder = build_bbox_coder(bbox_coder) self.reg_decoded_bbox = reg_decoded_bbox self.use_sigmoid_cls = False self.cls_focal_loss = False self.train_cfg = train_cfg self.test_cfg = test_cfg # set sampling=False for archor_target self.sampling = False if self.train_cfg: self.assigner = build_assigner(self.train_cfg.assigner) # SSD sampling=False so use PseudoSampler sampler_cfg = dict(type='PseudoSampler') self.sampler = build_sampler(sampler_cfg, context=self) self.fp16_enabled = False def init_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): xavier_init(m, distribution='uniform', bias=0)
[docs] def forward(self, feats): cls_scores = [] bbox_preds = [] for feat, reg_conv, cls_conv in zip(feats, self.reg_convs, self.cls_convs): cls_scores.append(cls_conv(feat)) bbox_preds.append(reg_conv(feat)) return cls_scores, bbox_preds
def loss_single(self, cls_score, bbox_pred, anchor, labels, label_weights, bbox_targets, bbox_weights, num_total_samples): loss_cls_all = F.cross_entropy( cls_score, labels, reduction='none') * label_weights # FG cat_id: [0, num_classes -1], BG cat_id: num_classes pos_inds = ((labels >= 0) & (labels < self.background_label)).nonzero().reshape(-1) neg_inds = (labels == self.background_label).nonzero().view(-1) num_pos_samples = pos_inds.size(0) num_neg_samples = self.train_cfg.neg_pos_ratio * num_pos_samples if num_neg_samples > neg_inds.size(0): num_neg_samples = neg_inds.size(0) topk_loss_cls_neg, _ = loss_cls_all[neg_inds].topk(num_neg_samples) loss_cls_pos = loss_cls_all[pos_inds].sum() loss_cls_neg = topk_loss_cls_neg.sum() loss_cls = (loss_cls_pos + loss_cls_neg) / num_total_samples if self.reg_decoded_bbox: bbox_pred = self.bbox_coder.decode(anchor, bbox_pred) loss_bbox = smooth_l1_loss( bbox_pred, bbox_targets, bbox_weights, beta=self.train_cfg.smoothl1_beta, avg_factor=num_total_samples) return loss_cls[None], loss_bbox def loss(self, cls_scores, bbox_preds, gt_bboxes, gt_labels, img_metas, gt_bboxes_ignore=None): featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] assert len(featmap_sizes) == self.anchor_generator.num_levels device = cls_scores[0].device anchor_list, valid_flag_list = self.get_anchors( featmap_sizes, img_metas, device=device) cls_reg_targets = self.get_targets( anchor_list, valid_flag_list, gt_bboxes, img_metas, gt_bboxes_ignore_list=gt_bboxes_ignore, gt_labels_list=gt_labels, label_channels=1, unmap_outputs=False) if cls_reg_targets is None: return None (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, num_total_pos, num_total_neg) = cls_reg_targets num_images = len(img_metas) all_cls_scores = torch.cat([ s.permute(0, 2, 3, 1).reshape( num_images, -1, self.cls_out_channels) for s in cls_scores ], 1) all_labels = torch.cat(labels_list, -1).view(num_images, -1) all_label_weights = torch.cat(label_weights_list, -1).view(num_images, -1) all_bbox_preds = torch.cat([ b.permute(0, 2, 3, 1).reshape(num_images, -1, 4) for b in bbox_preds ], -2) all_bbox_targets = torch.cat(bbox_targets_list, -2).view(num_images, -1, 4) all_bbox_weights = torch.cat(bbox_weights_list, -2).view(num_images, -1, 4) # concat all level anchors to a single tensor all_anchors = [] for i in range(num_images): all_anchors.append(torch.cat(anchor_list[i])) # check NaN and Inf assert torch.isfinite(all_cls_scores).all().item(), \ 'classification scores become infinite or NaN!' assert torch.isfinite(all_bbox_preds).all().item(), \ 'bbox predications become infinite or NaN!' losses_cls, losses_bbox = multi_apply( self.loss_single, all_cls_scores, all_bbox_preds, all_anchors, all_labels, all_label_weights, all_bbox_targets, all_bbox_weights, num_total_samples=num_total_pos) return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)