Source code for mmdet.models.losses.mse_loss

import torch.nn as nn
import torch.nn.functional as F

from ..builder import LOSSES
from .utils import weighted_loss


@weighted_loss
def mse_loss(pred, target):
    return F.mse_loss(pred, target, reduction='none')


[docs]@LOSSES.register_module() class MSELoss(nn.Module): def __init__(self, reduction='mean', loss_weight=1.0): super().__init__() self.reduction = reduction self.loss_weight = loss_weight
[docs] def forward(self, pred, target, weight=None, avg_factor=None): loss = self.loss_weight * mse_loss( pred, target, weight, reduction=self.reduction, avg_factor=avg_factor) return loss