Shortcuts

mmcls.models.PCPVT

class mmcls.models.PCPVT(arch, in_channels=3, out_indices=(3,), qkv_bias=False, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, norm_cfg={'type': 'LN'}, norm_after_stage=False, init_cfg=None)[源代码]

The backbone of Twins-PCPVT.

This backbone is the implementation of Twins: Revisiting the Design of Spatial Attention in Vision Transformers.

参数
  • arch (dict, str) –

    PCPVT architecture, a str value in arch zoo or a detailed configuration dict with 7 keys, and the length of all the values in dict should be the same:

    • depths (List[int]): The number of encoder layers in each stage.

    • embed_dims (List[int]): Embedding dimension in each stage.

    • patch_sizes (List[int]): The patch sizes in each stage.

    • num_heads (List[int]): Numbers of attention head in each stage.

    • strides (List[int]): The strides in each stage.

    • mlp_ratios (List[int]): The ratios of mlp in each stage.

    • sr_ratios (List[int]): The ratios of GSA-encoder layers in each

      stage.

  • in_channels (int) – Number of input channels. Default: 3.

  • out_indices (tuple[int]) – Output from which stages. Default: (3, ).

  • qkv_bias (bool) – Enable bias for qkv if True. Default: False.

  • drop_rate (float) – Probability of an element to be zeroed. Default 0.

  • attn_drop_rate (float) – The drop out rate for attention layer. Default 0.0

  • drop_path_rate (float) – Stochastic depth rate. Default 0.0

  • norm_cfg (dict) – Config dict for normalization layer. Default: dict(type=’LN’)

  • norm_after_stage (bool, List[bool]) – Add extra norm after each stage. Default False.

  • init_cfg (dict, optional) – The Config for initialization. Defaults to None.

实际案例

>>> from mmcls.models import PCPVT
>>> import torch
>>> pcpvt_cfg = {'arch': "small",
>>>              'norm_after_stage': [False, False, False, True]}
>>> model = PCPVT(**pcpvt_cfg)
>>> x = torch.rand(1, 3, 224, 224)
>>> outputs = model(x)
>>> print(outputs[-1].shape)
torch.Size([1, 512, 7, 7])
>>> pcpvt_cfg['norm_after_stage'] = [True, True, True, True]
>>> pcpvt_cfg['out_indices'] = (0, 1, 2, 3)
>>> model = PCPVT(**pcpvt_cfg)
>>> outputs = model(x)
>>> for feat in outputs:
>>>     print(feat.shape)
torch.Size([1, 64, 56, 56])
torch.Size([1, 128, 28, 28])
torch.Size([1, 320, 14, 14])
torch.Size([1, 512, 7, 7])
Read the Docs v: latest
Versions
master
latest
1.x
dev-1.x
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.