Shortcuts

Source code for mmdet.models.detectors.lad

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.runner import load_checkpoint

from ..builder import DETECTORS, build_backbone, build_head, build_neck
from .kd_one_stage import KnowledgeDistillationSingleStageDetector


[docs]@DETECTORS.register_module() class LAD(KnowledgeDistillationSingleStageDetector): """Implementation of `LAD <https://arxiv.org/pdf/2108.10520.pdf>`_.""" def __init__(self, backbone, neck, bbox_head, teacher_backbone, teacher_neck, teacher_bbox_head, teacher_ckpt, eval_teacher=True, train_cfg=None, test_cfg=None, pretrained=None): super(KnowledgeDistillationSingleStageDetector, self).__init__(backbone, neck, bbox_head, train_cfg, test_cfg, pretrained) self.eval_teacher = eval_teacher self.teacher_model = nn.Module() self.teacher_model.backbone = build_backbone(teacher_backbone) if teacher_neck is not None: self.teacher_model.neck = build_neck(teacher_neck) teacher_bbox_head.update(train_cfg=train_cfg) teacher_bbox_head.update(test_cfg=test_cfg) self.teacher_model.bbox_head = build_head(teacher_bbox_head) if teacher_ckpt is not None: load_checkpoint( self.teacher_model, teacher_ckpt, map_location='cpu') @property def with_teacher_neck(self): """bool: whether the detector has a teacher_neck""" return hasattr(self.teacher_model, 'neck') and \ self.teacher_model.neck is not None
[docs] def extract_teacher_feat(self, img): """Directly extract teacher features from the backbone+neck.""" x = self.teacher_model.backbone(img) if self.with_teacher_neck: x = self.teacher_model.neck(x) return x
[docs] def forward_train(self, img, img_metas, gt_bboxes, gt_labels, gt_bboxes_ignore=None): """ Args: img (Tensor): Input images of shape (N, C, H, W). Typically these should be mean centered and std scaled. img_metas (list[dict]): A List of image info dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details on the values of these keys see :class:`mmdet.datasets.pipelines.Collect`. gt_bboxes (list[Tensor]): Each item are the truth boxes for each image in [tl_x, tl_y, br_x, br_y] format. gt_labels (list[Tensor]): Class indices corresponding to each box gt_bboxes_ignore (None | list[Tensor]): Specify which bounding boxes can be ignored when computing the loss. Returns: dict[str, Tensor]: A dictionary of loss components. """ # get label assignment from the teacher with torch.no_grad(): x_teacher = self.extract_teacher_feat(img) outs_teacher = self.teacher_model.bbox_head(x_teacher) label_assignment_results = \ self.teacher_model.bbox_head.get_label_assignment( *outs_teacher, gt_bboxes, gt_labels, img_metas, gt_bboxes_ignore) # the student use the label assignment from the teacher to learn x = self.extract_feat(img) losses = self.bbox_head.forward_train(x, label_assignment_results, img_metas, gt_bboxes, gt_labels, gt_bboxes_ignore) return losses
Read the Docs v: v2.21.0
Versions
latest
stable
v2.21.0
v2.20.0
v2.19.1
v2.19.0
v2.18.1
v2.18.0
v2.17.0
v2.16.0
v2.15.1
v2.15.0
v2.14.0
v2.13.0
v2.12.0
v2.11.0
v2.10.0
v2.9.0
v2.8.0
v2.7.0
v2.6.0
v2.5.0
v2.4.0
v2.3.0
v2.2.1
v2.2.0
v2.1.0
v2.0.0
v1.2.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.