Shortcuts

Source code for mmdet.models.dense_heads.ascend_anchor_head

# Copyright (c) OpenMMLab. All rights reserved.
import torch

from ...core.bbox.assigners import AscendMaxIoUAssigner
from ...core.bbox.samplers import PseudoSampler
from ...utils import (batch_images_to_levels, get_max_num_gt_division_factor,
                      masked_fill)
from ..builder import HEADS
from .anchor_head import AnchorHead


[docs]@HEADS.register_module() class AscendAnchorHead(AnchorHead): """Ascend Anchor-based head (RetinaNet, SSD, etc.). Args: num_classes (int): Number of categories excluding the background category. in_channels (int): Number of channels in the input feature map. feat_channels (int): Number of hidden channels. Used in child classes. anchor_generator (dict): Config dict for anchor generator bbox_coder (dict): Config of bounding box coder. reg_decoded_bbox (bool): If true, the regression loss would be applied directly on decoded bounding boxes, converting both the predicted boxes and regression targets to absolute coordinates format. Default False. It should be `True` when using `IoULoss`, `GIoULoss`, or `DIoULoss` in the bbox head. loss_cls (dict): Config of classification loss. loss_bbox (dict): Config of localization loss. train_cfg (dict): Training config of anchor head. test_cfg (dict): Testing config of anchor head. init_cfg (dict or list[dict], optional): Initialization config dict. """ # noqa: W605 def __init__(self, num_classes, in_channels, feat_channels=256, anchor_generator=dict( type='AnchorGenerator', scales=[8, 16, 32], ratios=[0.5, 1.0, 2.0], strides=[4, 8, 16, 32, 64]), bbox_coder=dict( type='DeltaXYWHBBoxCoder', clip_border=True, target_means=(.0, .0, .0, .0), target_stds=(1.0, 1.0, 1.0, 1.0)), reg_decoded_bbox=False, loss_cls=dict( type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), loss_bbox=dict( type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0), train_cfg=None, test_cfg=None, init_cfg=dict(type='Normal', layer='Conv2d', std=0.01)): super(AscendAnchorHead, self).__init__( num_classes=num_classes, in_channels=in_channels, feat_channels=feat_channels, anchor_generator=anchor_generator, bbox_coder=bbox_coder, reg_decoded_bbox=reg_decoded_bbox, loss_cls=loss_cls, loss_bbox=loss_bbox, train_cfg=train_cfg, test_cfg=test_cfg, init_cfg=init_cfg)
[docs] def get_batch_gt_bboxes(self, gt_bboxes_list, num_images, gt_nums, device, max_gt_labels): """Get ground truth bboxes of all image. Args: gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image. num_images (int): The num of images. gt_nums(list[int]): The ground truth bboxes num of each image. device (torch.device | str): Device for returned tensors max_gt_labels(int): The max ground truth bboxes num of all image. Returns: batch_gt_bboxes: (Tensor): Ground truth bboxes of all image. """ # a static ground truth boxes. # Save static gt. Related to Ascend. Helps improve performance if not hasattr(self, 'batch_gt_bboxes'): self.batch_gt_bboxes = {} # a min anchor filled the excess anchor if not hasattr(self, 'min_anchor'): self.min_anchor = (-1354, -1344) if gt_bboxes_list is None: batch_gt_bboxes = None else: if self.batch_gt_bboxes.get(max_gt_labels) is None: batch_gt_bboxes = torch.zeros((num_images, max_gt_labels, 4), dtype=gt_bboxes_list[0].dtype, device=device) batch_gt_bboxes[:, :, :2] = self.min_anchor[0] batch_gt_bboxes[:, :, 2:] = self.min_anchor[1] self.batch_gt_bboxes[max_gt_labels] = batch_gt_bboxes.clone() else: batch_gt_bboxes = self.batch_gt_bboxes.get( max_gt_labels).clone() for index_imgs, gt_bboxes in enumerate(gt_bboxes_list): batch_gt_bboxes[index_imgs, :gt_nums[index_imgs]] = gt_bboxes return batch_gt_bboxes
[docs] def get_batch_gt_bboxes_ignore(self, gt_bboxes_ignore_list, num_images, gt_nums, device): """Ground truth bboxes to be ignored of all image. Args: gt_bboxes_ignore_list (list[Tensor]): Ground truth bboxes to be ignored. num_images (int): The num of images. gt_nums(list[int]): The ground truth bboxes num of each image. device (torch.device | str): Device for returned tensors Returns: batch_gt_bboxes_ignore: (Tensor): Ground truth bboxes to be ignored of all image. """ # TODO: support gt_bboxes_ignore_list if gt_bboxes_ignore_list is None: batch_gt_bboxes_ignore = None else: raise RuntimeError('gt_bboxes_ignore not support yet') return batch_gt_bboxes_ignore
[docs] def get_batch_gt_labels(self, gt_labels_list, num_images, gt_nums, device, max_gt_labels): """Ground truth bboxes to be ignored of all image. Args: gt_labels_list (list[Tensor]): Ground truth labels. num_images (int): The num of images. gt_nums(list[int]): The ground truth bboxes num of each image. device (torch.device | str): Device for returned tensors Returns: batch_gt_labels: (Tensor): Ground truth labels of all image. """ if gt_labels_list is None: batch_gt_labels = None else: batch_gt_labels = torch.zeros((num_images, max_gt_labels), dtype=gt_labels_list[0].dtype, device=device) for index_imgs, gt_labels in enumerate(gt_labels_list): batch_gt_labels[index_imgs, :gt_nums[index_imgs]] = gt_labels return batch_gt_labels
def _get_targets_concat(self, batch_anchors, batch_valid_flags, batch_gt_bboxes, batch_gt_bboxes_ignore, batch_gt_labels, img_metas, label_channels=1, unmap_outputs=True): """Compute regression and classification targets for anchors in all images. Args: batch_anchors (Tensor): anchors of all image, which are concatenated into a single tensor of shape (num_imgs, num_anchors ,4). batch_valid_flags (Tensor): valid flags of all image, which are concatenated into a single tensor of shape (num_imgs, num_anchors,). batch_gt_bboxes (Tensor): Ground truth bboxes of all image, shape (num_imgs, max_gt_nums, 4). batch_gt_bboxes_ignore (Tensor): Ground truth bboxes to be ignored, shape (num_imgs, num_ignored_gts, 4). batch_gt_labels (Tensor): Ground truth labels of each box, shape (num_imgs, max_gt_nums,). img_metas (list[dict]): Meta info of each image. label_channels (int): Channel of label. unmap_outputs (bool): Whether to map outputs back to the original set of anchors. Returns: tuple: batch_labels (Tensor): Labels of all level batch_label_weights (Tensor): Label weights of all level batch_bbox_targets (Tensor): BBox targets of all level batch_bbox_weights (Tensor): BBox weights of all level batch_pos_mask (Tensor): Positive samples mask in all images batch_neg_mask (Tensor): Negative samples mask in all images sampling_result (Sampling): The result of sampling, default: None. """ num_imgs, num_anchors, _ = batch_anchors.size() # assign gt and sample batch_anchors assign_result = self.assigner.assign( batch_anchors, batch_gt_bboxes, batch_gt_bboxes_ignore, None if self.sampling else batch_gt_labels, batch_bboxes_ignore_mask=batch_valid_flags) # TODO: support sampling_result sampling_result = None batch_pos_mask = assign_result.batch_pos_mask batch_neg_mask = assign_result.batch_neg_mask batch_anchor_gt_indes = assign_result.batch_anchor_gt_indes batch_anchor_gt_labels = assign_result.batch_anchor_gt_labels batch_anchor_gt_bboxes = torch.zeros( batch_anchors.size(), dtype=batch_anchors.dtype, device=batch_anchors.device) for index_imgs in range(num_imgs): batch_anchor_gt_bboxes[index_imgs] = torch.index_select( batch_gt_bboxes[index_imgs], 0, batch_anchor_gt_indes[index_imgs]) batch_bbox_targets = torch.zeros_like(batch_anchors) batch_bbox_weights = torch.zeros_like(batch_anchors) batch_labels = batch_anchors.new_full((num_imgs, num_anchors), self.num_classes, dtype=torch.int) batch_label_weights = batch_anchors.new_zeros((num_imgs, num_anchors), dtype=torch.float) if not self.reg_decoded_bbox: batch_pos_bbox_targets = self.bbox_coder.encode( batch_anchors, batch_anchor_gt_bboxes) else: batch_pos_bbox_targets = batch_anchor_gt_bboxes batch_bbox_targets = masked_fill(batch_bbox_targets, batch_pos_mask.unsqueeze(2), batch_pos_bbox_targets) batch_bbox_weights = masked_fill(batch_bbox_weights, batch_pos_mask.unsqueeze(2), 1.0) if batch_gt_labels is None: batch_labels = masked_fill(batch_labels, batch_pos_mask, 0.0) else: batch_labels = masked_fill(batch_labels, batch_pos_mask, batch_anchor_gt_labels) if self.train_cfg.pos_weight <= 0: batch_label_weights = masked_fill(batch_label_weights, batch_pos_mask, 1.0) else: batch_label_weights = masked_fill(batch_label_weights, batch_pos_mask, self.train_cfg.pos_weight) batch_label_weights = masked_fill(batch_label_weights, batch_neg_mask, 1.0) return (batch_labels, batch_label_weights, batch_bbox_targets, batch_bbox_weights, batch_pos_mask, batch_neg_mask, sampling_result)
[docs] def get_targets(self, anchor_list, valid_flag_list, gt_bboxes_list, img_metas, gt_bboxes_ignore_list=None, gt_labels_list=None, label_channels=1, unmap_outputs=True, return_sampling_results=False, return_level=True): """Compute regression and classification targets for anchors in multiple images. Args: anchor_list (list[list[Tensor]]): Multi level anchors of each image. The outer list indicates images, and the inner list corresponds to feature levels of the image. Each element of the inner list is a tensor of shape (num_anchors, 4). valid_flag_list (list[list[Tensor]]): Multi level valid flags of each image. The outer list indicates images, and the inner list corresponds to feature levels of the image. Each element of the inner list is a tensor of shape (num_anchors, ) gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image. img_metas (list[dict]): Meta info of each image. gt_bboxes_ignore_list (list[Tensor]): Ground truth bboxes to be ignored. gt_labels_list (list[Tensor]): Ground truth labels of each box. label_channels (int): Channel of label. unmap_outputs (bool): Whether to map outputs back to the original set of anchors. return_sampling_results (bool): Whether to return the result of sample. return_level (bool): Whether to map outputs back to the levels of feature map sizes. Returns: tuple: Usually returns a tuple containing learning targets. - labels_list (list[Tensor]): Labels of each level. - label_weights_list (list[Tensor]): Label weights of each level. - bbox_targets_list (list[Tensor]): BBox targets of each level. - bbox_weights_list (list[Tensor]): BBox weights of each level. - num_total_pos (int): Number of positive samples in all images. - num_total_neg (int): Number of negative samples in all images. additional_returns: This function enables user-defined returns from `self._get_targets_single`. These returns are currently refined to properties at each feature map (i.e. having HxW dimension). The results will be concatenated after the end """ assert gt_bboxes_ignore_list is None assert unmap_outputs is True assert return_sampling_results is False assert self.train_cfg.allowed_border < 0 assert isinstance(self.assigner, AscendMaxIoUAssigner) assert isinstance(self.sampler, PseudoSampler) num_imgs = len(img_metas) assert len(anchor_list) == len(valid_flag_list) == num_imgs device = anchor_list[0][0].device num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]] batch_anchor_list = [] batch_valid_flag_list = [] for i in range(num_imgs): assert len(anchor_list[i]) == len(valid_flag_list[i]) batch_anchor_list.append(torch.cat(anchor_list[i])) batch_valid_flag_list.append(torch.cat(valid_flag_list[i])) batch_anchors = torch.cat( [torch.unsqueeze(anchor, 0) for anchor in batch_anchor_list], 0) batch_valid_flags = torch.cat([ torch.unsqueeze(batch_valid_flag, 0) for batch_valid_flag in batch_valid_flag_list ], 0) gt_nums = [len(gt_bbox) for gt_bbox in gt_bboxes_list] max_gt_nums = get_max_num_gt_division_factor(gt_nums) batch_gt_bboxes = self.get_batch_gt_bboxes(gt_bboxes_list, num_imgs, gt_nums, device, max_gt_nums) batch_gt_bboxes_ignore = self.get_batch_gt_bboxes_ignore( gt_bboxes_ignore_list, num_imgs, gt_nums, device) batch_gt_labels = self.get_batch_gt_labels(gt_labels_list, num_imgs, gt_nums, device, max_gt_nums) results = self._get_targets_concat( batch_anchors, batch_valid_flags, batch_gt_bboxes, batch_gt_bboxes_ignore, batch_gt_labels, img_metas, label_channels=label_channels, unmap_outputs=unmap_outputs) (batch_labels, batch_label_weights, batch_bbox_targets, batch_bbox_weights, batch_pos_mask, batch_neg_mask, sampling_result) = results[:7] rest_results = list(results[7:]) # user-added return values # sampled anchors of all images min_num = torch.ones((num_imgs, ), dtype=torch.long, device=batch_pos_mask.device) num_total_pos = torch.sum( torch.max(torch.sum(batch_pos_mask, dim=1), min_num)) num_total_neg = torch.sum( torch.max(torch.sum(batch_neg_mask, dim=1), min_num)) if return_level is True: labels_list = batch_images_to_levels(batch_labels, num_level_anchors) label_weights_list = batch_images_to_levels( batch_label_weights, num_level_anchors) bbox_targets_list = batch_images_to_levels(batch_bbox_targets, num_level_anchors) bbox_weights_list = batch_images_to_levels(batch_bbox_weights, num_level_anchors) res = (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, num_total_pos, num_total_neg) if return_sampling_results: res = res + (sampling_result, ) for i, r in enumerate(rest_results): # user-added return values rest_results[i] = batch_images_to_levels(r, num_level_anchors) return res + tuple(rest_results) else: res = (batch_labels, batch_label_weights, batch_bbox_targets, batch_bbox_weights, batch_pos_mask, batch_neg_mask, sampling_result, num_total_pos, num_total_neg, batch_anchors) return res
Read the Docs v: dev
Versions
latest
stable
3.x
v3.0.0rc0
v2.28.2
v2.28.1
v2.28.0
v2.27.0
v2.26.0
v2.25.3
v2.25.2
v2.25.1
v2.25.0
v2.24.1
v2.24.0
v2.23.0
v2.22.0
v2.21.0
v2.20.0
v2.19.1
v2.19.0
v2.18.1
v2.18.0
v2.17.0
v2.16.0
v2.15.1
v2.15.0
v2.14.0
v2.13.0
v2.12.0
v2.11.0
v2.10.0
v2.9.0
v2.8.0
v2.7.0
v2.6.0
v2.5.0
v2.4.0
v2.3.0
v2.2.1
v2.2.0
v2.1.0
v2.0.0
v1.2.0
test-3.0.0rc0
dev-3.x
dev
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.