导图社区 半监督学习
这是一篇关于半监督学习的思维导图,主要是半监督学习的三类基本假设和三类方法,有需要的同学,可以收藏下哟!
编辑于2021-09-16 12:19:00半监督学习
三个基本假设
The Smoothness Assumption
如果两个样本 x1,x2 相似,则它们的相应输出 y1,y2 也应如此。这意味着如果两个输入相同类,并且属于同一簇,则它们相应的输出需要相近,反之亦成立。
The Cluster Assumption
假设输入数据点形成簇,每个簇对应于一个输出类,那么如果点在同一个簇中,则它们可以认为属于同一类。聚类假设也可以被视为低密度分离假设,即:给定的决策边界位于低密度地区。
The Manifold Assumption
a)输入空间由多个低维流形组成,所有数据点均位于其上;(b)位于同一流形上的数据点具有相同标签。
三类方法
一致性正则化Consistency Regularization
对于无标签图像,添加噪声之后模型预测也应该保持不变。核心思想:最小化未标记数据与其扰动输出两者之间的距离d
Pi-Model (ICLR2017)
Temporal Ensembling for Semi-Supervised Learning
由于正则化技术(例如 data augment 和 dropout)通常不会改变模型输出的概率分布,Pi-Model 正是利用神经网络中这种预测函数的特性,对于任何给定的输入 x,使用不同的正则化然后预测两次,而目标是减小两次预测之间的距离,提升模型在不同扰动下的一致性,Pi-Model 使用 MSE 做为两个概率分布之间的损失函数。
子主题
对每一个参与训练的样本,在训练阶段,Pi-Model 需要进行两次前向推理。此处的前向运算,包含一次随机增强变换和不做增强的前向运算。由于增强变换是随机的,同时模型采用了 Dropout,这两个因素都会造成两次前向运算结果的不同。
子主题
损失函数:由两部分构成,其中第一项含有一个时变系数 w,用来逐步释放此项的权重,x 是未标记数据,由两次前向运算结果的均方误差(MSE)构成。第二项由交叉熵构成,x 是标记数据,y 是对应标签,仅用来评估有标签数据的误差。可见,第一项即是用来实现一致性正则。
https://openreview.net/forum?id=BJ6oOfqge¬eId=BJ6oOfqge
https://github.com/s-laine/tempens
Temporal Ensembling (ICLR2017)
Temporal Ensembling for Semi-Supervised Learning
在 Pi-Model 的基础上进一步提出了Temporal Ensembling,其整体框架与 Pi-model 类似,在获取无标签数据的处理上采用了相同的思想
创新点:在目标函数的无监督一项中, Pi-Model 是两次前向计算结果的均方差,而在temporal ensembling 模型中,使用时序组合模型,采用的是当前模型预测结果与历史预测结果的平均值做均方差计算。有效地保留历史了信息,消除了扰动并稳定了当前值。
对于一个目标 yhat,在每次训练迭代中,当前输出 yhat通过 EMA(exponential moving averag,指数滑动平均)更新被累加到整体输出中 yema:
子主题
用空间来换取时间,总的前向推理次数减少了一半,因而减少了训练时间;
通过历史预测做平均,有利于平滑单次预测中的噪声。
https://openreview.net/forum?id=BJ6oOfqge¬eId=BJ6oOfqge
https://github.com/s-laine/tempens
Mean teachers (NIPS 2017)
Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results
Mean Teachers 则是 Temporal Ensembling 的改进版,Temporal Ensembling 对模型的预测值进行 EMA(exponential moving averag),而Mean Teachers 采用了对 studenet 模型权重进行 EMA
teacher model
子主题
损失的计算是有监督和无监督损失的总和:
子主题
子主题
https://arxiv.org/abs/1703.01780
https://github.com/CuriousAI/mean-teacher
Unsupervised Data Augmentation
Unsupervised Data Augmentation for Consistency Training
之前的工作中对未标记的数据加入噪声增强的方式主要是采用简单的随机噪声,但是这篇文章发现对输入 x 增加的噪声 α 对模型的性能提升有着重要的影响
提出对未标记的数据采取更多样化更真实的数据增强方式,并且对未标记的数据上优化相同的平滑度或一致性增强目标
最小化未标记数据和增强未标记数据上预测分布之间的 KL 差异:
x 是原始未标记数据的输入, xhat 是对未标签数据进行增强(如:图像上进行 AutoAugmen,文本进行反翻译)后的数据。
子主题
为了同时使用有标记的数据和未标记的数据,添加了标记数据的 Supervised Cross-entropy Loss 和上式中定义的一致性/平滑性目标 Unsupervised Consistency Loss,权重因子 λ 为我们的训练目标,最终目标的一致性损失函数定义如下
子主题
UDA 证明了针对性的数据增强效果明显优于无针对性的数据增强,这一点和监督学习的 AutoAugment、RandAugment 的结论是一致的。
https://arxiv.org/pdf/1904.12848v2.pdf
https://github.com/google-research/uda
给定一个未标记的数据点 x 及其扰动的形式x* ,两个输出之间的距离 d( f(x), f(x*) )
流行的距离测量d通常是均方误差(MSE),Kullback-Leiber散度(KL)和Jensen-Shannon散度(JS),其中 m=0.5 ( f(x)+f(x*) )
代理标签法Proxy-label Methods
使用预测模型或它的某些变体生成一些代理标签,这些代理标签和有标记的数据混合一起,提供一些额外的训练信息,即使生成标签通常包含嘈杂,不能反映实际情况。
self-training(模型本身生成代理标签)
Step1:首先,用少量的标签数据 L 训练 Model;也就是上图的虚线以上部分;
Step2:然后,使用训练后的 Model 给未标记的数据点 x∈U 分配 Pseudo-label(伪标签);
锐化方法
在保持预测值分布的同时使分布有些极端
子主题
Argmax 方法
仅使用对预测具有最高置信度的预测标签进行标记
对无标签数据进行过滤,如果预测结果大于预定阈值 τ,再将其添加训练中。对无标签数据进行过滤,如果预测结果大于预定阈值 τ,再将其添加训练中。
Setp3:通过交叉熵损失计算模型预测和伪标签的损失;
Step4:最后,使用训练好的模型为 U 的其余部分生成代理标签,一直循环,直到模型无法生成代理标签为止。
子主题
主要缺点是:模型无法纠正自己的错误。如果模型对自己预测的结果很有“自信”,但这种自信是盲目的,那么结果就是错的,这种偏差就会在训练中得到放大。
multi-view learning(代理标签是由根据不同数据视图训练的模型生成的)
多视图数据可以通过不同的测量方法(例如颜色信息和纹理)收集不同的视图图片信息,或通过创建原始数据的有限视图来实现。
目标是学习独特的预测函数 fθi 为数据点 x 的给定视图 vi(x) 建模,并共同优化所有用于提高泛化性能的功能
Co-training
一个模型会为另一个模型的输入提供标签:m1 和 m2 两个模型,它们分别在不同的数据集上训练。每轮迭代中,如果两个模型里的一个模型,比如模型 m1 认为自己对样本 x 的分类是可信的,置信度高,分类概率大于阈值 τ ,那 m1 会为它生成伪标签,然后把它放入 m2 的训练集。
Tri-Training
首先对有标记示例集进行可重复取样(bootstrap sampling)以获得三个有标记训练集,然后从每个训练集产生一个分类器。
在协同训练过程中,各分类器所获得的新标记示例都由其余两个分类器协作提供:如果两个分类器对同一个未标记示例的预测相同,则该示例就被认为具有较高的标记置信度,并在标记后被加入第三个分类器的有标记训练集
混合方法Holistic Methods
试图在一个框架中整合当前的 SSL 的主要方法,从而获得更好的性能
MixMatch【NeurIPS 2019】
Setp 1:Data Augmentation 标记的和未标记的数据都使用数据增强。数据增强只是标准的裁剪和翻转。
Step 2:Label Guessing 对于的每个未标记的训练数据,MixMatch 使用模型的预测为样本的生成一个“guess”标签,这个“guess”标签被用于无监督损失计算。
计算了该模型预测的分类分布在所有 K 个增量上的平均值,每个未标记的输入数据只增加两次扩增(K=2)
Step 3:Sharpening 一个很重要的过程,这个思想相当于深度学习中的 relu 过程。在给定预测的平均值的基础上,应用锐化函数减小了标签分布的熵。
sharpen 函数实际上只是一个“温度调整”,建议将温度参数 T 保持为 0.5
Step 4:MixUp 将有标签数据 X 和无标签数据 U 混合在一起形成一个混合数据 W,然后,有标签数据 X 和 W 中的前 #X 个进行 mixup 后,得到的数据作为有标签数据,作为 label group,记为 X',同样,无标签数据 U 和 W 中的后 #U 个进行 mixup 后,得到的数据作为无标签数据,作为 unlabel group,记为 U'。
Loss function:对于有标签的数据,使用交叉熵;“guess”标签的数据使用MSE;然后将两者加权组合
FixMatch(Google Brain)
FixMatch 是对弱增强图像与强增强图像之间的进行一致性正则化,使用交叉熵将 weakly augment 和 strong augment 的无标签数据进行比较
弱增强:用标准的翻转和平移策略。
强增强:输出严重失真的输入图像,先使用 RandAugment 或 CTAugment,再使用 CutOut 增强。
创新点:一致性正则化使用的是交叉熵损失函数
FixMatch 使用弱增强的数据制作了伪标签,仅使用具有高置信度的未标记数据参与训练
FixMatch使用 Wide-Resnet 变体作为基础体系结构,记为 Wide-Resnet-28-2,其深度为 28,扩展因子为 2。因此,此模型的宽度是 ResNet 的两倍。
子主题
Input:准备了 batch=B 的有标签数据和 batch=μB 的无标签数据,其中 μ 是无标签数据的比例
监督训练:对于在标注数据的监督训练,将常规的交叉熵损失 H() 用于分类任务。有标签数据的损失记为 ls,如伪代码中第 2 行所示;
生成伪标签:对无标签数据分别应用弱增强和强增强得到增强后的图像,再送给模型得到预测值,并将弱增强对应的预测值通过 argmax 获得伪标签;
一致性正则化:将强增强对应的预测值与弱增强对应的伪标签进行交叉熵损失计算,未标注数据的损失由 lu 表示,如伪代码中的第 7 行所示;式 τ 表示伪标签的阈值;
完整损失函数:最后,我们将 ls 和 lu 损失相结合,如伪代码第 8 行所示,对其进行优化以改进模型,其中,λu 是未标记数据对应损失的权重。
总结
当标注的数据较少时模型训练容易出现过拟合
一致性正则化方法通过鼓励无标签数据扰动前后的预测相同使学习的决策边界位于低密度区域,很好缓解了过拟合这一现象
代理标签法通过对未标记数据制作伪标签然后加入训练,以得到更好的决策边界
众多方法中,混合方法表现出了良好的性能,是近来的研究热点。
参考文献
Takeru M , Shin-Ichi M , Shin I , et al. Virtual Adversarial Training: A Regularization Method for Supervised and Semi-Supervised Learning[J]. IEEE Transactions on Pattern Analysis and Machine Intelligence, 2018:1-1.
Verma V , Lamb A , Kannala J , et al. Interpolation Consistency Training for Semi-Supervised Learning[J]. 2019.
Avrim Blum and Tom Mitchell. Combining labeled and unlabeled data with co-training. In Proceedings of the eleventh annual conference on Computational learning theory, pages 92–100, 1998.
Zhi-Hua Zhou and Ming Li. Tri-training: Exploiting unlabeled data using three classififiers. IEEE Transactions on knowledge and Data Engineering, 17(11):1529–1541, 2005.
Dong-Hyun Lee. Pseudo-label: The simple and effiffifficient semi-supervised learning method for deep neural networks. In Workshop on challenges in representation learning, ICML, volume 3, page 2, 2013.