Shortcuts

mmdet.models.utils.point_sample 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.ops import point_sample


[文档]def get_uncertainty(mask_pred, labels): """Estimate uncertainty based on pred logits. We estimate uncertainty as L1 distance between 0.0 and the logits prediction in 'mask_pred' for the foreground class in `classes`. Args: mask_pred (Tensor): mask predication logits, shape (num_rois, num_classes, mask_height, mask_width). labels (list[Tensor]): Either predicted or ground truth label for each predicted mask, of length num_rois. Returns: scores (Tensor): Uncertainty scores with the most uncertain locations having the highest uncertainty score, shape (num_rois, 1, mask_height, mask_width) """ if mask_pred.shape[1] == 1: gt_class_logits = mask_pred.clone() else: inds = torch.arange(mask_pred.shape[0], device=mask_pred.device) gt_class_logits = mask_pred[inds, labels].unsqueeze(1) return -torch.abs(gt_class_logits)
[文档]def get_uncertain_point_coords_with_randomness(mask_pred, labels, num_points, oversample_ratio, importance_sample_ratio): """Get ``num_points`` most uncertain points with random points during train. Sample points in [0, 1] x [0, 1] coordinate space based on their uncertainty. The uncertainties are calculated for each point using 'get_uncertainty()' function that takes point's logit prediction as input. Args: mask_pred (Tensor): A tensor of shape (num_rois, num_classes, mask_height, mask_width) for class-specific or class-agnostic prediction. labels (list): The ground truth class for each instance. num_points (int): The number of points to sample. oversample_ratio (int): Oversampling parameter. importance_sample_ratio (float): Ratio of points that are sampled via importnace sampling. Returns: point_coords (Tensor): A tensor of shape (num_rois, num_points, 2) that contains the coordinates sampled points. """ assert oversample_ratio >= 1 assert 0 <= importance_sample_ratio <= 1 batch_size = mask_pred.shape[0] num_sampled = int(num_points * oversample_ratio) point_coords = torch.rand( batch_size, num_sampled, 2, device=mask_pred.device) point_logits = point_sample(mask_pred, point_coords) # It is crucial to calculate uncertainty based on the sampled # prediction value for the points. Calculating uncertainties of the # coarse predictions first and sampling them for points leads to # incorrect results. To illustrate this: assume uncertainty func( # logits)=-abs(logits), a sampled point between two coarse # predictions with -1 and 1 logits has 0 logits, and therefore 0 # uncertainty value. However, if we calculate uncertainties for the # coarse predictions first, both will have -1 uncertainty, # and sampled point will get -1 uncertainty. point_uncertainties = get_uncertainty(point_logits, labels) num_uncertain_points = int(importance_sample_ratio * num_points) num_random_points = num_points - num_uncertain_points idx = torch.topk( point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] shift = num_sampled * torch.arange( batch_size, dtype=torch.long, device=mask_pred.device) idx += shift[:, None] point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view( batch_size, num_uncertain_points, 2) if num_random_points > 0: rand_roi_coords = torch.rand( batch_size, num_random_points, 2, device=mask_pred.device) point_coords = torch.cat((point_coords, rand_roi_coords), dim=1) return point_coords
Read the Docs v: v2.25.1
Versions
latest
stable
v2.25.1
v2.25.0
v2.24.1
v2.24.0
v2.23.0
v2.22.0
v2.21.0
v2.20.0
v2.19.1
v2.19.0
v2.18.1
v2.18.0
v2.17.0
v2.16.0
v2.15.1
v2.15.0
v2.14.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.