Shortcuts

mmdet.models.utils.misc 源代码

# Copyright (c) OpenMMLab. All rights reserved.
from torch.autograd import Function
from torch.nn import functional as F


class SigmoidGeometricMean(Function):
    """Forward and backward function of geometric mean of two sigmoid
    functions.

    This implementation with analytical gradient function substitutes
    the autograd function of (x.sigmoid() * y.sigmoid()).sqrt(). The
    original implementation incurs none during gradient backprapagation
    if both x and y are very small values.
    """

    @staticmethod
    def forward(ctx, x, y):
        x_sigmoid = x.sigmoid()
        y_sigmoid = y.sigmoid()
        z = (x_sigmoid * y_sigmoid).sqrt()
        ctx.save_for_backward(x_sigmoid, y_sigmoid, z)
        return z

    @staticmethod
    def backward(ctx, grad_output):
        x_sigmoid, y_sigmoid, z = ctx.saved_tensors
        grad_x = grad_output * z * (1 - x_sigmoid) / 2
        grad_y = grad_output * z * (1 - y_sigmoid) / 2
        return grad_x, grad_y


sigmoid_geometric_mean = SigmoidGeometricMean.apply


[文档]def interpolate_as(source, target, mode='bilinear', align_corners=False): """Interpolate the `source` to the shape of the `target`. The `source` must be a Tensor, but the `target` can be a Tensor or a np.ndarray with the shape (..., target_h, target_w). Args: source (Tensor): A 3D/4D Tensor with the shape (N, H, W) or (N, C, H, W). target (Tensor | np.ndarray): The interpolation target with the shape (..., target_h, target_w). mode (str): Algorithm used for interpolation. The options are the same as those in F.interpolate(). Default: ``'bilinear'``. align_corners (bool): The same as the argument in F.interpolate(). Returns: Tensor: The interpolated source Tensor. """ assert len(target.shape) >= 2 def _interpolate_as(source, target, mode='bilinear', align_corners=False): """Interpolate the `source` (4D) to the shape of the `target`.""" target_h, target_w = target.shape[-2:] source_h, source_w = source.shape[-2:] if target_h != source_h or target_w != source_w: source = F.interpolate( source, size=(target_h, target_w), mode=mode, align_corners=align_corners) return source if len(source.shape) == 3: source = source[:, None, :, :] source = _interpolate_as(source, target, mode, align_corners) return source[:, 0, :, :] else: return _interpolate_as(source, target, mode, align_corners)
Read the Docs v: latest
Versions
latest
stable
v2.25.0
v2.24.1
v2.24.0
v2.23.0
v2.22.0
v2.21.0
v2.20.0
v2.19.1
v2.19.0
v2.18.1
v2.18.0
v2.17.0
v2.16.0
v2.15.1
v2.15.0
v2.14.0
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.