对比学习 您所在的位置:网站首页 熵变对温度求导 对比学习

对比学习

2024-06-02 07:59| 来源: 网络整理| 查看: 265

导读

        在文章《对比学习(Contrastive Learning),必知必会》和《CIKM2021 当推荐系统遇上对比学习,谷歌SSL算法精读》中,我们都提到过两个思考:

        (1)对比学习常用的损失函数InfoNCE loss和cross entropy loss是否有联系?

        (2)对比损失InfoNCE loss中有一个温度系数,其作用是什么?温度系数的设置对效果如何产生影响?

        个人认为,这两个问题可以作为对比学习相关项目面试的考点,本文我们就一起盘一盘这两个问题,[2]是本文的重点参考。

1. InfoNCE loss公式

        InfoNCE Loss(Noise Contrastive Estimation Loss)是一种用于自监督学习的损失函数,通常用于学习特征表示或者表征学习。它基于信息论的思想,通过对比正样本和负样本的相似性来学习模型参数。 

        对比学习损失函数有多种,其中比较常用的一种是InfoNCE loss,下面我们借用恺明大佬在他的论文MoCo里定义的InfoNCE loss公式来说明。

        论文MoCo提出,我们可以把对比学习看成是一个字典查询的任务,即训练一个编码器从而去做字典查询的任务。

        假设已经有一个编码好的 query q(一个特征),以及一系列编码好的样本 k0,k1,k2......

        那么k0,k1,k2......可以看作是字典里的key。

        假设字典里只有一个 key 即 k+ (称为 positive)是跟 q 是匹配的,那么和就互为正样本对,其余的key为 q 的负样本。

        一旦定义好了正负样本对,就需要一个对比学习的损失函数来指导模型来进行学习。这个损失函数需要满足这些要求,即当 query q 和唯一的正样本相似,并且和其他所有负样本key都不相似的时候,这个loss的值应该比较低。反之,如果和不相似,或者和其他负样本的key相似了,那么loss就应该大,从而惩罚模型,促使模型进行参数更新。 

        

        MoCo采用的对比学习损失函数就是InfoNCE loss,以此来训练模型,公式如下:

        

        

2. InfoNCE loss和交叉熵损失的关系

         我们先从softmax说起,下面是softmax公式:

        

         交叉熵损失函数如下:

        

        在有监督学习下,ground truth是一个one-hot向量,softmax的结果取 -log ,再与ground truth相乘之后,即得到如下交叉熵损失:

        

         上式中的在有监督学习里指的是这个数据集一共有多少类别,比如CV的ImageNet数据集有1000类,k就是1000。

        交叉熵损失说明:让我们假设有一个四分类问题,类别标签有四个,我们可以分别标记为[0, 1, 2, 3]。在这种情况下,如果我们想用独热编码(one-hot encoding)来表示这四个类别,它们将会是这样的:

