Shortcuts

mmcls.apis

该包提供了一些用于分类任务的高阶 API

mmcls.apis

Model

mmcls.apis.list_models(pattern=None)[源代码]

List all models available in MMClassification.

参数

pattern (str | None) – A wildcard pattern to match model names.

返回

a list of model names.

返回类型

List[str]

使用示例

List all models:

>>> from mmcls import list_models
>>> print(list_models())

List ResNet-50 models on ImageNet-1k dataset:

>>> from mmcls import list_models
>>> print(list_models('resnet*in1k'))
['resnet50_8xb32_in1k',
 'resnet50_8xb32-fp16_in1k',
 'resnet50_8xb256-rsb-a1-600e_in1k',
 'resnet50_8xb256-rsb-a2-300e_in1k',
 'resnet50_8xb256-rsb-a3-100e_in1k']
mmcls.apis.get_model(model_name, pretrained=False, device=None, **kwargs)[源代码]

Get a pre-defined model by the name of model.

参数
  • model_name (str) – The name of model.

  • pretrained (bool | str) – If True, load the pre-defined pretrained weights. If a string, load the weights from it. Defaults to False.

  • device (str | torch.device | None) – Transfer the model to the target device. Defaults to None.

  • **kwargs – Other keyword arguments of the model config.

返回

The result model.

返回类型

mmengine.model.BaseModel

使用示例

Get a ResNet-50 model and extract images feature:

>>> import torch
>>> from mmcls import get_model
>>> inputs = torch.rand(16, 3, 224, 224)
>>> model = get_model('resnet50_8xb32_in1k', pretrained=True, backbone=dict(out_indices=(0, 1, 2, 3)))
>>> feats = model.extract_feat(inputs)
>>> for feat in feats:
...     print(feat.shape)
torch.Size([16, 256])
torch.Size([16, 512])
torch.Size([16, 1024])
torch.Size([16, 2048])

Get Swin-Transformer model with pre-trained weights and inference:

>>> from mmcls import get_model, inference_model
>>> model = get_model('swin-base_16xb64_in1k', pretrained=True)
>>> result = inference_model(model, 'demo/demo.JPEG')
>>> print(result['pred_class'])
'sea snake'
mmcls.apis.init_model(config, checkpoint=None, device=None, **kwargs)[源代码]

从配置文件初始化一个分类器

参数
  • config (str | mmengine.Config) – Config file path or the config object.

  • checkpoint (str, optional) – Checkpoint path. If left as None, the model will not load any weights.

  • device (str | torch.device | None) – Transfer the model to the target device. Defaults to None.

  • **kwargs – Other keyword arguments of the model config.

返回

The constructed model.

返回类型

nn.Module

推理

mmcls.apis.inference_model(model, img)[源代码]

使用分类器推理图像

参数
  • model (BaseClassifier) – The loaded classifier.

  • img (str/ndarray) – The image filename or loaded image.

返回

The classification results that contains

class_name, pred_label and pred_score.

返回类型

result (dict)

Read the Docs v: 1.x
Versions
master
latest
1.x
dev-1.x
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.