Note
You are reading the documentation for MMClassification 0.x, which will soon be deprecated at the end of 2022. We recommend you upgrade to MMClassification 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check the installation tutorial, migration tutorial and changelog for more details.
Source code for mmcls.models.backbones.timm_backbone
# Copyright (c) OpenMMLab. All rights reserved.
try:
import timm
except ImportError:
timm = None
import warnings
from mmcv.cnn.bricks.registry import NORM_LAYERS
from ...utils import get_root_logger
from ..builder import BACKBONES
from .base_backbone import BaseBackbone
def print_timm_feature_info(feature_info):
"""Print feature_info of timm backbone to help development and debug.
Args:
feature_info (list[dict] | timm.models.features.FeatureInfo | None):
feature_info of timm backbone.
"""
logger = get_root_logger()
if feature_info is None:
logger.warning('This backbone does not have feature_info')
elif isinstance(feature_info, list):
for feat_idx, each_info in enumerate(feature_info):
logger.info(f'backbone feature_info[{feat_idx}]: {each_info}')
else:
try:
logger.info(f'backbone out_indices: {feature_info.out_indices}')
logger.info(f'backbone out_channels: {feature_info.channels()}')
logger.info(f'backbone out_strides: {feature_info.reduction()}')
except AttributeError:
logger.warning('Unexpected format of backbone feature_info')
[docs]@BACKBONES.register_module()
class TIMMBackbone(BaseBackbone):
"""Wrapper to use backbones from timm library.
More details can be found in
`timm <https://github.com/rwightman/pytorch-image-models>`_.
See especially the document for `feature extraction
<https://rwightman.github.io/pytorch-image-models/feature_extraction/>`_.
Args:
model_name (str): Name of timm model to instantiate.
features_only (bool): Whether to extract feature pyramid (multi-scale
feature maps from the deepest layer at each stride). For Vision
Transformer models that do not support this argument,
set this False. Defaults to False.
pretrained (bool): Whether to load pretrained weights.
Defaults to False.
checkpoint_path (str): Path of checkpoint to load at the last of
``timm.create_model``. Defaults to empty string, which means
not loading.
in_channels (int): Number of input image channels. Defaults to 3.
init_cfg (dict or list[dict], optional): Initialization config dict of
OpenMMLab projects. Defaults to None.
**kwargs: Other timm & model specific arguments.
"""
def __init__(self,
model_name,
features_only=False,
pretrained=False,
checkpoint_path='',
in_channels=3,
init_cfg=None,
**kwargs):
if timm is None:
raise RuntimeError(
'Failed to import timm. Please run "pip install timm". '
'"pip install dataclasses" may also be needed for Python 3.6.')
if not isinstance(pretrained, bool):
raise TypeError('pretrained must be bool, not str for model path')
if features_only and checkpoint_path:
warnings.warn(
'Using both features_only and checkpoint_path will cause error'
' in timm. See '
'https://github.com/rwightman/pytorch-image-models/issues/488')
super(TIMMBackbone, self).__init__(init_cfg)
if 'norm_layer' in kwargs:
kwargs['norm_layer'] = NORM_LAYERS.get(kwargs['norm_layer'])
self.timm_model = timm.create_model(
model_name=model_name,
features_only=features_only,
pretrained=pretrained,
in_chans=in_channels,
checkpoint_path=checkpoint_path,
**kwargs)
# reset classifier
if hasattr(self.timm_model, 'reset_classifier'):
self.timm_model.reset_classifier(0, '')
# Hack to use pretrained weights from timm
if pretrained or checkpoint_path:
self._is_init = True
feature_info = getattr(self.timm_model, 'feature_info', None)
print_timm_feature_info(feature_info)
[docs] def forward(self, x):
features = self.timm_model(x)
if isinstance(features, (list, tuple)):
features = tuple(features)
else:
features = (features, )
return features