pytorch中的几个loss

最近跟着Tensor_Yu学习pytorch,理一理里面很多东西。
当我们想训练一个网络时,最重要的几个步骤是:如何载入数据,怎么定义并调用网络结构,最后,优化什么目标函数以及用什么方式优化?
因此,针对于最容易被忽略的部分——损失函数loss,做一些学习的整理,后面会慢慢把其他几个部分补上。

L1_loss

$L1~Loss$,顾名思义,源于L1范数。计算target与output的差的绝对值。

L2_loss, MSE_Loss

$L2~Loss$, 又称MSE均方误差损失,常用于回归任务的目标优化中。

NLL_Loss, Negative Log Likehood loss

PyTroch中,在用$NLL~Loss$之前,输入需要先经过$logsoftmax = \log\left(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} \right)$,得到每一类$softmax$归一化后的对数似然;然后进行下列公式计算即可。

当然也可以直接采用$CrossEntroy~Loss$,因为内部集成了$logsoftmax$函数,而无需再手动将输入进行$logsoftmax$变换。
$NLL~Loss$和$CrossEntroy~Loss$适用二分类任务以及多类分类。

>>> # classification task
>>> m = nn.LogSoftmax(dim=1)
>>> loss = nn.NLLLoss()
>>> # input is of size N x C = 3 x 5
>>> input = torch.randn(3, 5, requires_grad=True)
>>> # each element in target has to have 0 <= value < C
>>> target = torch.tensor([1, 0, 4])
>>> output = loss(m(input), target)
>>> output.backward()
>>>
>>> # 2D loss example (used, for example, with image inputs, segmentation task)
>>> N, C = 5, 4
>>> loss = nn.NLLLoss()
>>> # input is of size N x C x height x width
>>> data = torch.randn(N, 16, 10, 10)
>>> conv = nn.Conv2d(16, C, (3, 3))
>>> m = nn.LogSoftmax(dim=1)
>>> # each element in target has to have 0 <= value < C
>>> target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
>>> output = loss(m(conv(data)), target)
>>> output.backward()

CrossEntroy_Loss, 交叉熵loss

其实就是$logsoftmax$函数和$NLL~Loss$的集成。用$NLL~Loss$时需要先对网络forward后的原始输出进行$logsoftmax$再传入;而用$CrossEntroy~Loss$直接传入原始输出即可,内部进行归一化。

BCE_Loss

$BCE$,binary_cross_entroy, 是交叉熵损失的二分类下的特例。应用之前,需要先对输入经过$sigmoid$函数,得到归一化值。现在来想一想,把$x_n = sigmoid(x_n)$算入在内,将下面公式中的$\log x_n$看作是一个整体,那么$\log x_n$与$ \log(1 - x_n)$其实可以理解为二类下原始数据先求$softmax$再求$log$,和$cross~entory~loss$一样。
至于这里为什么用$sigmoid$函数而不用$softmax$函数?
$softmax= \frac{\exp(x_i) }{ \sum_j \exp(x_j)} = \frac{1}{1+ \exp(x_0 - x_1)}$要比$sigmoid(x_1) =\frac{1}{1+ \exp(x_1)}$需要每个batch的$x_0$缓存空间,但是$x_0$作为常数不参与反向求导,所以两者的本质是一样的。所以更多的采用sigmoid函数。
还不懂为啥本质一样?有详解

BCEWithLogits_Loss

其实就是$Sigmoid$函数和$BCE~Loss$的集成。

Conv 输出尺寸与输入关系

这个部分与本文无关,卷积操作后的尺寸问题,记录一下防止忘记

总结

这里就介绍6个Loss,但是实质上就是4个Loss,再实质上就是两个Loss,一个范数Loss一个交叉熵Loss。一个适用回归,一个适用分类。这两个任务下基于Loss的改进其实整体上还是基于他们改进。像FocalLoss等等。
另外还有三元组损失KLD损失什么的,用不上就不介绍了。