Shortcuts

Source code for mmdet.models.losses.kd_loss

# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import torch.nn as nn
import torch.nn.functional as F

from ..builder import LOSSES
from .utils import weighted_loss


@mmcv.jit(derivate=True, coderize=True)
@weighted_loss
def knowledge_distillation_kl_div_loss(pred,
                                       soft_label,
                                       T,
                                       detach_target=True):
    r"""Loss function for knowledge distilling using KL divergence.

    Args:
        pred (Tensor): Predicted logits with shape (N, n + 1).
        soft_label (Tensor): Target logits with shape (N, N + 1).
        T (int): Temperature for distillation.
        detach_target (bool): Remove soft_label from automatic differentiation

    Returns:
        torch.Tensor: Loss tensor with shape (N,).
    """
    assert pred.size() == soft_label.size()
    target = F.softmax(soft_label / T, dim=1)
    if detach_target:
        target = target.detach()

    kd_loss = F.kl_div(
        F.log_softmax(pred / T, dim=1), target, reduction='none').mean(1) * (
            T * T)

    return kd_loss


[docs]@LOSSES.register_module() class KnowledgeDistillationKLDivLoss(nn.Module): """Loss function for knowledge distilling using KL divergence. Args: reduction (str): Options are `'none'`, `'mean'` and `'sum'`. loss_weight (float): Loss weight of current loss. T (int): Temperature for distillation. """ def __init__(self, reduction='mean', loss_weight=1.0, T=10): super(KnowledgeDistillationKLDivLoss, self).__init__() assert T >= 1 self.reduction = reduction self.loss_weight = loss_weight self.T = T
[docs] def forward(self, pred, soft_label, weight=None, avg_factor=None, reduction_override=None): """Forward function. Args: pred (Tensor): Predicted logits with shape (N, n + 1). soft_label (Tensor): Target logits with shape (N, N + 1). weight (torch.Tensor, optional): The weight of loss for each prediction. Defaults to None. avg_factor (int, optional): Average factor that is used to average the loss. Defaults to None. reduction_override (str, optional): The reduction method used to override the original reduction method of the loss. Defaults to None. """ assert reduction_override in (None, 'none', 'mean', 'sum') reduction = ( reduction_override if reduction_override else self.reduction) loss_kd = self.loss_weight * knowledge_distillation_kl_div_loss( pred, soft_label, weight, reduction=reduction, avg_factor=avg_factor, T=self.T) return loss_kd
Read the Docs v: v2.19.1
Versions
latest
stable
v2.19.1
v2.19.0
v2.18.1
v2.18.0
v2.17.0
v2.16.0
v2.15.1
v2.15.0
v2.14.0
v2.13.0
v2.12.0
v2.11.0
v2.10.0
v2.9.0
v2.8.0
v2.7.0
v2.6.0
v2.5.0
v2.4.0
v2.3.0
v2.2.1
v2.2.0
v2.1.0
v2.0.0
v1.2.0
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.