Source code for mmdet.models.detectors.base

import warnings
from abc import ABCMeta, abstractmethod

import mmcv
import numpy as np
import torch.nn as nn
from mmcv.utils import print_log

from mmdet.core import auto_fp16
from mmdet.utils import get_root_logger


[docs]class BaseDetector(nn.Module, metaclass=ABCMeta): """Base class for detectors""" def __init__(self): super(BaseDetector, self).__init__() self.fp16_enabled = False @property def with_neck(self): return hasattr(self, 'neck') and self.neck is not None # TODO: these properties need to be carefully handled # for both single stage & two stage detectors @property def with_shared_head(self): return hasattr(self.roi_head, 'shared_head') and self.roi_head.shared_head is not None @property def with_bbox(self): return ((hasattr(self.roi_head, 'bbox_head') and self.roi_head.bbox_head is not None) or (hasattr(self, 'bbox_head') and self.bbox_head is not None)) @property def with_mask(self): return ((hasattr(self.roi_head, 'mask_head') and self.roi_head.mask_head is not None) or (hasattr(self, 'mask_head') and self.mask_head is not None)) @abstractmethod def extract_feat(self, imgs): pass def extract_feats(self, imgs): assert isinstance(imgs, list) return [self.extract_feat(img) for img in imgs]
[docs] @abstractmethod def forward_train(self, imgs, img_metas, **kwargs): """ Args: img (list[Tensor]): List of tensors of shape (1, C, H, W). Typically these should be mean centered and std scaled. img_metas (list[dict]): List of image info dict where each dict has: 'img_shape', 'scale_factor', 'flip', and my also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details on the values of these keys, see :class:`mmdet.datasets.pipelines.Collect`. kwargs (keyword arguments): Specific to concrete implementation. """ pass
async def async_simple_test(self, img, img_metas, **kwargs): raise NotImplementedError @abstractmethod def simple_test(self, img, img_metas, **kwargs): pass @abstractmethod def aug_test(self, imgs, img_metas, **kwargs): pass def init_weights(self, pretrained=None): if pretrained is not None: logger = get_root_logger() print_log(f'load model from: {pretrained}', logger=logger) async def aforward_test(self, *, img, img_metas, **kwargs): for var, name in [(img, 'img'), (img_metas, 'img_metas')]: if not isinstance(var, list): raise TypeError(f'{name} must be a list, but got {type(var)}') num_augs = len(img) if num_augs != len(img_metas): raise ValueError(f'num of augmentations ({len(img)}) ' f'!= num of image metas ({len(img_metas)})') # TODO: remove the restriction of samples_per_gpu == 1 when prepared samples_per_gpu = img[0].size(0) assert samples_per_gpu == 1 if num_augs == 1: return await self.async_simple_test(img[0], img_metas[0], **kwargs) else: raise NotImplementedError
[docs] def forward_test(self, imgs, img_metas, **kwargs): """ Args: imgs (List[Tensor]): the outer list indicates test-time augmentations and inner Tensor should have a shape NxCxHxW, which contains all images in the batch. img_metas (List[List[dict]]): the outer list indicates test-time augs (multiscale, flip, etc.) and the inner list indicates images in a batch. """ for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]: if not isinstance(var, list): raise TypeError(f'{name} must be a list, but got {type(var)}') num_augs = len(imgs) if num_augs != len(img_metas): raise ValueError(f'num of augmentations ({len(imgs)}) ' f'!= num of image meta ({len(img_metas)})') # TODO: remove the restriction of samples_per_gpu == 1 when prepared samples_per_gpu = imgs[0].size(0) assert samples_per_gpu == 1 if num_augs == 1: """ proposals (List[List[Tensor]]): the outer list indicates test-time augs (multiscale, flip, etc.) and the inner list indicates images in a batch. The Tensor should have a shape Px4, where P is the number of proposals. """ if 'proposals' in kwargs: kwargs['proposals'] = kwargs['proposals'][0] return self.simple_test(imgs[0], img_metas[0], **kwargs) else: # TODO: support test augmentation for predefined proposals assert 'proposals' not in kwargs return self.aug_test(imgs, img_metas, **kwargs)
[docs] @auto_fp16(apply_to=('img', )) def forward(self, img, img_metas, return_loss=True, **kwargs): """ Calls either forward_train or forward_test depending on whether return_loss=True. Note this setting will change the expected inputs. When `return_loss=True`, img and img_meta are single-nested (i.e. Tensor and List[dict]), and when `resturn_loss=False`, img and img_meta should be double nested (i.e. List[Tensor], List[List[dict]]), with the outer list indicating test time augmentations. """ if return_loss: return self.forward_train(img, img_metas, **kwargs) else: return self.forward_test(img, img_metas, **kwargs)
[docs] def show_result(self, img, result, score_thr=0.3, bbox_color='green', text_color='green', thickness=1, font_scale=0.5, win_name='', show=False, wait_time=0, out_file=None): """Draw `result` over `img`. Args: img (str or Tensor): The image to be displayed. result (Tensor or tuple): The results to draw over `img` bbox_result or (bbox_result, segm_result). score_thr (float, optional): Minimum score of bboxes to be shown. Default: 0.3. bbox_color (str or tuple or :obj:`Color`): Color of bbox lines. text_color (str or tuple or :obj:`Color`): Color of texts. thickness (int): Thickness of lines. font_scale (float): Font scales of texts. win_name (str): The window name. wait_time (int): Value of waitKey param. Default: 0. show (bool): Whether to show the image. Default: False. out_file (str or None): The filename to write the image. Default: None. Returns: img (Tensor): Only if not `show` or `out_file` """ img = mmcv.imread(img) img = img.copy() if isinstance(result, tuple): bbox_result, segm_result = result if isinstance(segm_result, tuple): segm_result = segm_result[0] # ms rcnn else: bbox_result, segm_result = result, None bboxes = np.vstack(bbox_result) labels = [ np.full(bbox.shape[0], i, dtype=np.int32) for i, bbox in enumerate(bbox_result) ] labels = np.concatenate(labels) # draw segmentation masks if segm_result is not None and len(labels) > 0: # non empty segms = mmcv.concat_list(segm_result) inds = np.where(bboxes[:, -1] > score_thr)[0] np.random.seed(42) color_masks = [ np.random.randint(0, 256, (1, 3), dtype=np.uint8) for _ in range(max(labels) + 1) ] for i in inds: i = int(i) color_mask = color_masks[labels[i]] mask = segms[i] img[mask] = img[mask] * 0.5 + color_mask * 0.5 # if out_file specified, do not show image in window if out_file is not None: show = False # draw bounding boxes mmcv.imshow_det_bboxes( img, bboxes, labels, class_names=self.CLASSES, score_thr=score_thr, bbox_color=bbox_color, text_color=text_color, thickness=thickness, font_scale=font_scale, win_name=win_name, show=show, wait_time=wait_time, out_file=out_file) if not (show or out_file): warnings.warn('show==False and out_file is not specified, only ' 'result image will be returned') return img