Source code for mmdet.core.bbox.coder.delta_xywh_bbox_coder

import mmcv
import numpy as np
import torch

from ..builder import BBOX_CODERS
from .base_bbox_coder import BaseBBoxCoder


[docs]@BBOX_CODERS.register_module() class DeltaXYWHBBoxCoder(BaseBBoxCoder): """Delta XYWH BBox coder. Following the practice in `R-CNN <https://arxiv.org/abs/1311.2524>`_, this coder encodes bbox (x1, y1, x2, y2) into delta (dx, dy, dw, dh) and decodes delta (dx, dy, dw, dh) back to original bbox (x1, y1, x2, y2). Args: target_means (Sequence[float]): Denormalizing means of target for delta coordinates target_stds (Sequence[float]): Denormalizing standard deviation of target for delta coordinates clip_border (bool, optional): Whether clip the objects outside the border of the image. Defaults to True. """ def __init__(self, target_means=(0., 0., 0., 0.), target_stds=(1., 1., 1., 1.), clip_border=True): super(BaseBBoxCoder, self).__init__() self.means = target_means self.stds = target_stds self.clip_border = clip_border
[docs] def encode(self, bboxes, gt_bboxes): """Get box regression transformation deltas that can be used to transform the ``bboxes`` into the ``gt_bboxes``. Args: bboxes (torch.Tensor): Source boxes, e.g., object proposals. gt_bboxes (torch.Tensor): Target of the transformation, e.g., ground-truth boxes. Returns: torch.Tensor: Box transformation deltas """ assert bboxes.size(0) == gt_bboxes.size(0) assert bboxes.size(-1) == gt_bboxes.size(-1) == 4 encoded_bboxes = bbox2delta(bboxes, gt_bboxes, self.means, self.stds) return encoded_bboxes
[docs] def decode(self, bboxes, pred_bboxes, max_shape=None, wh_ratio_clip=16 / 1000): """Apply transformation `pred_bboxes` to `boxes`. Args: bboxes (torch.Tensor): Basic boxes. Shape (B, N, 4) or (N, 4) pred_bboxes (Tensor): Encoded offsets with respect to each roi. Has shape (B, N, num_classes * 4) or (B, N, 4) or (N, num_classes * 4) or (N, 4). Note N = num_anchors * W * H when rois is a grid of anchors.Offset encoding follows [1]_. max_shape (Sequence[int] or torch.Tensor or Sequence[ Sequence[int]],optional): Maximum bounds for boxes, specifies (H, W, C) or (H, W). If bboxes shape is (B, N, 4), then the max_shape should be a Sequence[Sequence[int]] and the length of max_shape should also be B. wh_ratio_clip (float, optional): The allowed ratio between width and height. Returns: torch.Tensor: Decoded boxes. """ assert pred_bboxes.size(0) == bboxes.size(0) if pred_bboxes.ndim == 3: assert pred_bboxes.size(1) == bboxes.size(1) decoded_bboxes = delta2bbox(bboxes, pred_bboxes, self.means, self.stds, max_shape, wh_ratio_clip, self.clip_border) return decoded_bboxes
@mmcv.jit(coderize=True) def bbox2delta(proposals, gt, means=(0., 0., 0., 0.), stds=(1., 1., 1., 1.)): """Compute deltas of proposals w.r.t. gt. We usually compute the deltas of x, y, w, h of proposals w.r.t ground truth bboxes to get regression target. This is the inverse function of :func:`delta2bbox`. Args: proposals (Tensor): Boxes to be transformed, shape (N, ..., 4) gt (Tensor): Gt bboxes to be used as base, shape (N, ..., 4) means (Sequence[float]): Denormalizing means for delta coordinates stds (Sequence[float]): Denormalizing standard deviation for delta coordinates Returns: Tensor: deltas with shape (N, 4), where columns represent dx, dy, dw, dh. """ assert proposals.size() == gt.size() proposals = proposals.float() gt = gt.float() px = (proposals[..., 0] + proposals[..., 2]) * 0.5 py = (proposals[..., 1] + proposals[..., 3]) * 0.5 pw = proposals[..., 2] - proposals[..., 0] ph = proposals[..., 3] - proposals[..., 1] gx = (gt[..., 0] + gt[..., 2]) * 0.5 gy = (gt[..., 1] + gt[..., 3]) * 0.5 gw = gt[..., 2] - gt[..., 0] gh = gt[..., 3] - gt[..., 1] dx = (gx - px) / pw dy = (gy - py) / ph dw = torch.log(gw / pw) dh = torch.log(gh / ph) deltas = torch.stack([dx, dy, dw, dh], dim=-1) means = deltas.new_tensor(means).unsqueeze(0) stds = deltas.new_tensor(stds).unsqueeze(0) deltas = deltas.sub_(means).div_(stds) return deltas @mmcv.jit(coderize=True) def delta2bbox(rois, deltas, means=(0., 0., 0., 0.), stds=(1., 1., 1., 1.), max_shape=None, wh_ratio_clip=16 / 1000, clip_border=True): """Apply deltas to shift/scale base boxes. Typically the rois are anchor or proposed bounding boxes and the deltas are network outputs used to shift/scale those boxes. This is the inverse function of :func:`bbox2delta`. Args: rois (Tensor): Boxes to be transformed. Has shape (N, 4) or (B, N, 4) deltas (Tensor): Encoded offsets with respect to each roi. Has shape (B, N, num_classes * 4) or (B, N, 4) or (N, num_classes * 4) or (N, 4). Note N = num_anchors * W * H when rois is a grid of anchors.Offset encoding follows [1]_. means (Sequence[float]): Denormalizing means for delta coordinates stds (Sequence[float]): Denormalizing standard deviation for delta coordinates max_shape (Sequence[int] or torch.Tensor or Sequence[ Sequence[int]],optional): Maximum bounds for boxes, specifies (H, W, C) or (H, W). If rois shape is (B, N, 4), then the max_shape should be a Sequence[Sequence[int]] and the length of max_shape should also be B. wh_ratio_clip (float): Maximum aspect ratio for boxes. clip_border (bool, optional): Whether clip the objects outside the border of the image. Defaults to True. Returns: Tensor: Boxes with shape (B, N, num_classes * 4) or (B, N, 4) or (N, num_classes * 4) or (N, 4), where 4 represent tl_x, tl_y, br_x, br_y. References: .. [1] https://arxiv.org/abs/1311.2524 Example: >>> rois = torch.Tensor([[ 0., 0., 1., 1.], >>> [ 0., 0., 1., 1.], >>> [ 0., 0., 1., 1.], >>> [ 5., 5., 5., 5.]]) >>> deltas = torch.Tensor([[ 0., 0., 0., 0.], >>> [ 1., 1., 1., 1.], >>> [ 0., 0., 2., -1.], >>> [ 0.7, -1.9, -0.5, 0.3]]) >>> delta2bbox(rois, deltas, max_shape=(32, 32, 3)) tensor([[0.0000, 0.0000, 1.0000, 1.0000], [0.1409, 0.1409, 2.8591, 2.8591], [0.0000, 0.3161, 4.1945, 0.6839], [5.0000, 5.0000, 5.0000, 5.0000]]) """ means = deltas.new_tensor(means).view(1, -1).repeat(1, deltas.size(-1) // 4) stds = deltas.new_tensor(stds).view(1, -1).repeat(1, deltas.size(-1) // 4) denorm_deltas = deltas * stds + means dx = denorm_deltas[..., 0::4] dy = denorm_deltas[..., 1::4] dw = denorm_deltas[..., 2::4] dh = denorm_deltas[..., 3::4] max_ratio = np.abs(np.log(wh_ratio_clip)) dw = dw.clamp(min=-max_ratio, max=max_ratio) dh = dh.clamp(min=-max_ratio, max=max_ratio) x1, y1 = rois[..., 0], rois[..., 1] x2, y2 = rois[..., 2], rois[..., 3] # Compute center of each roi px = ((x1 + x2) * 0.5).unsqueeze(-1).expand_as(dx) py = ((y1 + y2) * 0.5).unsqueeze(-1).expand_as(dy) # Compute width/height of each roi pw = (x2 - x1).unsqueeze(-1).expand_as(dw) ph = (y2 - y1).unsqueeze(-1).expand_as(dh) # Use exp(network energy) to enlarge/shrink each roi gw = pw * dw.exp() gh = ph * dh.exp() # Use network energy to shift the center of each roi gx = px + pw * dx gy = py + ph * dy # Convert center-xy/width/height to top-left, bottom-right x1 = gx - gw * 0.5 y1 = gy - gh * 0.5 x2 = gx + gw * 0.5 y2 = gy + gh * 0.5 bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view(deltas.size()) if clip_border and max_shape is not None: if not isinstance(max_shape, torch.Tensor): max_shape = x1.new_tensor(max_shape) max_shape = max_shape[..., :2].type_as(x1) if max_shape.ndim == 2: assert bboxes.ndim == 3 assert max_shape.size(0) == bboxes.size(0) min_xy = x1.new_tensor(0) max_xy = torch.cat( [max_shape] * (deltas.size(-1) // 2), dim=-1).flip(-1).unsqueeze(-2) bboxes = torch.where(bboxes < min_xy, min_xy, bboxes) bboxes = torch.where(bboxes > max_xy, max_xy, bboxes) return bboxes