Source code for mmdet.models.detectors.detr

import torch

from ..builder import DETECTORS
from .single_stage import SingleStageDetector


[docs]@DETECTORS.register_module() class DETR(SingleStageDetector): r"""Implementation of `DETR: End-to-End Object Detection with Transformers <https://arxiv.org/pdf/2005.12872>`_""" def __init__(self, backbone, bbox_head, train_cfg=None, test_cfg=None, pretrained=None, init_cfg=None): super(DETR, self).__init__(backbone, None, bbox_head, train_cfg, test_cfg, pretrained, init_cfg) # over-write `onnx_export` because: # (1) the forward of bbox_head requires img_metas # (2) the different behavior (e.g. construction of `masks`) between # torch and ONNX model, during the forward of bbox_head
[docs] def onnx_export(self, img, img_metas): """Test function for exporting to ONNX, without test time augmentation. Args: img (torch.Tensor): input images. img_metas (list[dict]): List of image information. Returns: tuple[Tensor, Tensor]: dets of shape [N, num_det, 5] and class labels of shape [N, num_det]. """ x = self.extract_feat(img) # forward of this head requires img_metas outs = self.bbox_head.forward_onnx(x, img_metas) # get shape as tensor img_shape = torch._shape_as_tensor(img)[2:] img_metas[0]['img_shape_for_onnx'] = img_shape det_bboxes, det_labels = self.bbox_head.onnx_export(*outs, img_metas) return det_bboxes, det_labels