Pytorch中的CrossEntropyLoss()函数案例解读和结合one | 您所在的位置:网站首页 › 如何将数据onehot编码 › Pytorch中的CrossEntropyLoss()函数案例解读和结合one |
使用Pytorch框架进行深度学习任务,特别是分类任务时,经常会用到如下: import torch.nn as nn criterion = nn.CrossEntropyLoss().cuda() loss = criterion(output, target)即使用torch.nn.CrossEntropyLoss()作为损失函数。 那nn.CrossEntropyLoss()内部到底是啥??nn.CrossEntropyLoss()是torch.nn中包装好的一个类,对应torch.nn.functional中的cross_entropy。 此外,nn.CrossEntropyLoss()是nn.logSoftmax()和nn.NLLLoss()的整合(将两者结合到一个类中)。 nn.logSoftmax()定义如下: 从公式看,其实就是先softmax在log。 nn.NLLLoss()定义如下: 此loss期望的target是类别的索引 (0 to N-1, where N = number of classes)。 例子1: import torch.nn as nn m = nn.LogSoftmax() loss = nn.NLLLoss() # input is of size nBatch x nClasses = 3 x 5 input = autograd.Variable(torch.randn(3, 5), requires_grad=True) # each element in target has to have 0 |
CopyRight 2018-2019 实验室设备网 版权所有 |