原文 - Deep learning unbalanced training data?Solve it like this.
原文翻译 - 深度学习训练数据不平衡问题,怎么解决?- AI 研习社
当解决任何机器学习问题时,面临的最大问题之一是训练数据不平衡.
不平衡数据的问题在于学术界对于相同的定义、含义和可能的解决方案存在分歧.
这里尝试用图像分类问题来解开训练数据中不平衡类别的奥秘.
不平衡类会有什么问题?
在一个分类问题中,如果在所有想要预测的类别里有一个或者多个类别的样本量非常少,那这样的数据也许就面临不平衡类别的问题.
举例:
- 1.欺诈预测(欺诈的数量远远小于真实交易的数量)
- 2.自然灾害预测(不好的事情远远小于好的事情)
- 3.在图像分类中识别恶性肿瘤(训练样本中含有肿瘤的图像远比没有肿瘤的图像少)
不平衡类别会造成问题有两个主要原因:
- 1.对于不平衡类别,很难得到实时的最优结果,因为模型/算法从来没有充分地考察隐含类.
- 2.它对验证和测试样本的获取造成了一个问题,因为在一些类观测极少的情况下,很难在类中有代表性.
解决这个问题有哪些不同方法?
现在有三种主要建议的方法,它们各有利弊:
- 1.欠采样 - 随机删除观测数量足够多的类,使得两个类别间的相对比例是显著的. 虽然这种方法使用起来非常简单,但很有可能被我们删除了的数据包含着预测类的重要信息.
- 2.过采样 - 对于不平衡的类别,我们使用拷贝现有样本的方法随机增加观测数量. 理想情况下这种方法给了我们足够的样本数,但过采样可能导致过拟合训练数据.
- 3.合成采样(SMOTE)-该技术要求我们用合成方法得到不平衡类别的观测,与现有的使用最近邻分类方法很类似. 问题在于当一个类别的观测数量极度稀少时该怎么做. 比如说,要用图片分类问题确定一个稀有物种,但可能只有一幅这个稀有物种的图片.
尽管每种方法都有各自的优点,但没有什么特定的启发式方法来说明什么时候使用哪种方法.
现在将使用深度学习特定的图像分类问题详细研究这个问题.
图像分类中的不平衡类
这里选取一个图像分类问题,其存在不平衡类问题,然后将使用一种简单有效的技术来解决.
问题 - 在 kaggle 网站上选择 [座头鲸识别挑战],期望解决不平衡类别的挑战(理想情况下,所分类的鲸鱼数量少于未分类的鲸类,并且也有少数罕见鲸类我们有的图像数量更少.)
来自 kaggle :
在这场比赛中,你面临着建立一个算法来识别图像中的鲸鱼种类的挑战.
将分析 Happy Whale 数据库中的超过25,000张图像,这些数据来自研究机构和公共贡献者.
将会帮助打开有关全球海洋哺乳动物种群动态丰富的理解领域.
看看数据分布
这是一个多标签图像分类问题,首先检查数据在各个类别间的分布情况.
上面的图表表明,在4251个训练图片中,有超过2000个类别中只有一张图片. 还有一些类中有2-5个图片.
这是一个严重的不平衡类问题.
不能指望用每个类别的一张图片对深度学习模型进行训练(虽然有些算法可能正是用来做这个的,例如 one-shot 分类问题,但现在先忽略这一点).
也会产生一个问题,即如何划分训练样本和验证样本.
理想情况下,希望每个类都在训练和验证样本中有所体现.
现在应该做什么?
特别考虑了两个选项:
- 选项1 - 对训练样本进行严格的数据增强(只需要针对特定类的数据增强,但可能无法完全达到希望的目的). 因此,选择了看起来很简单的选项2.
- 选项2 - 类似于上面提到的过采样选项. 仅仅使用不同的图像增强技术将不平衡类的图像在训练数据中复制了15次.
在开始选项2之前,先看看训练样本中的一些图像.
这些图像都是鲸鱼的尾巴. 因此,识别很可能与特定的图片方向有关.
也注意到在数据中有很多图像是黑白图片或只有R/B/G通道.
根据这些观察结果,编写下面的代码,对训练样本中不平衡类的图像进行小幅改动并保存它们:
import os
from PIL import Image
from PIL import ImageFilter
filelist = train['Image'].loc[(train['cnt_freq']<10)].tolist()
for count in range(0,2):
for imagefile in filelist:
os.chdir('/home/paperspace/fastai/courses/dl1/data/humpback/train')
im=Image.open(imagefile)
im=im.convert("RGB")
r,g,b=im.split()
r=r.convert("RGB")
g=g.convert("RGB")
b=b.convert("RGB")
im_blur=im.filter(ImageFilter.GaussianBlur)
im_unsharp=im.filter(ImageFilter.UnsharpMask)
os.chdir('/home/paperspace/fastai/courses/dl1/data/humpback/copy')
r.save(str(count)+'r_'+imagefile)
g.save(str(count)+'g_'+imagefile)
b.save(str(count)+'b_'+imagefile)
im_blur.save(str(count)+'bl_'+imagefile)
im_unsharp.save(str(count)+'un_'+imagefile)
以上代码块对不平衡类(数量小于10)中的每个图像都进行如下处理:
- 1.将每张图片的 R、G、B 通道分别保存为增强副本
- 2.保存每张图片非锐化的增强副本
- 3.保存每张图片非锐化的增强副本
在上面的代码中可以看到,在这个练习中严格使用了 pillow ( python 图像库).
现在在每个不平衡类中都至少有了10个样本.
继续进行训练.
- 图像增强 - 希望确保我们的模型能够获得鲸鱼尾的详细视图. 为此,我们将变焦图包含到图像增强中.
- learning rate finder - 将学习率定为0.01,如图.
采用 Resnet50 模型进行了很少的迭代(先 frozen 模型,再 unfrozen).
发现 frozen 模型对于这个问题也非常有用,因为 imagenet 中有鲸鱼尾图像.
epoch trn_loss val_loss accuracy
0 1.827677 0.492113 0.895976
1 0.93804 0.188566 0.964128
2 0.844708 0.175866 0.967555
3 0.571255 0.126632 0.977614
4 0.458565 0.116253 0.979991
5 0.410907 0.113607 0.980544
6 0.42319 0.109893 0.981097
在测试数据上的表现
最终我们在 kaggle 排行榜上获得了真相.
提出的解决方案在本次比赛中排名34,前五的平均精确度为0.41928.
结论
有时,最简单的方法是最合理的(如果没有更多的数据,只需稍加变化地拷贝现有的数据,假装对模型来说这一类别的大多数观测与它们基本类似).
它们最有效并且可以更容易和直观地完成工作.