Shortcuts

备注

您正在阅读 MMClassification 0.x 版本的文档。MMClassification 0.x 会在 2022 年末被切换为次要分支。建议您升级到 MMClassification 1.0 版本,体验更多新特性和新功能。请查阅 MMClassification 1.0 的安装教程迁移教程以及更新日志

mmcls.models.losses.focal_loss 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
import torch.nn.functional as F

from ..builder import LOSSES
from .utils import convert_to_one_hot, weight_reduce_loss


def sigmoid_focal_loss(pred,
                       target,
                       weight=None,
                       gamma=2.0,
                       alpha=0.25,
                       reduction='mean',
                       avg_factor=None):
    r"""Sigmoid focal loss.

    Args:
        pred (torch.Tensor): The prediction with shape (N, \*).
        target (torch.Tensor): The ground truth label of the prediction with
            shape (N, \*).
        weight (torch.Tensor, optional): Sample-wise loss weight with shape
            (N, ). Defaults to None.
        gamma (float): The gamma for calculating the modulating factor.
            Defaults to 2.0.
        alpha (float): A balanced form for Focal Loss. Defaults to 0.25.
        reduction (str): The method used to reduce the loss.
            Options are "none", "mean" and "sum". If reduction is 'none' ,
            loss is same shape as pred and label. Defaults to 'mean'.
        avg_factor (int, optional): Average factor that is used to average
            the loss. Defaults to None.

    Returns:
        torch.Tensor: Loss.
    """
    assert pred.shape == \
        target.shape, 'pred and target should be in the same shape.'
    pred_sigmoid = pred.sigmoid()
    target = target.type_as(pred)
    pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
    focal_weight = (alpha * target + (1 - alpha) *
                    (1 - target)) * pt.pow(gamma)
    loss = F.binary_cross_entropy_with_logits(
        pred, target, reduction='none') * focal_weight
    if weight is not None:
        assert weight.dim() == 1
        weight = weight.float()
        if pred.dim() > 1:
            weight = weight.reshape(-1, 1)
    loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
    return loss


[文档]@LOSSES.register_module() class FocalLoss(nn.Module): """Focal loss. Args: gamma (float): Focusing parameter in focal loss. Defaults to 2.0. alpha (float): The parameter in balanced form of focal loss. Defaults to 0.25. reduction (str): The method used to reduce the loss into a scalar. Options are "none" and "mean". Defaults to 'mean'. loss_weight (float): Weight of loss. Defaults to 1.0. """ def __init__(self, gamma=2.0, alpha=0.25, reduction='mean', loss_weight=1.0): super(FocalLoss, self).__init__() self.gamma = gamma self.alpha = alpha self.reduction = reduction self.loss_weight = loss_weight def forward(self, pred, target, weight=None, avg_factor=None, reduction_override=None): r"""Sigmoid focal loss. Args: pred (torch.Tensor): The prediction with shape (N, \*). target (torch.Tensor): The ground truth label of the prediction with shape (N, \*), N or (N,1). weight (torch.Tensor, optional): Sample-wise loss weight with shape (N, \*). Defaults to None. avg_factor (int, optional): Average factor that is used to average the loss. Defaults to None. reduction_override (str, optional): The method used to reduce the loss into a scalar. Options are "none", "mean" and "sum". Defaults to None. Returns: torch.Tensor: Loss. """ assert reduction_override in (None, 'none', 'mean', 'sum') reduction = ( reduction_override if reduction_override else self.reduction) if target.dim() == 1 or (target.dim() == 2 and target.shape[1] == 1): target = convert_to_one_hot(target.view(-1, 1), pred.shape[-1]) loss_cls = self.loss_weight * sigmoid_focal_loss( pred, target, weight, gamma=self.gamma, alpha=self.alpha, reduction=reduction, avg_factor=avg_factor) return loss_cls
Read the Docs v: latest
Versions
master
latest
1.x
dev-1.x
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.