Shortcuts

Source code for mmdet.models.roi_heads.mask_heads.grid_head

# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule

from mmdet.models.builder import HEADS, build_loss


[docs]@HEADS.register_module() class GridHead(BaseModule): def __init__(self, grid_points=9, num_convs=8, roi_feat_size=14, in_channels=256, conv_kernel_size=3, point_feat_channels=64, deconv_kernel_size=4, class_agnostic=False, loss_grid=dict( type='CrossEntropyLoss', use_sigmoid=True, loss_weight=15), conv_cfg=None, norm_cfg=dict(type='GN', num_groups=36), init_cfg=[ dict(type='Kaiming', layer=['Conv2d', 'Linear']), dict( type='Normal', layer='ConvTranspose2d', std=0.001, override=dict( type='Normal', name='deconv2', std=0.001, bias=-np.log(0.99 / 0.01))) ]): super(GridHead, self).__init__(init_cfg) self.grid_points = grid_points self.num_convs = num_convs self.roi_feat_size = roi_feat_size self.in_channels = in_channels self.conv_kernel_size = conv_kernel_size self.point_feat_channels = point_feat_channels self.conv_out_channels = self.point_feat_channels * self.grid_points self.class_agnostic = class_agnostic self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg if isinstance(norm_cfg, dict) and norm_cfg['type'] == 'GN': assert self.conv_out_channels % norm_cfg['num_groups'] == 0 assert self.grid_points >= 4 self.grid_size = int(np.sqrt(self.grid_points)) if self.grid_size * self.grid_size != self.grid_points: raise ValueError('grid_points must be a square number') # the predicted heatmap is half of whole_map_size if not isinstance(self.roi_feat_size, int): raise ValueError('Only square RoIs are supporeted in Grid R-CNN') self.whole_map_size = self.roi_feat_size * 4 # compute point-wise sub-regions self.sub_regions = self.calc_sub_regions() self.convs = [] for i in range(self.num_convs): in_channels = ( self.in_channels if i == 0 else self.conv_out_channels) stride = 2 if i == 0 else 1 padding = (self.conv_kernel_size - 1) // 2 self.convs.append( ConvModule( in_channels, self.conv_out_channels, self.conv_kernel_size, stride=stride, padding=padding, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, bias=True)) self.convs = nn.Sequential(*self.convs) self.deconv1 = nn.ConvTranspose2d( self.conv_out_channels, self.conv_out_channels, kernel_size=deconv_kernel_size, stride=2, padding=(deconv_kernel_size - 2) // 2, groups=grid_points) self.norm1 = nn.GroupNorm(grid_points, self.conv_out_channels) self.deconv2 = nn.ConvTranspose2d( self.conv_out_channels, grid_points, kernel_size=deconv_kernel_size, stride=2, padding=(deconv_kernel_size - 2) // 2, groups=grid_points) # find the 4-neighbor of each grid point self.neighbor_points = [] grid_size = self.grid_size for i in range(grid_size): # i-th column for j in range(grid_size): # j-th row neighbors = [] if i > 0: # left: (i - 1, j) neighbors.append((i - 1) * grid_size + j) if j > 0: # up: (i, j - 1) neighbors.append(i * grid_size + j - 1) if j < grid_size - 1: # down: (i, j + 1) neighbors.append(i * grid_size + j + 1) if i < grid_size - 1: # right: (i + 1, j) neighbors.append((i + 1) * grid_size + j) self.neighbor_points.append(tuple(neighbors)) # total edges in the grid self.num_edges = sum([len(p) for p in self.neighbor_points]) self.forder_trans = nn.ModuleList() # first-order feature transition self.sorder_trans = nn.ModuleList() # second-order feature transition for neighbors in self.neighbor_points: fo_trans = nn.ModuleList() so_trans = nn.ModuleList() for _ in range(len(neighbors)): # each transition module consists of a 5x5 depth-wise conv and # 1x1 conv. fo_trans.append( nn.Sequential( nn.Conv2d( self.point_feat_channels, self.point_feat_channels, 5, stride=1, padding=2, groups=self.point_feat_channels), nn.Conv2d(self.point_feat_channels, self.point_feat_channels, 1))) so_trans.append( nn.Sequential( nn.Conv2d( self.point_feat_channels, self.point_feat_channels, 5, 1, 2, groups=self.point_feat_channels), nn.Conv2d(self.point_feat_channels, self.point_feat_channels, 1))) self.forder_trans.append(fo_trans) self.sorder_trans.append(so_trans) self.loss_grid = build_loss(loss_grid)
[docs] def forward(self, x): assert x.shape[-1] == x.shape[-2] == self.roi_feat_size # RoI feature transformation, downsample 2x x = self.convs(x) c = self.point_feat_channels # first-order fusion x_fo = [None for _ in range(self.grid_points)] for i, points in enumerate(self.neighbor_points): x_fo[i] = x[:, i * c:(i + 1) * c] for j, point_idx in enumerate(points): x_fo[i] = x_fo[i] + self.forder_trans[i][j]( x[:, point_idx * c:(point_idx + 1) * c]) # second-order fusion x_so = [None for _ in range(self.grid_points)] for i, points in enumerate(self.neighbor_points): x_so[i] = x[:, i * c:(i + 1) * c] for j, point_idx in enumerate(points): x_so[i] = x_so[i] + self.sorder_trans[i][j](x_fo[point_idx]) # predicted heatmap with fused features x2 = torch.cat(x_so, dim=1) x2 = self.deconv1(x2) x2 = F.relu(self.norm1(x2), inplace=True) heatmap = self.deconv2(x2) # predicted heatmap with original features (applicable during training) if self.training: x1 = x x1 = self.deconv1(x1) x1 = F.relu(self.norm1(x1), inplace=True) heatmap_unfused = self.deconv2(x1) else: heatmap_unfused = heatmap return dict(fused=heatmap, unfused=heatmap_unfused)
[docs] def calc_sub_regions(self): """Compute point specific representation regions. See Grid R-CNN Plus (https://arxiv.org/abs/1906.05688) for details. """ # to make it consistent with the original implementation, half_size # is computed as 2 * quarter_size, which is smaller half_size = self.whole_map_size // 4 * 2 sub_regions = [] for i in range(self.grid_points): x_idx = i // self.grid_size y_idx = i % self.grid_size if x_idx == 0: sub_x1 = 0 elif x_idx == self.grid_size - 1: sub_x1 = half_size else: ratio = x_idx / (self.grid_size - 1) - 0.25 sub_x1 = max(int(ratio * self.whole_map_size), 0) if y_idx == 0: sub_y1 = 0 elif y_idx == self.grid_size - 1: sub_y1 = half_size else: ratio = y_idx / (self.grid_size - 1) - 0.25 sub_y1 = max(int(ratio * self.whole_map_size), 0) sub_regions.append( (sub_x1, sub_y1, sub_x1 + half_size, sub_y1 + half_size)) return sub_regions
def get_targets(self, sampling_results, rcnn_train_cfg): # mix all samples (across images) together. pos_bboxes = torch.cat([res.pos_bboxes for res in sampling_results], dim=0).cpu() pos_gt_bboxes = torch.cat( [res.pos_gt_bboxes for res in sampling_results], dim=0).cpu() assert pos_bboxes.shape == pos_gt_bboxes.shape # expand pos_bboxes to 2x of original size x1 = pos_bboxes[:, 0] - (pos_bboxes[:, 2] - pos_bboxes[:, 0]) / 2 y1 = pos_bboxes[:, 1] - (pos_bboxes[:, 3] - pos_bboxes[:, 1]) / 2 x2 = pos_bboxes[:, 2] + (pos_bboxes[:, 2] - pos_bboxes[:, 0]) / 2 y2 = pos_bboxes[:, 3] + (pos_bboxes[:, 3] - pos_bboxes[:, 1]) / 2 pos_bboxes = torch.stack([x1, y1, x2, y2], dim=-1) pos_bbox_ws = (pos_bboxes[:, 2] - pos_bboxes[:, 0]).unsqueeze(-1) pos_bbox_hs = (pos_bboxes[:, 3] - pos_bboxes[:, 1]).unsqueeze(-1) num_rois = pos_bboxes.shape[0] map_size = self.whole_map_size # this is not the final target shape targets = torch.zeros((num_rois, self.grid_points, map_size, map_size), dtype=torch.float) # pre-compute interpolation factors for all grid points. # the first item is the factor of x-dim, and the second is y-dim. # for a 9-point grid, factors are like (1, 0), (0.5, 0.5), (0, 1) factors = [] for j in range(self.grid_points): x_idx = j // self.grid_size y_idx = j % self.grid_size factors.append((1 - x_idx / (self.grid_size - 1), 1 - y_idx / (self.grid_size - 1))) radius = rcnn_train_cfg.pos_radius radius2 = radius**2 for i in range(num_rois): # ignore small bboxes if (pos_bbox_ws[i] <= self.grid_size or pos_bbox_hs[i] <= self.grid_size): continue # for each grid point, mark a small circle as positive for j in range(self.grid_points): factor_x, factor_y = factors[j] gridpoint_x = factor_x * pos_gt_bboxes[i, 0] + ( 1 - factor_x) * pos_gt_bboxes[i, 2] gridpoint_y = factor_y * pos_gt_bboxes[i, 1] + ( 1 - factor_y) * pos_gt_bboxes[i, 3] cx = int((gridpoint_x - pos_bboxes[i, 0]) / pos_bbox_ws[i] * map_size) cy = int((gridpoint_y - pos_bboxes[i, 1]) / pos_bbox_hs[i] * map_size) for x in range(cx - radius, cx + radius + 1): for y in range(cy - radius, cy + radius + 1): if x >= 0 and x < map_size and y >= 0 and y < map_size: if (x - cx)**2 + (y - cy)**2 <= radius2: targets[i, j, y, x] = 1 # reduce the target heatmap size by a half # proposed in Grid R-CNN Plus (https://arxiv.org/abs/1906.05688). sub_targets = [] for i in range(self.grid_points): sub_x1, sub_y1, sub_x2, sub_y2 = self.sub_regions[i] sub_targets.append(targets[:, [i], sub_y1:sub_y2, sub_x1:sub_x2]) sub_targets = torch.cat(sub_targets, dim=1) sub_targets = sub_targets.to(sampling_results[0].pos_bboxes.device) return sub_targets def loss(self, grid_pred, grid_targets): loss_fused = self.loss_grid(grid_pred['fused'], grid_targets) loss_unfused = self.loss_grid(grid_pred['unfused'], grid_targets) loss_grid = loss_fused + loss_unfused return dict(loss_grid=loss_grid) def get_bboxes(self, det_bboxes, grid_pred, img_metas): # TODO: refactoring assert det_bboxes.shape[0] == grid_pred.shape[0] det_bboxes = det_bboxes.cpu() cls_scores = det_bboxes[:, [4]] det_bboxes = det_bboxes[:, :4] grid_pred = grid_pred.sigmoid().cpu() R, c, h, w = grid_pred.shape half_size = self.whole_map_size // 4 * 2 assert h == w == half_size assert c == self.grid_points # find the point with max scores in the half-sized heatmap grid_pred = grid_pred.view(R * c, h * w) pred_scores, pred_position = grid_pred.max(dim=1) xs = pred_position % w ys = pred_position // w # get the position in the whole heatmap instead of half-sized heatmap for i in range(self.grid_points): xs[i::self.grid_points] += self.sub_regions[i][0] ys[i::self.grid_points] += self.sub_regions[i][1] # reshape to (num_rois, grid_points) pred_scores, xs, ys = tuple( map(lambda x: x.view(R, c), [pred_scores, xs, ys])) # get expanded pos_bboxes widths = (det_bboxes[:, 2] - det_bboxes[:, 0]).unsqueeze(-1) heights = (det_bboxes[:, 3] - det_bboxes[:, 1]).unsqueeze(-1) x1 = (det_bboxes[:, 0, None] - widths / 2) y1 = (det_bboxes[:, 1, None] - heights / 2) # map the grid point to the absolute coordinates abs_xs = (xs.float() + 0.5) / w * widths + x1 abs_ys = (ys.float() + 0.5) / h * heights + y1 # get the grid points indices that fall on the bbox boundaries x1_inds = [i for i in range(self.grid_size)] y1_inds = [i * self.grid_size for i in range(self.grid_size)] x2_inds = [ self.grid_points - self.grid_size + i for i in range(self.grid_size) ] y2_inds = [(i + 1) * self.grid_size - 1 for i in range(self.grid_size)] # voting of all grid points on some boundary bboxes_x1 = (abs_xs[:, x1_inds] * pred_scores[:, x1_inds]).sum( dim=1, keepdim=True) / ( pred_scores[:, x1_inds].sum(dim=1, keepdim=True)) bboxes_y1 = (abs_ys[:, y1_inds] * pred_scores[:, y1_inds]).sum( dim=1, keepdim=True) / ( pred_scores[:, y1_inds].sum(dim=1, keepdim=True)) bboxes_x2 = (abs_xs[:, x2_inds] * pred_scores[:, x2_inds]).sum( dim=1, keepdim=True) / ( pred_scores[:, x2_inds].sum(dim=1, keepdim=True)) bboxes_y2 = (abs_ys[:, y2_inds] * pred_scores[:, y2_inds]).sum( dim=1, keepdim=True) / ( pred_scores[:, y2_inds].sum(dim=1, keepdim=True)) bbox_res = torch.cat( [bboxes_x1, bboxes_y1, bboxes_x2, bboxes_y2, cls_scores], dim=1) bbox_res[:, [0, 2]].clamp_(min=0, max=img_metas[0]['img_shape'][1]) bbox_res[:, [1, 3]].clamp_(min=0, max=img_metas[0]['img_shape'][0]) return bbox_res
Read the Docs v: v2.18.1
Versions
latest
stable
v2.18.1
v2.18.0
v2.17.0
v2.16.0
v2.15.1
v2.15.0
v2.14.0
v2.13.0
v2.12.0
v2.11.0
v2.10.0
v2.9.0
v2.8.0
v2.7.0
v2.6.0
v2.5.0
v2.4.0
v2.3.0
v2.2.1
v2.2.0
v2.1.0
v2.0.0
v1.2.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.