0%

对比学习 Contrastive Learning

对比学习 Contrastive Learning笔记

对比学习

记住的事物特征,不一定是像素级别的,而是更高维度的。更具体来说,比如用编码去做分类任务,我们不需要知道每个数据的细节,只要抓住每个类别的主要特征,自然就能把他们分开了。

不重构数据,那如何衡量表示 Z的好坏呢?这时也可以用互信息 I(X,Z),代表我们知道了 Z之后,X 的信息量减少了多少。 如果对最大化互信息的目标进行推导,就会得到对比学习的loss(也称InfoNCE),其核心是通过计算样本表示间的距离,拉近正样本,拉远负样本。也就是说,当我们能够区分该样本的正负例时,得到的表示就够用了。

具体的做法是,输入N个图片,用不同的数据增强方法为每个图片生成两个view,每个图片的两个不同aug方式,就作为正样本,分别对它们编码得到y和y’。我们对上下两批表示两两计算cosine,得到NxN的矩阵,每一行的对角线位置代表y和y’的相似度,其余代表y和N-1个负例的相似度。

64hSFgOjXobR1G8

摘自知乎

MoCo (CVPR2020)

主要思想

  • 维护一个queue(可以比minibatch大很多),里面放的都是负样本,每次都对这个queue里面的做对比学习。存在queue里面embedding更新的问题,因为迭代过程中,queue里面的是不更新的,就带来不一致的问题。所以提出momentum update,做的就是用一个系数,让队列里面的embed缓慢更新,更加平滑。

image.png

还尝试了三种不同的方式,不同就是键值的保持和键值编码器的更新方式不同

AlN6zcoV9qRm45j

  • a方法,字典大小和mini-batch大小相同,受限于GPU显存,对大的mini-batch进行优化也是挑战,有些pretexts进行了一些调整,能够使用更大的字典,但是这样不方便进行迁移使用
  • b方法,Memory Bank包含数据集所有数据的特征表示,从Memory Bank中采样数据不需要进行反向传播,所以能支持比较大的字典,然而一个样本的特征表示只在它出现时才在Memory Bank更新,所以一个epoch只会更新一次,但模型在训练过程中不断迭代,这个特征就会“过时”,因此具有更少的一致性,而且它的更新只是进行特征表示的更新,不涉及encoder。

Pretext Task

将一对查询query和以及键值key组成样本对,如果它们出自同一图像,那么是正样本对,否则为负样本对。查询和键值分别编码自 fq 和 fk。在随机数据增强下从同一图像中任意提取两个”view”构建正样本对,负样本取自队列。

Technical details. 使用ResNet作为编码器,最后一层输出为128D向量,即查询query和键值key的表示。

Shuffling BN. 在实验中发现Batch Norm会阻止模型学到良好的特征表示。模型似乎会欺骗pretext task并容易找到低损失的解决方案。可能是因为由BN导致的intra-batch communication among samples泄露了信息。

作者通过Shuffling BN来解决该问题。在训练时使用多个GPU,在每个GPU上分别进行BN(常规操作),对于键值编码器 fk ,在当前mini-batch中打乱样本的顺序,再把它们送到GPU上分别进行BN,然后再恢复样本的顺序;对于查询编码器 fq,不改变样本的顺序。这能够保证用于计算查询和其正键值的批统计值出自两个不同的子集。

伪代码

hTrbPW4Jeu8fm9w

SimCLR

摘自segmentfault
对比学习的目标是让相似样本产生相同的表示,不相似的样本产生不同的表示。对比学习的核心是噪声对比估计损失(Noise Contrastive Estimator (NCE) loss),其其表示如下:

TKxdpW2EkfZsD7L

其中x+是输入x的相似点,(x,x+)又可称为正对。通常x+由x变换得来,如图像裁剪,旋转变换或其他的数据增广手段。反之,x-则是x的不相似样本,则有负对(x,x-),NCE loss会使得负对与正对区别开。一般对于每组正对,都会有K组负对。负对的数目对对比学习效果影响很大。

