Shortcuts

Source code for mmdet.models.detectors.kd_one_stage

# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import torch
from mmcv.runner import load_checkpoint

from .. import build_detector
from ..builder import DETECTORS
from .single_stage import SingleStageDetector


[docs]@DETECTORS.register_module() class KnowledgeDistillationSingleStageDetector(SingleStageDetector): r"""Implementation of `Distilling the Knowledge in a Neural Network. <https://arxiv.org/abs/1503.02531>`_. Args: teacher_config (str | dict): Config file path or the config object of teacher model. teacher_ckpt (str, optional): Checkpoint path of teacher model. If left as None, the model will not load any weights. """ def __init__(self, backbone, neck, bbox_head, teacher_config, teacher_ckpt=None, eval_teacher=True, train_cfg=None, test_cfg=None, pretrained=None): super().__init__(backbone, neck, bbox_head, train_cfg, test_cfg, pretrained) self.eval_teacher = eval_teacher # Build teacher model if isinstance(teacher_config, str): teacher_config = mmcv.Config.fromfile(teacher_config) self.teacher_model = build_detector(teacher_config['model']) if teacher_ckpt is not None: load_checkpoint( self.teacher_model, teacher_ckpt, map_location='cpu')
[docs] def forward_train(self, img, img_metas, gt_bboxes, gt_labels, gt_bboxes_ignore=None): """ Args: img (Tensor): Input images of shape (N, C, H, W). Typically these should be mean centered and std scaled. img_metas (list[dict]): A List of image info dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details on the values of these keys see :class:`mmdet.datasets.pipelines.Collect`. gt_bboxes (list[Tensor]): Each item are the truth boxes for each image in [tl_x, tl_y, br_x, br_y] format. gt_labels (list[Tensor]): Class indices corresponding to each box gt_bboxes_ignore (None | list[Tensor]): Specify which bounding boxes can be ignored when computing the loss. Returns: dict[str, Tensor]: A dictionary of loss components. """ x = self.extract_feat(img) with torch.no_grad(): teacher_x = self.teacher_model.extract_feat(img) out_teacher = self.teacher_model.bbox_head(teacher_x) losses = self.bbox_head.forward_train(x, out_teacher, img_metas, gt_bboxes, gt_labels, gt_bboxes_ignore) return losses
[docs] def cuda(self, device=None): """Since teacher_model is registered as a plain object, it is necessary to put the teacher model to cuda when calling cuda function.""" self.teacher_model.cuda(device=device) return super().cuda(device=device)
[docs] def train(self, mode=True): """Set the same train mode for teacher and student model.""" if self.eval_teacher: self.teacher_model.train(False) else: self.teacher_model.train(mode) super().train(mode)
def __setattr__(self, name, value): """Set attribute, i.e. self.name = value This reloading prevent the teacher model from being registered as a nn.Module. The teacher module is registered as a plain object, so that the teacher parameters will not show up when calling ``self.parameters``, ``self.modules``, ``self.children`` methods. """ if name == 'teacher_model': object.__setattr__(self, name, value) else: super().__setattr__(name, value)
Read the Docs v: v2.19.0
Versions
latest
stable
v2.19.0
v2.18.1
v2.18.0
v2.17.0
v2.16.0
v2.15.1
v2.15.0
v2.14.0
v2.13.0
v2.12.0
v2.11.0
v2.10.0
v2.9.0
v2.8.0
v2.7.0
v2.6.0
v2.5.0
v2.4.0
v2.3.0
v2.2.1
v2.2.0
v2.1.0
v2.0.0
v1.2.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.