Shortcuts

备注

您正在阅读 MMClassification 0.x 版本的文档。MMClassification 0.x 会在 2022 年末被切换为次要分支。建议您升级到 MMClassification 1.0 版本,体验更多新特性和新功能。请查阅 MMClassification 1.0 的安装教程迁移教程以及更新日志

mmcls.models.VAN

class mmcls.models.VAN(arch='tiny', patch_sizes=[7, 3, 3, 3], in_channels=3, drop_rate=0.0, drop_path_rate=0.0, out_indices=(3,), frozen_stages=- 1, norm_eval=False, norm_cfg={'type': 'LN'}, block_cfgs={}, init_cfg=None)[源代码]

Visual Attention Network.

A PyTorch implement of : Visual Attention Network

Inspiration from https://github.com/Visual-Attention-Network/VAN-Classification

参数
  • arch (str | dict) –

    Visual Attention Network architecture. If use string, choose from ‘b0’, ‘b1’, b2’, b3’ and etc., if use dict, it should have below keys:

    • embed_dims (List[int]): The dimensions of embedding.

    • depths (List[int]): The number of blocks in each stage.

    • ffn_ratios (List[int]): The number of expansion ratio of feedforward network hidden layer channels.

    Defaults to ‘tiny’.

  • patch_sizes (List[int | tuple]) – The patch size in patch embeddings. Defaults to [7, 3, 3, 3].

  • in_channels (int) – The num of input channels. Defaults to 3.

  • drop_rate (float) – Dropout rate after embedding. Defaults to 0.

  • drop_path_rate (float) – Stochastic depth rate. Defaults to 0.1.

  • out_indices (Sequence[int]) – Output from which stages. Default: (3, ).

  • frozen_stages (int) – Stages to be frozen (stop grad and set eval mode). -1 means not freezing any parameters. Defaults to -1.

  • norm_eval (bool) – Whether to set norm layers to eval mode, namely, freeze running stats (mean and var). Note: Effect on Batch Norm and its variants only. Defaults to False.

  • norm_cfg (dict) – Config dict for normalization layer for all output features. Defaults to dict(type='LN')

  • block_cfgs (Sequence[dict] | dict) – The extra config of each block. Defaults to empty dicts.

  • init_cfg (dict, optional) – The Config for initialization. Defaults to None.

实际案例

>>> from mmcls.models import VAN
>>> import torch
>>> model = VAN(arch='b0')
>>> inputs = torch.rand(1, 3, 224, 224)
>>> outputs = model(inputs)
>>> for out in outputs:
>>>     print(out.size())
(1, 256, 7, 7)
Read the Docs v: latest
Versions
master
latest
1.x
dev-1.x
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.