Source code for mmdet.models.roi_heads.base_roi_head

from abc import ABCMeta, abstractmethod

from mmcv.runner import BaseModule

from ..builder import build_shared_head


[docs]class BaseRoIHead(BaseModule, metaclass=ABCMeta): """Base class for RoIHeads.""" def __init__(self, bbox_roi_extractor=None, bbox_head=None, mask_roi_extractor=None, mask_head=None, shared_head=None, train_cfg=None, test_cfg=None, pretrained=None, init_cfg=None): super(BaseRoIHead, self).__init__(init_cfg) self.train_cfg = train_cfg self.test_cfg = test_cfg if shared_head is not None: shared_head.pretrained = pretrained self.shared_head = build_shared_head(shared_head) if bbox_head is not None: self.init_bbox_head(bbox_roi_extractor, bbox_head) if mask_head is not None: self.init_mask_head(mask_roi_extractor, mask_head) self.init_assigner_sampler() @property def with_bbox(self): """bool: whether the RoI head contains a `bbox_head`""" return hasattr(self, 'bbox_head') and self.bbox_head is not None @property def with_mask(self): """bool: whether the RoI head contains a `mask_head`""" return hasattr(self, 'mask_head') and self.mask_head is not None @property def with_shared_head(self): """bool: whether the RoI head contains a `shared_head`""" return hasattr(self, 'shared_head') and self.shared_head is not None
[docs] @abstractmethod def init_bbox_head(self): """Initialize ``bbox_head``""" pass
[docs] @abstractmethod def init_mask_head(self): """Initialize ``mask_head``""" pass
[docs] @abstractmethod def init_assigner_sampler(self): """Initialize assigner and sampler.""" pass
[docs] @abstractmethod def forward_train(self, x, img_meta, proposal_list, gt_bboxes, gt_labels, gt_bboxes_ignore=None, gt_masks=None, **kwargs): """Forward function during training."""
[docs] async def async_simple_test(self, x, proposal_list, img_metas, proposals=None, rescale=False, **kwargs): """Asynchronized test function.""" raise NotImplementedError
[docs] def simple_test(self, x, proposal_list, img_meta, proposals=None, rescale=False, **kwargs): """Test without augmentation."""
[docs] def aug_test(self, x, proposal_list, img_metas, rescale=False, **kwargs): """Test with augmentations. If rescale is False, then returned bboxes and masks will fit the scale of imgs[0]. """