mmcls.apis¶
该包提供了一些用于分类任务的高阶 API
Model¶
- mmcls.apis.list_models(pattern=None)[源代码]¶
List all models available in MMClassification.
- 参数
pattern (str | None) – A wildcard pattern to match model names.
- 返回
a list of model names.
- 返回类型
List[str]
使用示例
List all models:
>>> from mmcls import list_models >>> print(list_models())
List ResNet-50 models on ImageNet-1k dataset:
>>> from mmcls import list_models >>> print(list_models('resnet*in1k')) ['resnet50_8xb32_in1k', 'resnet50_8xb32-fp16_in1k', 'resnet50_8xb256-rsb-a1-600e_in1k', 'resnet50_8xb256-rsb-a2-300e_in1k', 'resnet50_8xb256-rsb-a3-100e_in1k']
- mmcls.apis.get_model(model_name, pretrained=False, device=None, **kwargs)[源代码]¶
Get a pre-defined model by the name of model.
- 参数
model_name (str) – The name of model.
pretrained (bool | str) – If True, load the pre-defined pretrained weights. If a string, load the weights from it. Defaults to False.
device (str | torch.device | None) – Transfer the model to the target device. Defaults to None.
**kwargs – Other keyword arguments of the model config.
- 返回
The result model.
- 返回类型
使用示例
Get a ResNet-50 model and extract images feature:
>>> import torch >>> from mmcls import get_model >>> inputs = torch.rand(16, 3, 224, 224) >>> model = get_model('resnet50_8xb32_in1k', pretrained=True, backbone=dict(out_indices=(0, 1, 2, 3))) >>> feats = model.extract_feat(inputs) >>> for feat in feats: ... print(feat.shape) torch.Size([16, 256]) torch.Size([16, 512]) torch.Size([16, 1024]) torch.Size([16, 2048])
Get Swin-Transformer model with pre-trained weights and inference:
>>> from mmcls import get_model, inference_model >>> model = get_model('swin-base_16xb64_in1k', pretrained=True) >>> result = inference_model(model, 'demo/demo.JPEG') >>> print(result['pred_class']) 'sea snake'
- mmcls.apis.init_model(config, checkpoint=None, device=None, **kwargs)[源代码]¶
从配置文件初始化一个分类器
- 参数
config (str |
mmengine.Config
) – Config file path or the config object.checkpoint (str, optional) – Checkpoint path. If left as None, the model will not load any weights.
device (str | torch.device | None) – Transfer the model to the target device. Defaults to None.
**kwargs – Other keyword arguments of the model config.
- 返回
The constructed model.
- 返回类型
nn.Module
推理¶
- mmcls.apis.inference_model(model, img)[源代码]¶
使用分类器推理图像
- 参数
model (BaseClassifier) – The loaded classifier.
img (str/ndarray) – The image filename or loaded image.
- 返回
- The classification results that contains
class_name, pred_label and pred_score.
- 返回类型
result (dict)