原文:深度学习任务面临非平衡数据问题?试试这个简单方法 - 2018.05.30
出处:数盟 - 微信公众号
对于数据科学或机器学习研究者而言,当解决任何机器学习问题时,可能面临的最大问题之一就是训练数据不平衡的问题.
本文将尝试使用图像分类问题来揭示训练数据中不平衡类别的奥秘.
1. 数据不平衡问题是什么?
在一个分类问题中,当你想要预测一个或多个类中的样本数量极少时,可能会遇到数据中类不平衡的问题,即部分类的样本数量远远大于其它类中的样本数量. 例如:
- 欺诈预测(真实交易的欺诈数量要低得多);
- 自然灾害预测(坏事件发生的频率将远远低于好事);
- 识别图像分类中的恶性肿瘤(具有肿瘤的图像将比训练样本内的无肿瘤的图像少得多);
2. 为什么这会是个问题?
不平衡课程造成问题主要是由于以下两个原因:
[1] - 由于模型/算法从来没有充分地查看全部类别信息,对于实时不平衡的类别没有得到最优化的结果;
[2] - 由于少数样本类的观察次数极少,这会产生一个验证或测试样本的问题,即很难在类中进行表示;
3. 解决这个问题的方法有哪些?
解决这个问题的方法主要有三种,三种各有各自的优缺点:
[1] - 下采样(Undersampling):
随机删除具有足够观察多样本的类,以便数据中类的数量比较平衡. 虽然这种方法非常简单,但很有可能删除的数据中可能包含有关预测的重要信息.
[2] - 过采样(Oversampling):
对于不平衡类(样本数少的类),随机地增加观测样本的数量,这些观测样本只是现有样本的副本,虽然增加了样本的数量,但过采样可能导致训练数据过拟合.
[3] - 合成取样(SMOT):
该技术要求综合地制造不平衡类的样本,类似于使用最近邻分类. 问题是当观察的数目是极其罕见的类时不知道怎么做.
尽管每种方法都有各自的优点,但没有什么固定的使用方式,需要根据实际问题不断自己尝试. 现在将使用深度学习特定的图像分类问题来详细研究这个问题.
4. 图像分类中的不平衡类
在本节中,将分析一个图像分类问题(其中存在不平衡类问题),然后使用一种简单有效的技术来解决它.
问题:在 kaggle上选择了“驼背鲸识别挑战”任务,期望解决不平衡类别的挑战(理想情况下,所分类的鲸鱼数量少于未分类的鲸类).
Kagele上任务说明:在这场比赛中,面临的挑战是要建立一个算法来识别图像中的鲸鱼种类. 将分析Happy Whale数据库(包含25,000多张图像),这些数据来自研究机构和公共贡献者. 通过竞赛,有助于为全球海洋哺乳动物种群动态开启丰富的理解领域.
4.1. 查看Happy Whale数据集
由于这是一个多标签图像分类问题,首先想要检查数据是如何在类中分布的.
上图表明,在4251张训练图像中,每个类只有一张图像的超过了2000张. 还有一些类只有2~5张图像. 可见这是一个严重的不平衡类问题.
我们不能期望深度学习模型每个类别仅使用一张图像进行训练. 这也会产生一个问题,即如何在训练和验证样本之间创建一个分界线,理想情况下希望每个类都在训练样本和验证样本中都有表示.
4.2. 接下来应该做什么?
本文考虑了两个特别的选项:
[1] - 选项1:对训练样本进行严格的数据增强(只需要针对特定类的数据增强,单这可能无法完全解决本文的问题).
[2] - 选项2:类似于之前提到的过采样技术. 只是使用不同的图像增强技术将不平衡类的图像复制到训练数据中15次.
在开始使用选项2处理数据之前,可以从训练样本中查看少量图像.
从图像中可以看到,图像是特定于鲸鱼的尾巴,因此,识别将可能与图像的方向有关. 同时注意到数据中有很多图像是特定的黑白或只有R/G/B通道.
根据这些观察结果,使用以下代码对训练样本中不平衡类的图像进行小幅改动并保存:
import os
from PIL import Image, ImageFilter
filelist = train["Image"].loc[(train["cnf_freq"]<10)].tolist()
for count in range(0, 2):
for imagefile in filelist:
os.chdir("/path/to/data/humpback/train")
im = Image.open(imagefile).convert("RGB")
r, g, b = im.split()
r = r.convert("RGB")
g = g.convert("RGB")
b = b.convert("RGB")
im_blur = im.filter(ImageFilter.GaussianBlur)
im_unsharp = im.fliter(ImageFilter.UnSharpMask)
os.chdir("/path/to/data/humpback/copy")
r.save(str(count) + "r_" + imagefile)
g.save(str(count) + "g_" + imagefile)
b.save(str(count) + "b_" + imagefile)
im_blur.save(str(count) + "bl_" + imagefile)
im_unsharp.save(str(count) + "un_" + imagefile)
以上代码对不平衡类中的每张图像(频率小于10)都进行如下处理:
- 将每张图像的增强副本保存为R/B/G ;
- 保存每张图像的增强副本;
- 保存每张图像未锐化的增强副本;
在上面的代码中可以看到,使用pillow库来严格执行此练习,现在已经为所有不平衡的类分配了至少10个样本. 接下来进行训练.
图像增强:只想确保模型能够获得鲸鱼fluke的详细视图. 为此,将缩放合并成图像增强.
学习率设定:从图中可以看到,将学习率定为0.01时效果最好.
使用 Resnet50 模型(第一层参数不变)进行了很少的迭代训练就能取得很好的效果,这是由于 imagenet 数据库中也有鲸鱼图像.
4.3. 测试数据集上效果如何?
在 kaggle 排行榜上可以看到模型在测试集上的效果,本文提出的解决方案在本次比赛中排名 34,平均精度均值(MAP)为0.41928.
5. 结论
有时候,最简单的方法是最合乎逻辑的(如果没有更多的数据,只需要复制现有的数据,并有轻微的变化即可),也是最有效的.