Note
You are reading the documentation for MMClassification 0.x, which will soon be deprecated at the end of 2022. We recommend you upgrade to MMClassification 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check the installation tutorial, migration tutorial and changelog for more details.
VAN¶
- class mmcls.models.VAN(arch='tiny', patch_sizes=[7, 3, 3, 3], in_channels=3, drop_rate=0.0, drop_path_rate=0.0, out_indices=(3,), frozen_stages=- 1, norm_eval=False, norm_cfg={'type': 'LN'}, block_cfgs={}, init_cfg=None)[source]¶
Visual Attention Network.
A PyTorch implement of : Visual Attention Network
Inspiration from https://github.com/Visual-Attention-Network/VAN-Classification
- Parameters
Visual Attention Network architecture. If use string, choose from ‘b0’, ‘b1’, b2’, b3’ and etc., if use dict, it should have below keys:
embed_dims (List[int]): The dimensions of embedding.
depths (List[int]): The number of blocks in each stage.
ffn_ratios (List[int]): The number of expansion ratio of feedforward network hidden layer channels.
Defaults to ‘tiny’.
patch_sizes (List[int | tuple]) – The patch size in patch embeddings. Defaults to [7, 3, 3, 3].
in_channels (int) – The num of input channels. Defaults to 3.
drop_rate (float) – Dropout rate after embedding. Defaults to 0.
drop_path_rate (float) – Stochastic depth rate. Defaults to 0.1.
out_indices (Sequence[int]) – Output from which stages. Default:
(3, )
.frozen_stages (int) – Stages to be frozen (stop grad and set eval mode). -1 means not freezing any parameters. Defaults to -1.
norm_eval (bool) – Whether to set norm layers to eval mode, namely, freeze running stats (mean and var). Note: Effect on Batch Norm and its variants only. Defaults to False.
norm_cfg (dict) – Config dict for normalization layer for all output features. Defaults to
dict(type='LN')
block_cfgs (Sequence[dict] | dict) – The extra config of each block. Defaults to empty dicts.
init_cfg (dict, optional) – The Config for initialization. Defaults to None.
Examples
>>> from mmcls.models import VAN >>> import torch >>> model = VAN(arch='b0') >>> inputs = torch.rand(1, 3, 224, 224) >>> outputs = model(inputs) >>> for out in outputs: >>> print(out.size()) (1, 256, 7, 7)