Pytorch中的CrossEntropyLoss()函数案例解读和结合one 您所在的位置:网站首页 如何将数据onehot编码 Pytorch中的CrossEntropyLoss()函数案例解读和结合one

Pytorch中的CrossEntropyLoss()函数案例解读和结合one

2023-10-07 03:51| 来源: 网络整理| 查看: 265

使用Pytorch框架进行深度学习任务,特别是分类任务时,经常会用到如下:

import torch.nn as nn criterion = nn.CrossEntropyLoss().cuda() loss = criterion(output, target)

即使用torch.nn.CrossEntropyLoss()作为损失函数。

那nn.CrossEntropyLoss()内部到底是啥??

nn.CrossEntropyLoss()是torch.nn中包装好的一个类,对应torch.nn.functional中的cross_entropy。 此外,nn.CrossEntropyLoss()是nn.logSoftmax()和nn.NLLLoss()的整合(将两者结合到一个类中)。

nn.logSoftmax()

定义如下: 在这里插入图片描述 从公式看,其实就是先softmax在log。

nn.NLLLoss()

定义如下: 在这里插入图片描述 此loss期望的target是类别的索引 (0 to N-1, where N = number of classes)。

例子1:

import torch.nn as nn m = nn.LogSoftmax() loss = nn.NLLLoss() # input is of size nBatch x nClasses = 3 x 5 input = autograd.Variable(torch.randn(3, 5), requires_grad=True) # each element in target has to have 0


【本文地址】

公司简介

联系我们

今日新闻

    推荐新闻

    专题文章
      CopyRight 2018-2019 实验室设备网 版权所有