Source code for mmdet.models.roi_heads.base_roi_head

from abc import ABCMeta, abstractmethod

import torch.nn as nn

from ..builder import build_shared_head


[docs]class BaseRoIHead(nn.Module, metaclass=ABCMeta): """Base class for RoIHeads""" def __init__(self, bbox_roi_extractor=None, bbox_head=None, mask_roi_extractor=None, mask_head=None, shared_head=None, train_cfg=None, test_cfg=None): super(BaseRoIHead, self).__init__() self.train_cfg = train_cfg self.test_cfg = test_cfg if shared_head is not None: self.shared_head = build_shared_head(shared_head) if bbox_head is not None: self.init_bbox_head(bbox_roi_extractor, bbox_head) if mask_head is not None: self.init_mask_head(mask_roi_extractor, mask_head) self.init_assigner_sampler() @property def with_bbox(self): return hasattr(self, 'bbox_head') and self.bbox_head is not None @property def with_mask(self): return hasattr(self, 'mask_head') and self.mask_head is not None @property def with_shared_head(self): return hasattr(self, 'shared_head') and self.shared_head is not None @abstractmethod def init_weights(self, pretrained): pass @abstractmethod def init_bbox_head(self): pass @abstractmethod def init_mask_head(self): pass @abstractmethod def init_assigner_sampler(self): pass
[docs] @abstractmethod def forward_train(self, x, img_meta, proposal_list, gt_bboxes, gt_labels, gt_bboxes_ignore=None, gt_masks=None, **kwargs): """Forward function during training""" pass
async def async_simple_test(self, x, img_meta, **kwargs): raise NotImplementedError
[docs] def simple_test(self, x, proposal_list, img_meta, proposals=None, rescale=False, **kwargs): """Test without augmentation.""" pass
[docs] def aug_test(self, x, proposal_list, img_metas, rescale=False, **kwargs): """Test with augmentations. If rescale is False, then returned bboxes and masks will fit the scale of imgs[0]. """ pass