Shortcuts

mmdet.utils.split_batch 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import torch


[文档]def split_batch(img, img_metas, kwargs): """Split data_batch by tags. Code is modified from <https://github.com/microsoft/SoftTeacher/blob/main/ssod/utils/structure_utils.py> # noqa: E501 Args: img (Tensor): of shape (N, C, H, W) encoding input images. Typically these should be mean centered and std scaled. img_metas (list[dict]): List of image info dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details on the values of these keys, see :class:`mmdet.datasets.pipelines.Collect`. kwargs (dict): Specific to concrete implementation. Returns: data_groups (dict): a dict that data_batch splited by tags, such as 'sup', 'unsup_teacher', and 'unsup_student'. """ # only stack img in the batch def fuse_list(obj_list, obj): return torch.stack(obj_list) if isinstance(obj, torch.Tensor) else obj_list # select data with tag from data_batch def select_group(data_batch, current_tag): group_flag = [tag == current_tag for tag in data_batch['tag']] return { k: fuse_list([vv for vv, gf in zip(v, group_flag) if gf], v) for k, v in data_batch.items() } kwargs.update({'img': img, 'img_metas': img_metas}) kwargs.update({'tag': [meta['tag'] for meta in img_metas]}) tags = list(set(kwargs['tag'])) data_groups = {tag: select_group(kwargs, tag) for tag in tags} for tag, group in data_groups.items(): group.pop('tag') return data_groups
Read the Docs v: 3.x
Versions
latest
stable
3.x
v2.28.2
v2.28.1
v2.28.0
v2.27.0
v2.26.0
v2.25.3
v2.25.2
v2.25.1
v2.25.0
v2.24.1
v2.24.0
v2.23.0
v2.22.0
v2.21.0
v2.20.0
v2.19.1
v2.19.0
v2.18.1
v2.18.0
v2.17.0
v2.16.0
v2.15.1
v2.15.0
v2.14.0
v2.13.0
dev-3.x
dev
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.