sim(.)代表相似度度量。通常其使用内积或余弦相似度。g(.)是一个卷积神经网络。有的对比学习会用siamese网络。

Vr8uByA5PstcQUq

流程模拟

  1. 首先,我们从原始图像生成批大小为N的batch。为了简单起见,我们取一批大小为N = 2的数据。在论文中,他们使用了8192的大batch
  2. 论文中定义了一个随机变换函数T,该函数取一幅图像并应用 random (crop + flip + color jitter + grayscale)。对于这个batch中的每一幅图像,使用随机变换函数得到一对图像。因此,对于batch大小为2的情况,我们得到2N = 4张总图像。CNcmJP5gkUsSV37
  3. 每一对中的增强过的图像都通过一个编码器来获得图表示。所使用的编码器是通用的,可与其他架构替换。下面显示的两个编码器有共享的权值,我们得到向量$h_i$$h_j$。可以参考上面的图。论文里面用ReNet-50作为ConvNet编码器,输出2048维向量h
  4. Dense layer部分是MLP,linear-bn-relu-linear-bn,非线性变换。研究发现encoder编码后的 h 会保留和数据增强变换相关的信息,而非线性层的作用就是去掉这些信息,让表示回归数据的本质。注意非线性层只在无监督训练时用,在迁移到其他任务时不使用
  5. 计算两两间的余弦相似度
  6. 计算损失,SimCLR使用NT-Xent损失:归一化温度-尺度交叉熵损失。

2oEaWngvFMzk8bR

这个softmax计算等价于第二个增强的猫图像与图像对中的第一个猫图像最相似的概率。这里,batch中所有剩余的图像都被采样为不相似的图像(负样本对)。然后,通过取上述计算的对数的负数来计算这一对图像的损失。这个公式就是噪声对比估计(NCE)损失:

Py5fTuBdYgGNzpM

swjqeVOHBhMxzfv

在图像位置互换的情况下,我们再次计算同一对图像的损失。

16vWJ5ZAt8lYoEC

最后,我们计算Batch size N=2的所有配对的损失并取平均值。

31tQTiA79xdMp4s
7. 结果:在ImageNet ilsvvc -2012上,实现了76.5%的top-1准确率,比之前的SOTA自监督方法Contrastive Predictive Coding提高了7%,与有监督的ResNet50持平。当训练1%的标签时,它达到85.8%的top-5精度,超过了AlexNet,但使用带标签的数据少了100倍。

hIsUBJw1NRqmHjb

MoCo v2

bRG1BLlTHsYAVW4

SimCLR v2(NIPS20)

image.png

Swav

7lqnS4HJCGm9DTb

上面的都是把每个样本作为一类来看的,但是这样耗时也比较耗资源,就提出了一个新想法,在minibatch里面聚类,然后只要区分每个类的类簇就可以。每个类簇里面的样本按道理应该要和类簇的相似度为1,但是这样太严格了,用soft label更好点,这也就是是上图的Codes。

除了这个,还提出一种新的数据增强方式:mix不同分辨率的view。

6a7HUBeudJxoN29

SEER

  1. 就是Swav作者做的新活
  2. 就是数据集更大了,然后模型更好了。
  3. 用了10亿的Instagram图片训练,然后用ResNet结合了一点神经网络搜索NAS
  4. 不细看了
  5. 8CtcDhV4iEfGadQ

BYOL

