论文阅读:Scaling Deep Contrastive Learning Batch Size under Memory Limited Setup
基本信息 论文标题:Scaling Deep Contrastive Learning Batch Size under Memory Limited Setup 作者单位:CMU 论文链接:https://arxiv.org/pdf/2101.06983 来源:arxiv 一、问题 对比学习通常使用InfoNCE loss进行训练,公式如下(1): 其中: \(s_i\)是anchor,\(f(s_i)\)是anchor embedding,\(f\)是anchor encoder \(t_{r_i}\)是\(s_i\)对应的positive,\(g(t_{r_i})\)是positive embedding,\(g\)是positive encoder \(t_j \in T\)是batch内其他\(s\)对应的positives,作为\(s_i\)的in-batch negatives。在没有hard negative的情况下,即每条样本是<anchor, positive>这种二元组的情况下,\(|T|\)等于batchsize \(\tau\)是温度系数,通常是一个常数,为讨论方便,后续省略该参数 在经典的双塔对比学习场景下,函数\(f\)和\(g\)通常是两个不同的网络(比如CLIP);在LLM/VLM emb场景下,函数\(f\)和\(g\)通常是共享参数的 对于对比学习,通常有in-batch negatives数量\(|T|\)越大,效果越好。但\(|T|\)越大,意味着batchsize也越大,训练时占用的显存也越多。如何在增大\(|T|\)的情况下不显著增加显存占用,是个很大的挑战。 二、方法 通常我们会使用梯度累积的方法在不增加显存的情况下增大batchsize,但梯度累积只适用于instance-wise loss,即每条样本的loss计算是独立的,这样可以把大batch拆成多个小batch分别计算梯度,然后累加起来。 但是由于对比学习loss计算时涉及到anchor和in-batch negative的运算,即对比学习loss是batch-wise loss,直接把大batch拆成多个小batch不加额外处理的话,小batch内的in-batch negatives就少了,影响对比学习效果。 本文的核心思想是,把公式1中的 对比学习loss 对 模型参数 的梯度求解过程拆分成:loss对表征\(f(s)\)的梯度 乘以 表征 对 模型参数 的梯度。由于只有loss计算需要batch-wise的运算,故上述拆解只有前半部分需要batch-wise运算,后半部分仍然可以instance-wise的计算然后梯度累加。下面来看具体过程。 为方便讨论,我们只看loss对\(f\)的参数的梯度求解过程(比如在\(f\)和\(g\)共享参数的情况下),对\(g\)的梯度求解过程的分析类似。 根据链式法则,对比学习loss \(\mathcal{L}\)对\(f\)的参数\(\Theta\)的梯度求解过程如下:把它拆解成第一项是\(\mathcal{L}\) 对表征\(f(s_i)\)的梯度,第二项是表征\(f(s_i)\)对模型参数\(\Theta\)的梯度,两项相乘就是loss \(\mathcal{L}\)对模型参数\(\Theta\)的梯度。 通过上述拆解为什么能显著降低显存呢,详细分析如下: 对于公式(2)右边第一项\(\frac{\partial \mathcal{L}}{\partial f\left(s_{i}\right)}\),根据公式(1)可以得到第一项的梯度如下公式(5)。也就是说第一项的梯度只与batch内所有的表征\(f(s)\)和\(g(t)\)有关,而与模型参数\(\Theta\)无关。所以我们可以先进行一次不含梯度的前向过程,拿到batch内所有的表征\(f(s)\)和\(g(t)\),由此可计算出公式(5),即loss \(\mathcal{L}\)对表征的梯度。由于这次前向不包含梯度(类似inference过程),所以不用记录各种中间激活值(activations)和梯度,可以大大节省显存,详细可看https://mingchao.wang/4KTgtnFc 的分析(即模型训练过程中activations是占显存的大头)。 计算完公式(5)之后,可以把这部分梯度缓存起来,用于后续计算。这部分缓存的梯度只需要额外占用\((|S|d+|T|d)\)的显存,显著小于海量的模型参数和中间activations的参数量。 对于公式(2)右边第二项\(\frac{\partial f\left(s_{i}\right)}{\partial \Theta}\),这部分梯度就是表征\(f(s_i)\)对模型参数\(\Theta\)的梯度,和常规梯度没什么两样,是instance-wise的,即每个样本的这个梯度计算是独立的。因此可以像常规梯度累积一样,进行mini-batch的计算,然后累加起来。为了完成这第二个过程,需要对每个样本\(s_i\)重新进行一次前向计算,由于需要对参数\(\Theta\)求梯度,所以这一次前向需要记录所有梯度和activations中间值。但是由于这个过程每个样本\(s_i\)可以独立计算,所以可以像梯度累积一样,把大batch拆分成多个小batch,每个mini-batch进行前向计算并进行梯度反向传播,所以显存峰值由mini-batch size决定,也不会太大。 ...