Shortcuts

Source code for mmdet.models.dense_heads.ld_head

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.runner import force_fp32

from mmdet.core import bbox_overlaps, multi_apply, reduce_mean
from ..builder import HEADS, build_loss
from .gfl_head import GFLHead


[docs]@HEADS.register_module() class LDHead(GFLHead): """Localization distillation Head. (Short description) It utilizes the learned bbox distributions to transfer the localization dark knowledge from teacher to student. Original paper: `Localization Distillation for Object Detection. <https://arxiv.org/abs/2102.12252>`_ Args: num_classes (int): Number of categories excluding the background category. in_channels (int): Number of channels in the input feature map. loss_ld (dict): Config of Localization Distillation Loss (LD), T is the temperature for distillation. """ def __init__(self, num_classes, in_channels, loss_ld=dict( type='LocalizationDistillationLoss', loss_weight=0.25, T=10), **kwargs): super(LDHead, self).__init__(num_classes, in_channels, **kwargs) self.loss_ld = build_loss(loss_ld)
[docs] def loss_single(self, anchors, cls_score, bbox_pred, labels, label_weights, bbox_targets, stride, soft_targets, num_total_samples): """Compute loss of a single scale level. Args: anchors (Tensor): Box reference for each scale level with shape (N, num_total_anchors, 4). cls_score (Tensor): Cls and quality joint scores for each scale level has shape (N, num_classes, H, W). bbox_pred (Tensor): Box distribution logits for each scale level with shape (N, 4*(n+1), H, W), n is max value of integral set. labels (Tensor): Labels of each anchors with shape (N, num_total_anchors). label_weights (Tensor): Label weights of each anchor with shape (N, num_total_anchors) bbox_targets (Tensor): BBox regression targets of each anchor weight shape (N, num_total_anchors, 4). stride (tuple): Stride in this scale level. num_total_samples (int): Number of positive samples that is reduced over all GPUs. Returns: dict[tuple, Tensor]: Loss components and weight targets. """ assert stride[0] == stride[1], 'h stride is not equal to w stride!' anchors = anchors.reshape(-1, 4) cls_score = cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels) bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4 * (self.reg_max + 1)) soft_targets = soft_targets.permute(0, 2, 3, 1).reshape(-1, 4 * (self.reg_max + 1)) bbox_targets = bbox_targets.reshape(-1, 4) labels = labels.reshape(-1) label_weights = label_weights.reshape(-1) # FG cat_id: [0, num_classes -1], BG cat_id: num_classes bg_class_ind = self.num_classes pos_inds = ((labels >= 0) & (labels < bg_class_ind)).nonzero().squeeze(1) score = label_weights.new_zeros(labels.shape) if len(pos_inds) > 0: pos_bbox_targets = bbox_targets[pos_inds] pos_bbox_pred = bbox_pred[pos_inds] pos_anchors = anchors[pos_inds] pos_anchor_centers = self.anchor_center(pos_anchors) / stride[0] weight_targets = cls_score.detach().sigmoid() weight_targets = weight_targets.max(dim=1)[0][pos_inds] pos_bbox_pred_corners = self.integral(pos_bbox_pred) pos_decode_bbox_pred = self.bbox_coder.decode( pos_anchor_centers, pos_bbox_pred_corners) pos_decode_bbox_targets = pos_bbox_targets / stride[0] score[pos_inds] = bbox_overlaps( pos_decode_bbox_pred.detach(), pos_decode_bbox_targets, is_aligned=True) pred_corners = pos_bbox_pred.reshape(-1, self.reg_max + 1) pos_soft_targets = soft_targets[pos_inds] soft_corners = pos_soft_targets.reshape(-1, self.reg_max + 1) target_corners = self.bbox_coder.encode(pos_anchor_centers, pos_decode_bbox_targets, self.reg_max).reshape(-1) # regression loss loss_bbox = self.loss_bbox( pos_decode_bbox_pred, pos_decode_bbox_targets, weight=weight_targets, avg_factor=1.0) # dfl loss loss_dfl = self.loss_dfl( pred_corners, target_corners, weight=weight_targets[:, None].expand(-1, 4).reshape(-1), avg_factor=4.0) # ld loss loss_ld = self.loss_ld( pred_corners, soft_corners, weight=weight_targets[:, None].expand(-1, 4).reshape(-1), avg_factor=4.0) else: loss_ld = bbox_pred.sum() * 0 loss_bbox = bbox_pred.sum() * 0 loss_dfl = bbox_pred.sum() * 0 weight_targets = bbox_pred.new_tensor(0) # cls (qfl) loss loss_cls = self.loss_cls( cls_score, (labels, score), weight=label_weights, avg_factor=num_total_samples) return loss_cls, loss_bbox, loss_dfl, loss_ld, weight_targets.sum()
[docs] def forward_train(self, x, out_teacher, img_metas, gt_bboxes, gt_labels=None, gt_bboxes_ignore=None, proposal_cfg=None, **kwargs): """ Args: x (list[Tensor]): Features from FPN. img_metas (list[dict]): Meta information of each image, e.g., image size, scaling factor, etc. gt_bboxes (Tensor): Ground truth bboxes of the image, shape (num_gts, 4). gt_labels (Tensor): Ground truth labels of each box, shape (num_gts,). gt_bboxes_ignore (Tensor): Ground truth bboxes to be ignored, shape (num_ignored_gts, 4). proposal_cfg (mmcv.Config): Test / postprocessing configuration, if None, test_cfg would be used Returns: tuple[dict, list]: The loss components and proposals of each image. - losses (dict[str, Tensor]): A dictionary of loss components. - proposal_list (list[Tensor]): Proposals of each image. """ outs = self(x) soft_target = out_teacher[1] if gt_labels is None: loss_inputs = outs + (gt_bboxes, soft_target, img_metas) else: loss_inputs = outs + (gt_bboxes, gt_labels, soft_target, img_metas) losses = self.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore) if proposal_cfg is None: return losses else: proposal_list = self.get_bboxes(*outs, img_metas, cfg=proposal_cfg) return losses, proposal_list
[docs] @force_fp32(apply_to=('cls_scores', 'bbox_preds')) def loss(self, cls_scores, bbox_preds, gt_bboxes, gt_labels, soft_target, img_metas, gt_bboxes_ignore=None): """Compute losses of the head. Args: cls_scores (list[Tensor]): Cls and quality scores for each scale level has shape (N, num_classes, H, W). bbox_preds (list[Tensor]): Box distribution logits for each scale level with shape (N, 4*(n+1), H, W), n is max value of integral set. gt_bboxes (list[Tensor]): Ground truth bboxes for each image with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. gt_labels (list[Tensor]): class indices corresponding to each box img_metas (list[dict]): Meta information of each image, e.g., image size, scaling factor, etc. gt_bboxes_ignore (list[Tensor] | None): specify which bounding boxes can be ignored when computing the loss. Returns: dict[str, Tensor]: A dictionary of loss components. """ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] assert len(featmap_sizes) == self.prior_generator.num_levels device = cls_scores[0].device anchor_list, valid_flag_list = self.get_anchors( featmap_sizes, img_metas, device=device) label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1 cls_reg_targets = self.get_targets( anchor_list, valid_flag_list, gt_bboxes, img_metas, gt_bboxes_ignore_list=gt_bboxes_ignore, gt_labels_list=gt_labels, label_channels=label_channels) if cls_reg_targets is None: return None (anchor_list, labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, num_total_pos, num_total_neg) = cls_reg_targets num_total_samples = reduce_mean( torch.tensor(num_total_pos, dtype=torch.float, device=device)).item() num_total_samples = max(num_total_samples, 1.0) losses_cls, losses_bbox, losses_dfl, losses_ld, \ avg_factor = multi_apply( self.loss_single, anchor_list, cls_scores, bbox_preds, labels_list, label_weights_list, bbox_targets_list, self.prior_generator.strides, soft_target, num_total_samples=num_total_samples) avg_factor = sum(avg_factor) + 1e-6 avg_factor = reduce_mean(avg_factor).item() losses_bbox = [x / avg_factor for x in losses_bbox] losses_dfl = [x / avg_factor for x in losses_dfl] return dict( loss_cls=losses_cls, loss_bbox=losses_bbox, loss_dfl=losses_dfl, loss_ld=losses_ld)
Read the Docs v: v2.19.1
Versions
latest
stable
v2.19.1
v2.19.0
v2.18.1
v2.18.0
v2.17.0
v2.16.0
v2.15.1
v2.15.0
v2.14.0
v2.13.0
v2.12.0
v2.11.0
v2.10.0
v2.9.0
v2.8.0
v2.7.0
v2.6.0
v2.5.0
v2.4.0
v2.3.0
v2.2.1
v2.2.0
v2.1.0
v2.0.0
v1.2.0
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.