You are reading the documentation for MMClassification 0.x, which will soon be deprecated at the end of 2022. We recommend you upgrade to MMClassification 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check the installation tutorial, migration tutorial and changelog for more details.

Source code for mmcls.models.heads.conformer_head

# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn.utils.weight_init import trunc_normal_

from ..builder import HEADS
from .cls_head import ClsHead

[docs]@HEADS.register_module() class ConformerHead(ClsHead): """Linear classifier head. Args: num_classes (int): Number of categories excluding the background category. in_channels (int): Number of channels in the input feature map. init_cfg (dict | optional): The extra init config of layers. Defaults to use ``dict(type='Normal', layer='Linear', std=0.01)``. """ def __init__( self, num_classes, in_channels, # [conv_dim, trans_dim] init_cfg=dict(type='Normal', layer='Linear', std=0.01), *args, **kwargs): super(ConformerHead, self).__init__(init_cfg=None, *args, **kwargs) self.in_channels = in_channels self.num_classes = num_classes self.init_cfg = init_cfg if self.num_classes <= 0: raise ValueError( f'num_classes={num_classes} must be a positive integer') self.conv_cls_head = nn.Linear(self.in_channels[0], num_classes) self.trans_cls_head = nn.Linear(self.in_channels[1], num_classes) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0)
[docs] def init_weights(self): super(ConformerHead, self).init_weights() if (isinstance(self.init_cfg, dict) and self.init_cfg['type'] == 'Pretrained'): # Suppress default init if use pretrained model. return else: self.apply(self._init_weights)
def pre_logits(self, x): if isinstance(x, tuple): x = x[-1] return x
[docs] def simple_test(self, x, softmax=True, post_process=True): """Inference without augmentation. Args: x (tuple[tuple[tensor, tensor]]): The input features. Multi-stage inputs are acceptable but only the last stage will be used to classify. Every item should be a tuple which includes convluation features and transformer features. The shape of them should be ``(num_samples, in_channels[0])`` and ``(num_samples, in_channels[1])``. softmax (bool): Whether to softmax the classification score. post_process (bool): Whether to do post processing the inference results. It will convert the output to a list. Returns: Tensor | list: The inference results. - If no post processing, the output is a tensor with shape ``(num_samples, num_classes)``. - If post processing, the output is a multi-dimentional list of float and the dimensions are ``(num_samples, num_classes)``. """ x = self.pre_logits(x) # There are two outputs in the Conformer model assert len(x) == 2 conv_cls_score = self.conv_cls_head(x[0]) tran_cls_score = self.trans_cls_head(x[1]) if softmax: cls_score = conv_cls_score + tran_cls_score pred = ( F.softmax(cls_score, dim=1) if cls_score is not None else None) if post_process: pred = self.post_process(pred) else: pred = [conv_cls_score, tran_cls_score] if post_process: pred = list(map(self.post_process, pred)) return pred
def forward_train(self, x, gt_label): x = self.pre_logits(x) assert isinstance(x, list) and len(x) == 2, \ 'There should be two outputs in the Conformer model' conv_cls_score = self.conv_cls_head(x[0]) tran_cls_score = self.trans_cls_head(x[1]) losses = self.loss([conv_cls_score, tran_cls_score], gt_label) return losses def loss(self, cls_score, gt_label): num_samples = len(cls_score[0]) losses = dict() # compute loss loss = sum([ self.compute_loss(score, gt_label, avg_factor=num_samples) / len(cls_score) for score in cls_score ]) if self.cal_acc: # compute accuracy acc = self.compute_accuracy(cls_score[0] + cls_score[1], gt_label) assert len(acc) == len(self.topk) losses['accuracy'] = { f'top-{k}': a for k, a in zip(self.topk, acc) } losses['loss'] = loss return losses
Read the Docs v: latest
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.