Shortcuts

Source code for mmdet.structures.reid_data_sample

# Copyright (c) OpenMMLab. All rights reserved.
from numbers import Number
from typing import Sequence, Union

import mmengine
import numpy as np
import torch
from mmengine.structures import BaseDataElement, LabelData


def format_label(value: Union[torch.Tensor, np.ndarray, Sequence, int],
                 num_classes: int = None) -> LabelData:
    """Convert label of various python types to :obj:`mmengine.LabelData`.

    Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
    :class:`Sequence`, :class:`int`.

    Args:
        value (torch.Tensor | numpy.ndarray | Sequence | int): Label value.
        num_classes (int, optional): The number of classes. If not None, set
            it to the metainfo. Defaults to None.

    Returns:
        :obj:`mmengine.LabelData`: The foramtted label data.
    """

    # Handle single number
    if isinstance(value, (torch.Tensor, np.ndarray)) and value.ndim == 0:
        value = int(value.item())

    if isinstance(value, np.ndarray):
        value = torch.from_numpy(value)
    elif isinstance(value, Sequence) and not mmengine.utils.is_str(value):
        value = torch.tensor(value)
    elif isinstance(value, int):
        value = torch.LongTensor([value])
    elif not isinstance(value, torch.Tensor):
        raise TypeError(f'Type {type(value)} is not an available label type.')

    metainfo = {}
    if num_classes is not None:
        metainfo['num_classes'] = num_classes
        if value.max() >= num_classes:
            raise ValueError(f'The label data ({value}) should not '
                             f'exceed num_classes ({num_classes}).')
    label = LabelData(label=value, metainfo=metainfo)
    return label


[docs]class ReIDDataSample(BaseDataElement): """A data structure interface of ReID task. It's used as interfaces between different components. Meta field: img_shape (Tuple): The shape of the corresponding input image. Used for visualization. ori_shape (Tuple): The original shape of the corresponding image. Used for visualization. num_classes (int): The number of all categories. Used for label format conversion. Data field: gt_label (LabelData): The ground truth label. pred_label (LabelData): The predicted label. scores (torch.Tensor): The outputs of model. """ @property def gt_label(self): return self._gt_label @gt_label.setter def gt_label(self, value: LabelData): self.set_field(value, '_gt_label', dtype=LabelData) @gt_label.deleter def gt_label(self): del self._gt_label
[docs] def set_gt_label( self, value: Union[np.ndarray, torch.Tensor, Sequence[Number], Number] ) -> 'ReIDDataSample': """Set label of ``gt_label``.""" label = format_label(value, self.get('num_classes')) if 'gt_label' in self: # setting for the second time self.gt_label.label = label.label else: # setting for the first time self.gt_label = label return self
[docs] def set_gt_score(self, value: torch.Tensor) -> 'ReIDDataSample': """Set score of ``gt_label``.""" assert isinstance(value, torch.Tensor), \ f'The value should be a torch.Tensor but got {type(value)}.' assert value.ndim == 1, \ f'The dims of value should be 1, but got {value.ndim}.' if 'num_classes' in self: assert value.size(0) == self.num_classes, \ f"The length of value ({value.size(0)}) doesn't "\ f'match the num_classes ({self.num_classes}).' metainfo = {'num_classes': self.num_classes} else: metainfo = {'num_classes': value.size(0)} if 'gt_label' in self: # setting for the second time self.gt_label.score = value else: # setting for the first time self.gt_label = LabelData(score=value, metainfo=metainfo) return self
@property def pred_feature(self): return self._pred_feature @pred_feature.setter def pred_feature(self, value: torch.Tensor): self.set_field(value, '_pred_feature', dtype=torch.Tensor) @pred_feature.deleter def pred_feature(self): del self._pred_feature
Read the Docs v: dev-3.x
Versions
latest
stable
3.x
v3.3.0
v3.2.0
v3.1.0
v3.0.0
v3.0.0rc0
v2.28.2
v2.28.1
v2.28.0
v2.27.0
v2.26.0
v2.25.3
v2.25.2
v2.25.1
v2.25.0
v2.24.1
v2.24.0
v2.23.0
v2.22.0
v2.21.0
v2.20.0
v2.19.1
v2.19.0
v2.18.1
v2.18.0
v2.17.0
v2.16.0
v2.15.1
v2.15.0
v2.14.0
v2.13.0
v2.12.0
v2.11.0
v2.10.0
v2.9.0
v2.8.0
v2.7.0
v2.6.0
v2.5.0
v2.4.0
v2.3.0
v2.2.1
v2.2.0
v2.1.0
v2.0.0
v1.2.0
test-3.0.0rc0
main
dev-3.x
dev
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.