写算子时记录的笔记。
NLLLoss(Negative Log Likelihood Loss,负对数似然损失)是用于训练具有 C 类的分类问题的重要损失函数,支持加权以应对不平衡的训练集,并能够处理高维输入。该损失函数需要输入每一分类的对数概率,并且目标应该是特定的类索引,若指定 ignore_index,则该索引的损失不计入梯度计算。
数理原理
在深度学习中,NLLLoss(负对数似然损失)是最大似然概率问题的一个代理问题。实际上,它的数学本质就是将正确类别的对数概率求和 ← sum up the correct log probabilities。
朴素的 Maximum Likelihood Estimation
让我们先回顾一下什么是最大似然问题。
最大似然概率问题是统计学中的一个基本问题,它旨在找到一组参数,使得在这些参数下,给定的观测数据出现的概率最大。我们熟悉的形式是最大似然估计(Maximum Likelihood Estimation, MLE),其目标是找到参数 的最佳值,使得在这些参数下,观测数据的概率最大。
假设我们有一个由 个独立同分布的数据点组成的样本 ,并且我们假设这些数据点来自于某个参数化的概率分布 。那么就可以写出似然函数(其定义是给定参数 时观测数据 出现的联合概率):
对于独立同分布的数据点,似然函数显然可以写成:
最大似然估计的目标是找到参数 使得对数似然函数 达到最大值。也就是说,求解如下优化问题即可:
深度学习中的 Maximum Likelihood Estimation
二分类场景
讨论深度学习中的 MLE 应用时,该公式会产生一定的变化,但原理并没有发生改变。依旧给出一个由 个独立同分布的数据点组成的样本 ,以二分类场景为例,最大似然估计的目标是找到模型参数 ,使得观察到的数据的标签预测情况与真实标签 match 的程度的似然最大化,以使得模型的预测正确率在参数调整后得到提升。
此时,假设对于每个样本 ,预测其为正类 or 负类样本的概率表示为:
其中 函数是一个用于将分布在 范围上的样本值映射到 范围上的 map 函数,一般会选择使用 sigmoid 激活函数 .
最朴素的最大似然写法是与上一节相同的概率连乘。不过,此时的似然函数 通常会写成:
其中,幂指的含义是:我们只想关心那些与 true label 相关的预测值,如果某一项的 label 是 false,那么我们不希望该项对 likelihood 造成任何影响,后续调整参数的时候我们只要调整那些跟 true label 有关的样本的参数就可以了。这意味着训练出的模型会在“预测 true label”这一角度得到正确率的提升,至于能不能正确预测 false label,就不关心了。从数学上考虑,这有点像一种目标是 true label 样本的 mask select。
在大多数情况下,我们只关心对于正确标签预测的结果有多好。
跟传统的最大似然估计相同,为了简化计算,我们总是对似然函数取对数进行后续的最大化计算,得到对数似然(log likelihood):
结合幂指的含义,我们可以继续把对数似然改写成以下形式:
其中,对于二分类场景,,.
多分类场景
在多分类场景下, 的取值不再只是 0 或 1,如若给定 个分类标签,那么真实标签 的取值当然也可以分布在 范围内。此时,我们要保证 可以形成一个有效的概率分布,即所有类别的概率和应该要为 1。二分类场景应用的 sigmoid 激活函数和分类标签隐含的“非正即负”情景也是为了满足这个条件。
通常来说,我们会使用 softmax 激活函数完成这个目标:. 实际上,如果设置 c = 2 且 (二分类场景),就可以从此完全复原前述的 sigmoid 激活函数。
后续的计算过程与二分类场景类似,不再赘述。
此时 的含义就变成了:如果 为真实标签,则 ,否则为 .
NLLLoss
有了 log likelihood 的定义,给出 negative log likelihood 的定义就不是什么难事:它实际上就是上面所示的对数似然公式的负数版本!
不同之处在于,我们的目标是最大化在给定参数设置 下观察到数据的对数似然,因此相反地,我们希望最小化负对数似然。这就是 NLLLoss 在深度学习中的作用:通过调整参数 、最小化 NLLLoss,我们实际上是在最大化模型在给定数据集上的正确性😊。
在 pyTorch 的实现中,该公式被进一步简化了,它将直接要求前向调用给定的输入数据里已经包含对数概率的信息(即输入 而不是输入原始得分,例如 或者 ),不再在此函数中进一步计算概率的对数值。这就是为什么一般在神经网络中调用 NLLLoss 时,需要在前面再加一个 LogSoftmax 层。
🤔不过为什么非要引入 NLLLoss,而不是直接计算 MLE 呢?
一个或许合理的解释:多元正态分布的负对数似然函数是正定二次型, 所以如果初值取得比较合适, 负对数似然函数与多元正态分布的负对数似然函数相近, 接近于正定二次型, 这时求的最小值点会比较容易。同时我们还可以考虑凸性,负对数似然是关于未知参数的高阶连续可导的凸函数,便于求其全局最优解。而且一般来说应用于最小化问题的优化算法比较多,例如,可以用梯度下降求最小值。
CrossEntropyLoss
首先我们给出交叉熵(Cross Entropy)的定义。给定两个概率分布函数 和 ,那么定义它们的交叉熵如下:
可以看到这个形式与我们前面推导出的对数似然公式非常相似!对比一下,就可以发现负对数似然完全就是预测标签 和标签预测概率 的交叉熵。
NLLLoss 与 CELoss 的对比
前文讨论多分类场景下的 MLE 时实际上已经给出了答案。在 pyTorch 中,这两个损失函数的区别就在于 CELoss 为输入的原始预测值 施加了一个 softmax 激活函数归一化预测值,并且对预测概率做了取对数的处理(LogSoftmax)。
在 pyTorch 中选择使用 NLLLoss 还是 CELoss 取决于输入数据的形式:NLLLoss 适用于已经包含对数概率的输入数据,而 CELoss 适用于原始预测值,不需要做其他特殊处理。如果用户倾向于不在网络中再添加一个额外的 Layer,那么更推荐使用 CELoss。
Torch 中的函数定义
See in https://pytorch.org/docs/stable/generated/torch.nn.NLLLoss.html#torch.nn.NLLLoss.
其中 size_average 参数已过时,不必考虑。文档中需要注意的只是对 格式的约束和推导公式。
文档释义
对于 ,pyTorch 规定其必须是形状为 或者 的 Tensor,后者用于处理高维的 ,例如 2D 图像的像素。
(minibatch = 4, C = 3) | Class_1 | Class_2 | Class_3 |
---|---|---|---|
sample_1 | x_11 | x_12 | x_13 |
sample_2 | x_21 | x_22 | x_23 |
sample_3 | x_31 | x_32 | x_33 |
sample_4 | x_41 | x_42 | x_43 |
对比前文的手推公式,可以发现 pyTorch 函数中 的本质就是前面提到的 ,这也是为什么 NLLLoss 要求在 中预先 encode log-probabilities 的信息。在 CELoss 中,其输入的 则为前面提到的 ,经 softmax 归一化的过程和计算 log-probabilities 的过程都在函数内进行。
再看到推导公式和文档中的注释,可以发现,,也即推导式中的 ,就是前面提到的真实标签 。 的写法,就完全等同于前面提到的 mask select principle。
pyTorch 也引入了一些与手推公式的不同点:
- 引入分类的权重信息 weight。
- 指定参数 ignore_index 时,将忽略该分类的权重信息。根据推导式最后一项,该分类的权重应为 0。
- 引入 reduce 操作。
- 从推导式可以看出 pyTorch 函数的输出是一个张量,整个计算过程截止于原始公式中的 。而结合前面的手推公式,可以看到在数学定义中默认了一个 sum 的 reduce 操作。
- pyTorch 中默认的 reduce 操作是 mean,需要对 sum 的 reduce 结果再除一个全体真实标签的权重和。
NLLLoss 反向传播推导
CELoss 的反向传播的推导与此完全相同。
首先简单重复一遍正向的过程。假设此时我们有 个分类,为便于演示,此时 ,不考虑 ,模型输出的未归一化得分(logits)。然后我们需要通过 softmax 激活函数将未归一化的得分 转换为概率分布 ,其中 表示样本属于类别 的概率:
然后对其做 log 变换:
这一步得到的实际上就是 pyTorch 函数的输入 了。遵从文档写法的话,等式左边就是 .
假设真实标签 为 ,则 NLLLoss 对于单个样本输出的定义为:
对于 pyTorch 中的反向传播求导过程,应先查看:
https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#gradients
中介绍的 pyTorch 自动求梯度机制(autograd)。后文仅介绍对未归一化得分输入的梯度求解过程,对应的代码应类似:
1 | import torch |
如果要推导 NLLLoss 对未归一化得分 的梯度,则先来看损失函数对于 的导数:
然后,我们计算 softmax 激活函数的输出 对于 logits 的导数:
- 对于预测符合真实标签的正确类别 :
- 对于错误类别 :
结合链式求导法则,计算 NLLLoss 对 logits 的梯度:
- 对于预测符合真实标签的正确类别 :
- 对于错误类别 :
故 NLLLoss 对于 logits 的梯度可以表示为:
其中 是指示函数,当 时为 1,否则为 0。
含 reduce 的场景以此类推,此处不再赘述。