Navigation

    Gpushare.com

    • Register
    • Login
    • Search
    • Popular
    • Categories
    • Recent
    • Tags

    【炼丹保姆】如何给损失函数加权重

    技术分享📚有奖励
    1
    1
    73
    Loading More Posts
    • Oldest to Newest
    • Newest to Oldest
    • Most Votes
    Reply
    • Reply as topic
    Log in to reply
    This topic has been deleted. Only users with topic management privileges can see it.
    • 173****7719
      173****7719 last edited by 173****7719

      详解pytorch的CrossEntropyLoss函数

      不加 weight

      N 是batch数
      C 是类别数

      根据公式,也就是每个batch会有一个loss,然后当reduction='mean’的时候,最终loss会取平均。

      一起来看一个三分类的问题,有4个样本

      # input shape:  (N, C)
      # target shape: (N)
      # 假设这是一个三分类问题,batch size = 4,  即 N = 4, C  = 3
      import torch
      from torch.nn import nn
      import math
      
      N = 4 # batch size
      C = 3 # class number
      input = torch.randn(N, 4)
      ''' 
      input: tensor([[0.2150, -0.1513, -0.1051],
                     [-0.7047, -0.9339, 0.2430],
                     [-0.9102, -0.4769, -0.1792],
                     [0.2268, 1.9147, 1.8868]])
      '''
      target = torch.empty(N, dtype=torch.long).random_(C)
      # target: tensor([0, 0, 0, 1])
      loss = nn.CrossEntropyLoss()
      output = loss(input,target)
      # output: tensor(1.1647)
      
      
      # 根据官方文档的公式手动验算一下,
      # https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html#torch.nn.CrossEntropyLoss
      
      l = []
      for i in range(N):
          b = 0 # 分母
          for j in range(C):
              b += math.exp(input[i][j])
          l.append(-math.log(math.exp(input[i][target[i]])/b))
      
      print(np.mean(l))
      # 1.1647, 验证成功
      

      引入weight

      # 注意:这里的weight对应的是每一类的权重
      weights = torch.tensor([9, 15, 80], dtype=torch.float32)
      loss = nn.CrossEntropyLoss(weight=weights)
      output = loss(input,target)
      # 输出:1.1081
      
      # 同样自己验算一下
      
      l = torch.empty(4, dtype=torch.float64)
      weights = torch.tensor([9, 15, 80], dtype=torch.float32)
      
      each_weight = torch.empty(N, dtype=torch.float64)
      #这里的each_weight指的是每个样本根据它的标签生成一个权重
      #由于 target 是 [0, 0, 0, 1]
      #所以 each_weight = [9, 9, 9, 15]
      
      for i in range(N):
          each_weight[i] = weights[target[i]]  
          b = 0
          for j in range(C):
              b += math.exp(input[i][j])
          l[i] = -each_weight[i]*math.log(math.exp(input[i][target[i]])/b)
      
      output = l.sum/each_weight.sum()
      # 1.1081,验证成功!
      
      1 Reply Last reply Reply Quote 1
      • Referenced by  173****7719 173****7719 
      • Referenced by  173****7719 173****7719 
      • Referenced by  Alice_恒源云 Alice_恒源云 
      • First post
        Last post