如果一个样本的真实标签是0,那么它的独热编码就是[1, 0, 0, 0]如果一个样本的真实标签是1,那么它的独热编码就是[0, 1, 0, 0]如果一个样本的真实标签是2,那么它的独热编码就是[0, 0, 1, 0]如果一个样本的真实标签是3,那么它的独热编码就是[0, 0, 0, 1]

        每条独热编码的长度即是分类的数量,所在的位置指示了标签类别,位置上的“1”表示这个样本属于该类别。

        所以如果我们的模型预测一个属于第二个类别的样本的概率分布为[0.1, 0.5, 0.3, 0.1],那么计算交叉熵损失时,我们会取的是第二个位置上0.5的log值,即-log(0.5)。因为在这种情况下,独热编码是[0, 1, 0, 0],只有第二个元素为1,其他元素都为0,所以只有这个元素参与到交叉熵损失中。

        交叉熵损失是衡量两个概率分布之间差异的一种指标。在分类问题中,我们通常有一个真实的概率分布 P (通常是一个独热编码向量,代表了样本的真实标签分布),和一个模型预测的概率分布 Q 。交叉熵损失用于衡量这两个概率分布之间的差异。

        

        对于对比学习来说,理论上也是可以用上式去计算loss,但是实际上是行不通的。为什么呢?

        还是拿CV领域的ImageNet数据集来举例,该数据集一共有128万张图片,我们使用数据增强手段(例如,随机裁剪、随机颜色失真、随机高斯模糊)来产生对比学习正样本对,每张图片就是单独一类,那k就是128万类,而不是1000类了,有多少张图就有多少类。

        但是softmax操作在如此多类别上进行计算是非常耗时的,再加上有指数运算的操作,当向量的维度是几百万的时候,计算复杂度是相当高的。所以对比学习用上式去计算loss是行不通的。

         怎么办呢?NCE loss可以解决这个问题。

         NCE(noise contrastive estimation)核心思想是将多分类问题转化成二分类问题,一个类是数据类别 data sample,另一个类是噪声类别 noisy sample,通过学习数据样本和噪声样本之间的区别,将数据样本去和噪声样本做对比,也就是“噪声对比(noise contrastive)”,从而发现数据中的一些特性。

        但是,如果把整个数据集剩下的数据都当作负样本(即噪声样本),虽然解决了类别多的问题,计算复杂度还是没有降下来,解决办法就是做负样本采样来计算loss,这就是estimation的含义,也就是说它只是估计和近似。一般来说,负样本选取的越多,就越接近整个数据集,效果自然会更好。

         NCE loss常用在NLP模型中,公式如下:

        

         上述公式细节详见:NCE loss

         

        有了NCE loss,为什么还要用Info NCE loss呢?

        Info NCE loss是NCE的一个简单变体,它认为如果你只把问题看作是一个二分类,只有数据样本和噪声样本的话,可能对模型学习不友好,因为很多噪声样本可能本就不是一个类,因此还是把它看成一个多分类问题比较合理(但这里的多分类 k 指代的是负采样之后负样本的数量,下面会解释)。

        于是就有了InfoNCE loss,公式如下:

        

        上式中,qk 是模型出来的logits,相当于上文softmax公式中的 z ,τ 是一个温度超参数,是个标量,假设我们忽略 τ,那么infoNCE loss其实就是cross entropy loss。

        唯一的区别是,在cross entropy loss里,k指代的是数据集里类别的数量,而在对比学习InfoNCE loss里,这个k指的是负样本的数量。

        上式分母中的sum是在1个正样本和 k 个负样本上做的,从0到k,所以共k+1个样本,也就是字典里所有的key。

        恺明大佬在MoCo里提到,InfoNCE loss其实就是一个cross entropy loss,做的是一个k+1类的分类任务,目的就是想把这个图片分到这个类。

        另外,我们看下图中MoCo的伪代码,MoCo这个loss的实现就是基于cross entropy loss。

        

        

 

3. 温度系数的作用

        温度系数τ虽然只是一个超参数,但它的设置是非常讲究的,直接影响了模型的效果。 上式Info NCE loss中的 qk相当于是logits,温度系数可以用来控制logits的分布形状。

        对于既定的logits分布的形状,当 τ 值变大,则 1/τ 就变小, qk/τ 则会使得原来logits分布里的数值都变小,且经过指数运算之后,就变得更小了,导致原来的logits分布变得更平滑。

        相反,如果 τ 取得值小,1/τ 就变大,原来的logits分布里的数值就相应的变大,经过指数运算之后,就变得更大,使得这个分布变得更集中,更peak。

        注:在统计学和概率论中,当我们讨论概率分布更 peak 时,通常指的是分布更加集中在特定数值或区间上,具有更大的峰值,表示这个数值或区间上的事件更加可能发生或更具有代表性。这可能意味着概率分布更加狭窄或者更集中在某些数值附近,而不是分散在整个范围内。

        如果温度系数设的越大,logits分布变得越平滑,那么对比损失会对所有的负样本一视同仁,导致模型学习没有轻重。

        如果温度系数设的过小,则模型会越关注特别困难的负样本,这些“特别困难”的负样本指的是距离对应正样本较近,但被模型误认为是负样本的样本。其实那些负样本很可能是潜在的正样本,这样会导致模型很难收敛或者泛化能力差。

        总之,温度系数的作用就是它控制了模型对负样本的区分度。

参考

[1]Momentum Contrast for Unsupervised Visual Representation Learning.

[2]https://www.bilibili.com/video/BV1C3411s7t9 (墙裂推荐!bryanyzhu大佬出品~)

相关文章推荐

对比学习(Contrastive Learning),必知必会​

CIKM2021 当推荐系统遇上对比学习,谷歌SSL算法精读​

 转载自对比学习损失(InfoNCE loss)与交叉熵损失的联系,以及温度系数的作用 - 知乎

InfoNCE Loss公式及源码理解-CSDN博客 



【本文地址】

公司简介

联系我们

今日新闻

    推荐新闻

    专题文章
      CopyRight 2018-2019 实验室设备网 版权所有