import torch
from ..builder import BBOX_CODERS
from .base_bbox_coder import BaseBBoxCoder
[docs]@BBOX_CODERS.register_module()
class TBLRBBoxCoder(BaseBBoxCoder):
"""TBLR BBox coder
Following the practice in `FSAF <https://arxiv.org/abs/1903.00621>`_,
this coder encodes gt bboxes (x1, y1, x2, y2) into (top, bottom, left,
right) and decode it back to the original.
Args:
normalizer (list | float): Normalization factor to be
divided with when coding the coordinates. If it is a list, it should
have length of 4 indicating normalization factor in tblr dims.
Otherwise it is a unified float factor for all dims. Default: 4.0
"""
def __init__(self, normalizer=4.0):
super(BaseBBoxCoder, self).__init__()
self.normalizer = normalizer
[docs] def encode(self, bboxes, gt_bboxes):
"""Get box regression transformation deltas that can be used
to transform the `bboxes` into the `gt_bboxes` in the top,
left, bottom, right order.
Args:
bboxes (torch.Tensor): source boxes, e.g., object proposals.
gt_bboxes (torch.Tensor): target of the transformation, e.g.,
ground truth boxes.
Returns:
torch.Tensor: Box transformation deltas
"""
assert bboxes.size(0) == gt_bboxes.size(0)
assert bboxes.size(-1) == gt_bboxes.size(-1) == 4
encoded_bboxes = bboxes2tblr(
bboxes, gt_bboxes, normalizer=self.normalizer)
return encoded_bboxes
[docs] def decode(self, bboxes, pred_bboxes, max_shape=None):
"""Apply transformation `pred_bboxes` to `boxes`.
Args:
boxes (torch.Tensor): Basic boxes.
pred_bboxes (torch.Tensor): Encoded boxes with shape
max_shape (tuple[int], optional): Maximum shape of boxes.
Defaults to None.
Returns:
torch.Tensor: Decoded boxes.
"""
assert pred_bboxes.size(0) == bboxes.size(0)
decoded_bboxes = tblr2bboxes(
bboxes,
pred_bboxes,
normalizer=self.normalizer,
max_shape=max_shape)
return decoded_bboxes
def bboxes2tblr(priors, gts, normalizer=4.0, normalize_by_wh=True):
"""Encode ground truth boxes to tblr coordinate
It first convert the gt coordinate to tblr format,
(top, bottom, left, right), relative to prior box centers.
The tblr coordinate may be normalized by the side length of prior bboxes
if `normalize_by_wh` is specified as True, and it is then normalized by
the `normalizer` factor.
Args:
priors (Tensor): Prior boxes in point form
Shape: (num_proposals,4).
gts (Tensor): Coords of ground truth for each prior in point-form
Shape: (num_proposals, 4).
normalizer (Sequence[float] | float): normalization parameter of
encoded boxes. If it is a list, it has to have length = 4.
Default: 4.0
normalize_by_wh (bool): Whether to normalize tblr coordinate by the
side length (wh) of prior bboxes.
Return:
encoded boxes (Tensor), Shape: (num_proposals, 4)
"""
# dist b/t match center and prior's center
if not isinstance(normalizer, float):
normalizer = torch.tensor(normalizer, device=priors.device)
assert len(normalizer) == 4, 'Normalizer must have length = 4'
assert priors.size(0) == gts.size(0)
prior_centers = (priors[:, 0:2] + priors[:, 2:4]) / 2
xmin, ymin, xmax, ymax = gts.split(1, dim=1)
top = prior_centers[:, 1].unsqueeze(1) - ymin
bottom = ymax - prior_centers[:, 1].unsqueeze(1)
left = prior_centers[:, 0].unsqueeze(1) - xmin
right = xmax - prior_centers[:, 0].unsqueeze(1)
loc = torch.cat((top, bottom, left, right), dim=1)
if normalize_by_wh:
# Normalize tblr by anchor width and height
wh = priors[:, 2:4] - priors[:, 0:2]
w, h = torch.split(wh, 1, dim=1)
loc[:, :2] /= h # tb is normalized by h
loc[:, 2:] /= w # lr is normalized by w
# Normalize tblr by the given normalization factor
return loc / normalizer
def tblr2bboxes(priors,
tblr,
normalizer=4.0,
normalize_by_wh=True,
max_shape=None):
"""Decode tblr outputs to prediction boxes
The process includes 3 steps: 1) De-normalize tblr coordinates by
multiplying it with `normalizer`; 2) De-normalize tblr coordinates by the
prior bbox width and height if `normalize_by_wh` is `True`; 3) Convert
tblr (top, bottom, left, right) pair relative to the center of priors back
to (xmin, ymin, xmax, ymax) coordinate.
Args:
priors (Tensor): Prior boxes in point form (x0, y0, x1, y1)
Shape: (n,4).
tblr (Tensor): Coords of network output in tblr form
Shape: (n, 4).
normalizer (Sequence[float] | float): Normalization parameter of
encoded boxes. By list, it represents the normalization factors at
tblr dims. By float, it is the unified normalization factor at all
dims. Default: 4.0
normalize_by_wh (bool): Whether the tblr coordinates have been
normalized by the side length (wh) of prior bboxes.
max_shape (tuple, optional): Shape of the image. Decoded bboxes
exceeding which will be clamped.
Return:
encoded boxes (Tensor), Shape: (n, 4)
"""
if not isinstance(normalizer, float):
normalizer = torch.tensor(normalizer, device=priors.device)
assert len(normalizer) == 4, 'Normalizer must have length = 4'
assert priors.size(0) == tblr.size(0)
loc_decode = tblr * normalizer
prior_centers = (priors[:, 0:2] + priors[:, 2:4]) / 2
if normalize_by_wh:
wh = priors[:, 2:4] - priors[:, 0:2]
w, h = torch.split(wh, 1, dim=1)
loc_decode[:, :2] *= h # tb
loc_decode[:, 2:] *= w # lr
top, bottom, left, right = loc_decode.split(1, dim=1)
xmin = prior_centers[:, 0].unsqueeze(1) - left
xmax = prior_centers[:, 0].unsqueeze(1) + right
ymin = prior_centers[:, 1].unsqueeze(1) - top
ymax = prior_centers[:, 1].unsqueeze(1) + bottom
boxes = torch.cat((xmin, ymin, xmax, ymax), dim=1)
if max_shape is not None:
boxes[:, 0].clamp_(min=0, max=max_shape[1])
boxes[:, 1].clamp_(min=0, max=max_shape[0])
boxes[:, 2].clamp_(min=0, max=max_shape[1])
boxes[:, 3].clamp_(min=0, max=max_shape[0])
return boxes