4F3WYagE57iDBXq

  1. 上面的都需要负样本,这提出在没有负样本的情况下做constrastive learning。
  2. 和SimCLR,MoCo区别
    1. SimCLR提出了nonlinear projection head的概念,nonlinear projection head里面没有BN;
    2. SimCLR提取两张经过augmentation的图片的CNN网络是相同的,而MOCO训练的时候则是有两个不同的网络, 其中一个网络根据另一个网络的parameters慢慢进行更新,且这个网络会提供一个memory bank,会提供更多大量的negative sample。因为contrastive learning是非常依仗negative sample的数量,所以negative sample数量越多,contrastive task越难,最终提取到的representation就越好。
    3. BYOL则是在MOCO的基础上直接去掉了negative sample。如下图所示,前面的结构都和MOCO相同(除了 g theta 的结构,加入了BN,后面会提到),两个不同的网络分别为上面的online网络和target网络。不同之处是online网络在经过projection得到 z theta 后加了一个predictor(由1层或2层FC组成),然后用这个predictor来预测target网络得到的 z’ ,相当于一个回归任务,loss函数采用MSE(值得注意的是,z theta 和 z’ 都经过了L2 normalize)
  3. 8G5PeQEgoKVuJnm
  4. ToZOtILmXwrbndC
  5. 没有负样本是非常奇怪的,因为loss就是为了从负样本和正样本中选出那个正样本。一般用的那个softmax cross entropy loss可以芬姐成两个部分。qTvSwibDEmcXg6U这个意思就是说,我可以把contrastive loss分解成两个部分,第一部分叫做alignment,就是希望positive pair的feature接近,第二部分叫做uniformity,就是希望所有点的feature尽量均匀的分部在unit sphere上面,都挺好理解的吧?这两部分理论上是都需要的,假如只有alignment,没有uniformity,那就很容易都坍缩到0,就是退化解。所以BYOL就是去掉uniformity,只保留了alignment。这听起来似乎不科学,因为模型很容易学到trivial solution:就是使online网络和target网络永远都输出同样的constant。所以模型为什么会work呢?看了一些大佬的分享(详见Reference),总结大概有以下几点:EMA,predictor,BN
  6. 关于EMA(exponential moving average)
    1. 详细的EMA,看这里EMA的原理和实现
    2. MA可能在帮助悄悄scatter feature,因为τ取值比较大时,target网络的更新是比较慢的,迭代次数较多的时候,使得online network和target network是不同的,进而帮助阻止模型塌陷
  7. Predictor,后面接的那个全连接层,会让online net输出和target不那么一样,而是靠ffn来match,比较灵活
  8. BN层
    1. BYOL是在MoCo基础上做的,但是相比来说加入了BN层
    2. BN实际上就是规范化一个batch的分布。得到的mean和variance都和batch里面所有的image有关,所以BN相当于一个隐性的contrastive learning:每一个image都和batch的mean做contrastive learning。
    3. 2kFiv8N4HhLqGcu
    4. 从uniformity的角度来理解,BN会把不同点的特征scatter开来,就等于默默地在做dispersion。正是因为在BN之后,batch中的所有样本都不能采用相同的值,所以可以防止模型坍塌。
    5. 复现结果来看,没有BN的话,结果和random baseline差不多,并且没有BN的时候,online net和target是非常像的

数据增强方式

  1. Si吗CLR中提出,不同的增强方式组合会更好,单一增强太简单。并且Crop和Color的组合最好,因为大多数图像中的颜色是比较一致的,即使裁剪也会容易辨认,如果去掉颜色会增加任务难度
  2. 颜色越浅,效果越好,分数越高。Yod7lbkiEWNZ9VD

BatchNorm的影响

BatchNorm导致的信息泄露

  • 在分布式训练中,BN都是分别在各个设备上做的。而对比学习的正例对在一个机器上计算,会出现信息泄露。个人认为,在BN去除batch内共同特征时,很可能被归一化到相似的分布,降低任务难度。
  • MoCo的解决办法是把一边的样本重新shuffle再并行,另一边顺序不变,这样batch的统计量就不同了;SimCLR的解法是计算一个全局的BN值。

MLP在里面的作用

在SimCLR以后的工作中都使用了MLP,SimSiam也对它的作用进行了探究:
8vDmBTx1LKktUqo

没有MLP的时候很差,MLP可能承担了估计整体期望的功能,这同SimCLR最初增加MLP时的发现是一致的,核心思想还是过滤表示中的无效特征,得到本质,服务于“对比”任务。

上面大部分内容摘自自监督对比学习(Contrastive Learning)综述+代码以及上面提到的几篇博客