Shortcuts

Source code for mmpretrain.models.backbones.timm_backbone

# Copyright (c) OpenMMLab. All rights reserved.
import warnings

from mmengine.logging import MMLogger

from mmpretrain.registry import MODELS
from mmpretrain.utils import require
from .base_backbone import BaseBackbone


def print_timm_feature_info(feature_info):
    """Print feature_info of timm backbone to help development and debug.

    Args:
        feature_info (list[dict] | timm.models.features.FeatureInfo | None):
            feature_info of timm backbone.
    """
    logger = MMLogger.get_current_instance()
    if feature_info is None:
        logger.warning('This backbone does not have feature_info')
    elif isinstance(feature_info, list):
        for feat_idx, each_info in enumerate(feature_info):
            logger.info(f'backbone feature_info[{feat_idx}]: {each_info}')
    else:
        try:
            logger.info(f'backbone out_indices: {feature_info.out_indices}')
            logger.info(f'backbone out_channels: {feature_info.channels()}')
            logger.info(f'backbone out_strides: {feature_info.reduction()}')
        except AttributeError:
            logger.warning('Unexpected format of backbone feature_info')


[docs]@MODELS.register_module() class TIMMBackbone(BaseBackbone): """Wrapper to use backbones from timm library. More details can be found in `timm <https://github.com/rwightman/pytorch-image-models>`_. See especially the document for `feature extraction <https://rwightman.github.io/pytorch-image-models/feature_extraction/>`_. Args: model_name (str): Name of timm model to instantiate. features_only (bool): Whether to extract feature pyramid (multi-scale feature maps from the deepest layer at each stride). For Vision Transformer models that do not support this argument, set this False. Defaults to False. pretrained (bool): Whether to load pretrained weights. Defaults to False. checkpoint_path (str): Path of checkpoint to load at the last of ``timm.create_model``. Defaults to empty string, which means not loading. in_channels (int): Number of input image channels. Defaults to 3. init_cfg (dict or list[dict], optional): Initialization config dict of OpenMMLab projects. Defaults to None. **kwargs: Other timm & model specific arguments. """ @require('timm') def __init__(self, model_name, features_only=False, pretrained=False, checkpoint_path='', in_channels=3, init_cfg=None, **kwargs): import timm if not isinstance(pretrained, bool): raise TypeError('pretrained must be bool, not str for model path') if features_only and checkpoint_path: warnings.warn( 'Using both features_only and checkpoint_path will cause error' ' in timm. See ' 'https://github.com/rwightman/pytorch-image-models/issues/488') super(TIMMBackbone, self).__init__(init_cfg) if 'norm_layer' in kwargs: norm_class = MODELS.get(kwargs['norm_layer']) def build_norm(*args, **kwargs): return norm_class(*args, **kwargs) kwargs['norm_layer'] = build_norm self.timm_model = timm.create_model( model_name=model_name, features_only=features_only, pretrained=pretrained, in_chans=in_channels, checkpoint_path=checkpoint_path, **kwargs) # reset classifier if hasattr(self.timm_model, 'reset_classifier'): self.timm_model.reset_classifier(0, '') # Hack to use pretrained weights from timm if pretrained or checkpoint_path: self._is_init = True feature_info = getattr(self.timm_model, 'feature_info', None) print_timm_feature_info(feature_info) def forward(self, x): features = self.timm_model(x) if isinstance(features, (list, tuple)): features = tuple(features) else: features = (features, ) return features
Read the Docs v: latest
Versions
latest
stable
mmcls-1.x
mmcls-0.x
dev
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.