mmcls.structures¶
This package includes basic data structures for classification tasks.
ClsDataSample¶
- class mmcls.structures.ClsDataSample(*, metainfo=None, **kwargs)[source]¶
A data structure interface of classification task.
It’s used as interfaces between different components.
- Meta fields
img_shape (Tuple) – The shape of the corresponding input image. Used for visualization.
ori_shape (Tuple) – The original shape of the corresponding image. Used for visualization.
num_classes (int) – The number of all categories. Used for label format conversion.
- Data fields
Examples
>>> import torch >>> from mmcls.structures import ClsDataSample >>> >>> img_meta = dict(img_shape=(960, 720), num_classes=5) >>> data_sample = ClsDataSample(metainfo=img_meta) >>> data_sample.set_gt_label(3) >>> print(data_sample) <ClsDataSample( META INFORMATION num_classes = 5 img_shape = (960, 720) DATA FIELDS gt_label: <LabelData( META INFORMATION num_classes: 5 DATA FIELDS label: tensor([3]) ) at 0x7f21fb1b9190> ) at 0x7f21fb1b9880> >>> # For multi-label data >>> data_sample.set_gt_label([0, 1, 4]) >>> print(data_sample.gt_label) <LabelData( META INFORMATION num_classes: 5 DATA FIELDS label: tensor([0, 1, 4]) ) at 0x7fd7d1b41970> >>> # Set one-hot format score >>> score = torch.tensor([0.1, 0.1, 0.6, 0.1, 0.1]) >>> data_sample.set_pred_score(score) >>> print(data_sample.pred_label) <LabelData( META INFORMATION num_classes: 5 DATA FIELDS score: tensor([0.1, 0.1, 0.6, 0.1, 0.1]) ) at 0x7fd7d1b41970>