您现在的位置是:首页 > 正文

Focal Loss pytorch实现

2024-02-29 15:13:00阅读 0

引用自知乎   以备后用

实现

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable


class FocalLoss(nn.Module):
    r"""
        This criterion is a implemenation of Focal Loss, which is proposed in
        Focal Loss for Dense Object Detection.

            Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class])

        The losses are averaged across observations for each minibatch.

        Args:
            alpha(1D Tensor, Variable) : the scalar factor for this criterion
            gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5),
                                   putting more focus on hard, misclassified examples
            size_average(bool): By default, the losses are averaged over observations for each minibatch.
                                However, if the field size_average is set to False, the losses are
                                instead summed for each minibatch.


    """
    def __init__(self, class_num, alpha=None, gamma=2, size_average=True):
        super(FocalLoss, self).__init__()
        if alpha is None:
            self.alpha = Variable(torch.ones(class_num, 1))
        else:
            if isinstance(alpha, Variable):
                self.alpha = alpha
            else:
                self.alpha = Variable(alpha)
        self.gamma = gamma
        self.class_num = class_num
        self.size_average = size_average

    def forward(self, inputs, targets):
        N = inputs.size(0)
        C = inputs.size(1)
        P = F.softmax(inputs)

        class_mask = inputs.data.new(N, C).fill_(0)
        class_mask = Variable(class_mask)
        ids = targets.view(-1, 1)
        class_mask.scatter_(1, ids.data, 1.)

        if inputs.is_cuda and not self.alpha.is_cuda:
            self.alpha = self.alpha.cuda()
        alpha = self.alpha[ids.data.view(-1)]

        probs = (P*class_mask).sum(1).view(-1,1)

        log_p = probs.log()

        batch_loss = -alpha*(torch.pow((1-probs), self.gamma))*log_p

        if self.size_average:
            loss = batch_loss.mean()
        else:
            loss = batch_loss.sum()
        return loss

网站文章

  • 基于java的网络版坦克大战游戏系统设计与实现(项目报告+答辩PPT+源代码+数据库+部署视频)

    基于java的网络版坦克大战游戏系统设计与实现(项目报告+答辩PPT+源代码+数据库+部署视频)

    基于Java的坦克大战游戏的设计与实现目 录1.引言.............................................................................

    2024-02-29 15:12:30
  • Eclipse快速上手指南之使用JUnit

    Eclipse快速上手指南之使用JUnit

    测试对于保证软件开发质量有着非常重要的作用,单元测试更是必不可少,JUnit是一个非常强大的单元测试包,可以对一个/多个类的单个/多个方法测试,还可以将不同的TestCase组合成TestSuit,使测试任务自动化。Eclipse同样集成了JUnit,可以非常方便地编写TestCase。  我们创建一个Java工程,添加一个example.Hello类,首先我们给Hello类添加一个abs()方法

    2024-02-29 15:12:23
  • 前端学习笔记(1)-vue相关格式化插件集成

    前端学习笔记(1)-vue相关格式化插件集成

    vue相关格式化插件集成

    2024-02-29 15:12:14
  • 为程序员量身定制的12个目标

    为程序员量身定制的12个目标ugmbbc发布于 2012-01-15 11:53:19|13029 次阅读 字体:大 小 打印预览 [分享至腾讯微博] [分享到QQ空间] 分享至新浪微博 转贴到开心网 分享到校内人人网 添加到Google书签cnBeta 博文精选感谢伯乐的投递对程序员们来说挑战自我非常重要,要么不断创新,要么技术停滞不前。新年伊始,我整理了12个月的目标,...

    2024-02-29 15:11:44
  • 学习JQuery的toggle()遇到的问题

    翠和他都让我平时多写博客记录自己遇到的问题,总结的方法等。我总是没有践行,不以为意。今天开始记录我遇到的问题,以后再来解决,不知道堵在哪里了。最后发现,竟然是JQuery库文件的原因,换了一个库文件,马上就可以执行了。 初始化 $(function(){ $("#btn").bind("click",function(){ var $content=$(

    2024-02-29 15:11:35
  • 影片推荐系统思考以及用spark.mllib.ALS实现最简单的推荐

    影片推荐系统思考以及用spark.mllib.ALS实现最简单的推荐

    影片推荐系统思考 1、用户信息的补充和处理 背景:智能电视通过机顶盒向用户分发电视节目。事先采集的用户信息及其有限。且电视节目的用户大多以家庭为单位,用户画像也相应呈现出家庭的特征,各项属性,如年龄,...

    2024-02-29 15:11:29
  • DB9接口详解---DB9引脚在 UART,CAN,RS485中的定义

    DB9接口详解---DB9引脚在 UART,CAN,RS485中的定义

    DB9端口的线缆在串行通信中使用较为普遍,本文将围绕DB9端口的定义、测试、产品以及连接方式等内容,详细介绍DB9端口。

    2024-02-29 15:10:59
  • 第三屏屏闪的原因分析

    第三屏屏闪的原因分析

    第三屏屏闪原因分析

    2024-02-29 15:10:52
  • 作为Java后台,这些都不会的话,就别去面试了 热门推荐

    作为Java后台,这些都不会的话,就别去面试了 热门推荐

    还有,ConcurrentHashMap的设计思路和HashMap是同步的,也就是说,ConcurrentHashMap除了锁机制这块的处理与HashMap不同,数组+链表(+红黑树)是和HashMa...

    2024-02-29 15:10:45
  • 【机器学习】,正则化惩罚(实例+图示+解释)

    【机器学习】,正则化惩罚(实例+图示+解释)

    在机器学习特别是深度学习中,我们通过大量数据集希望训练得到精确、泛化能力强的模型,对于生活中的对象越简洁、抽象就越容易描述和分别,相反,对象越具体、复杂、明显就越不容易描述区分,描述区分的泛化能力就越...

    2024-02-29 15:10:37