备注
您正在阅读 MMClassification 0.x 版本的文档。MMClassification 0.x 会在 2022 年末被切换为次要分支。建议您升级到 MMClassification 1.0 版本,体验更多新特性和新功能。请查阅 MMClassification 1.0 的安装教程、迁移教程以及更新日志。
mmcls.core.hook.class_num_check_hook 源代码
# Copyright (c) OpenMMLab. All rights reserved
from mmcv.runner import IterBasedRunner
from mmcv.runner.hooks import HOOKS, Hook
from mmcv.utils import is_seq_of
[文档]@HOOKS.register_module()
class ClassNumCheckHook(Hook):
def _check_head(self, runner, dataset):
"""Check whether the `num_classes` in head matches the length of
`CLASSES` in `dataset`.
Args:
runner (obj:`EpochBasedRunner`, `IterBasedRunner`): runner object.
dataset (obj: `BaseDataset`): the dataset to check.
"""
model = runner.model
if dataset.CLASSES is None:
runner.logger.warning(
f'Please set `CLASSES` '
f'in the {dataset.__class__.__name__} and'
f'check if it is consistent with the `num_classes` '
f'of head')
else:
assert is_seq_of(dataset.CLASSES, str), \
(f'`CLASSES` in {dataset.__class__.__name__}'
f'should be a tuple of str.')
for name, module in model.named_modules():
if hasattr(module, 'num_classes'):
assert module.num_classes == len(dataset.CLASSES), \
(f'The `num_classes` ({module.num_classes}) in '
f'{module.__class__.__name__} of '
f'{model.__class__.__name__} does not matches '
f'the length of `CLASSES` '
f'{len(dataset.CLASSES)}) in '
f'{dataset.__class__.__name__}')
def before_train_iter(self, runner):
"""Check whether the training dataset is compatible with head.
Args:
runner (obj: `IterBasedRunner`): Iter based Runner.
"""
if not isinstance(runner, IterBasedRunner):
return
self._check_head(runner, runner.data_loader._dataloader.dataset)
def before_val_iter(self, runner):
"""Check whether the eval dataset is compatible with head.
Args:
runner (obj:`IterBasedRunner`): Iter based Runner.
"""
if not isinstance(runner, IterBasedRunner):
return
self._check_head(runner, runner.data_loader._dataloader.dataset)
def before_train_epoch(self, runner):
"""Check whether the training dataset is compatible with head.
Args:
runner (obj:`EpochBasedRunner`): Epoch based Runner.
"""
self._check_head(runner, runner.data_loader.dataset)
def before_val_epoch(self, runner):
"""Check whether the eval dataset is compatible with head.
Args:
runner (obj:`EpochBasedRunner`): Epoch based Runner.
"""
self._check_head(runner, runner.data_loader.dataset)