出处:Paddle文档平台 - GAN 基本概念

1. 博弈论

博弈论可以被认为是两个或多个理性的代理人或玩家之间相互作用的模型。

理性这个关键字,因为它是博弈论的基础。我们可以简单地把理性称为一种理解,即每个行为人都知道所有其他行为人都和他/她一样理性,拥有相同的理解和知识水平。同时,理性指的是,考虑到其他行为人的行为,行为人总是倾向于更高的报酬/回报。

既然我们已经知道了理性意味着什么,让我们来看看与博弈论相关的其他一些关键词:

  • 游戏:一般来说,游戏是由一组玩家,行动/策略和最终收益组成。例如:拍卖、象棋、政治等。
  • 玩家:玩家是参与任何游戏的理性实体。例如:在拍卖会的投标人、石头剪刀布的玩家、参加选举的政治家等。
  • 收益:收益是所有玩家在获得特定结果时所获得的奖励。它可以是正的,也可以是负的。正如我们之前所讨论的,每个代理都是自私的,并且想要最大化他们的收益:

2. 纳什均衡

纳什均衡(或者纳什平衡),Nash equilibrium ,又称为非合作博弈均衡,是人工智能博弈论方法的“基石”。

所谓纳什均衡,指的是参与者的一种策略组合,在该策略上,任何参与人单独改变策略都不会得到好处,即每个人的策略都是对其他人的策略的最优反应。换句话说,如果在一个策略组合上,当所有其他人都不改变策略时,没有人会改变自己的策略,则该策略组合就是一个纳什均衡。

经典的例子就是囚徒困境

背景:一个案子的两个嫌疑犯A和B被警官分开审讯,所以A和B没有机会进行串供的;

奖惩:警官分别告诉A和B,如果都不招供,则各判3年;如果两人均招供,均判5年;如果你招供、而对方不招供,则你判1年,对方10年。

结果:A和B都选择招供,各判5年,这个便是此时的纳什均衡。

从奖惩说明看都不招供才是最优解,判刑最少。其实并不是这样,A和B无法沟通,于是从各自的利益角度出发:

嫌疑犯A想法:

  • 如果B招供,如果我招供只判5年,不招供的话就判10年;
  • 如果B不招供,如果我招供只判1年,不招供的话就判3年;

所以无论B是否招供,A只要招供了,对A而言是最优的策略。

同上,嫌疑犯B想法也是相同的,都依据各自的理性而选择招供,这种情况就被称为纳什均衡点。

3. GAN生成器的输入为什么是噪声?

GAN生成器Generator的输入是随机噪声,目的是每次生成不同的图片。但如果完全随机,就不知道生成的图像有什么特征,结果就会不可控,因此通常从一个先验的随机分布产生噪声。常用的随机分布:

  • 高斯分布:连续变量中最广泛使用的概率分布;
  • 均匀分布:连续变量x的一种简单分布。

引入随机噪声使得生成的图片具有多样性,比如下图不同的噪声z可以产生不同的数字:

4. 生成器Generator

生成器G是一个生成图片的网络,可以采用多层感知机、卷积网络、自编码器等。它接收一个随机的噪声z,通过这个噪声生成图片,记做G(z)。通过下图模型结构讲解生成器如何一步步将噪声生成一张图片:

[1] - 输入:100维的向量;

[2] - 经过两个全连接层Fc1和Fc2、一个Resize,将噪声向量放大,得到128个7*7大小的特征图;

[3] - 进行上采样,以扩大特征图,得到128个14*14大小的特征图;

[4] - 经过第一个卷积Conv1,得到64个14*14的特征图;

[5] - 进行上采样,以扩大特征图,得到64个28*28大小的特征图;

[6] - 经过第二个卷积Conv2,将输入的噪声Z逐渐转化为1*28*28的单通道图片输出,得到生成的手写数字。

Tips:

全连接层作用:维度变换,变为高维,方便将噪声向量放大。因为全连接层计算量稍大,后序改进的GAN移除全连接层。

Tips:

最后一层激活函数通常使用tanh():既起到激活作用,又起到归一作用,将生成器的输出归一化至[-1,1],作为判别器的输入。也使GAN的训练更稳定,收敛速度更快,生成质量确实更高。

5. 判别器Discriminator

判别器D的输入为真实图像和生成器生成的图像,其目的是将生成的图像从真实图像中尽可能的分辨出来。属于二分类问题,通过下图模型结构讲解判别器如何区分真假图片:

  • 输入:单通道图像,尺寸为28*28像素(非固定值,根据实际情况修改即可)。
  • 输出:二分类,样本是真或假。

[1] - 输入:28*28*1像素的图像;

[2] - 经过第一个卷积conv1,得到64个26*26的特征图,然后进行最大池化pool1,得到64个13*13的特征图;

[3] - 经过第二个卷积conv2,得到128个11*11的特征图,然后进行最大池化pool2,得到128个5*5的特征图;

[4] - 通过Resize将多维输入一维化;

[5] - 再经过两个全连接层fc1和fc2,得到原始图像的向量表达;

[6] - 最后通过Sigmoid激活函数,输出判别概率,即图片是真是假的二分类结果。

6. GAN损失函数

在训练过程中,生成器G(Generator)的目标就是尽量生成真实的图片去欺骗判别器D(Discriminator)。而D的目标就是尽量把G生成的图片和真实的图片区分开。这样,G和D构成了一个动态的“博弈过程”。

最后博弈的结果是什么?在最理想的状态下,G可以生成足以“以假乱真”的图片G(z)。对于D来说,它难以判定G生成的图片究竟是不是真实的,因此D(G(z)) = 0.5。

用公式表示如下:

$$\begin{equation} \mathop{min}\limits_{G}\mathop{max}\limits_{D}V(D,G) = Ε_{x\sim p_{data}(x)} \left[\log D\left(x\right)\right]+Ε_{z\sim p_{z}(z)}\left[\log \left(1 - D\left(G\left(z\right)\right)\right)\right]\end{equation} \tag{1}$$

公式左边V(D,G)表示生成图像和真实图像的差异度,采用二分类(真、假两个类别)的交叉熵损失函数。包含minG和maxD两部分:

[1] - $\mathop{max}\limits_{D}V(D,G)$ 表示固定生成器G训练判别器D,通过最大化交叉熵损失V(D,G)来更新判别器D的参数。D的训练目标是正确区分真实图片x和生成图片G(z),D的鉴别能力越强,D(x)应该越大,右边第一项更大,D(G(x))应该越小,右边第二项更大。这时V(D,G)会变大,因此式子对于D来说是求最大(maxD)。

[2] - $\mathop{min}\limits_{G}\mathop{max}\limits_{D}V(D,G)$ 表示固定判别器D训练生成器G,生成器要在判别器最大化真、假图片交叉熵损失V(D,G)的情况下,最小化这个交叉熵损失。此时右边只有第二项有用, G希望自己生成的图片“越接近真实越好”,能够欺骗判别器,即D(G(z))尽可能得大,这时V(D, G)会变小。因此式子对于G来说是求最小(min_G)。

  • $x\sim p_{data}(x)$:表示真实图像;
  • $z\sim p_{z}(z)$:表示高斯分布的样本,即噪声;
  • D(x)代表x为真实图片的概率,如果为1,就代表100%是真实的图片,而输出为0,就代表不可能是真实的图片。

等式的右边其实就是将等式左边的交叉商损失公式展开,并写成概率分布的期望形式。

详细的推导请参见原论文Generative Adversarial Nets

7. 模型训练

GAN包含生成器G和判别器D两个网络,那么我们如何训练两个网络?

训练时先训练鉴别器D:将真实图片打上真标签1和生成器G生成的假图片打上假标签0,一同组成batch送入判别器D,对判别器进行训练。计算loss时使判别器对真实图像输入的判别趋近于真,对生成的假图片的判别趋近于假。此过程中只更新判别器的参数,不更新生成器的参数。

然后再训练生成器G:将高斯分布的噪声z送入生成器G,将生成的假图片打上真标签1送入判别器D。计算loss时使判别器对生成的假图片的判别趋近于真。此过程中只更新生成器的参数,不更新判别器的参数。

注意:训练初期,当G的生成效果很差时,D会以高置信度来拒绝生成样本,因为它们与训练数据明显不同。因此,log(1−D(G(z)))饱和(即为常数,梯度为0)。因此我们选择最大化logD(G(z))而不是最小化log(1−D(G(z)))来训练G,和公示(1)右边第二项比较。

8. 模型训练不稳定

GAN训练不稳定的原因如下:

  • 不收敛:很难使两个模型G和D同时收敛;
  • 模式崩溃:生成器G生成单个或有限模式(4.2节详细介绍);
  • 慢速训练:生成器G的梯度消失。

训练GAN的时候,可以采取以下训练技巧:

[1] - 生成器最后一层的激活函数用tanh(),输出归一化至[-1, 1];

[2] - 真实图像也归一化到[-1,1]之间;

[3] - 学习率不要设置太大,初始1e-4可以参考,另外可以随着训练进行不断缩小学习率;

[4] - 优化器尽量选择Adam,因为SGD解决的是一个寻找最小值的问题,GAN是一个博弈问题,使用SGD容易震荡;

[5] - 避免使用ReLU和MaxPool,减少稀疏梯度的可能性,可以使用Leak Re LU激活函数,下采样可以用Average Pooling或者Convolution + stride替代。上采样可以用PixelShuffle, ConvTranspose2d + stride;

[6] - 加噪声:在真实图像和生成图像中添加噪声,增加鉴别器训练难度,有利于提升稳定性;

[7] - 如果有标签数据,尽量使用标签信息来训练;

[8] - 标签平滑:如果真实图像的标签设置为1,我们将它更改为一个较低的值,比如0.9,避免鉴别器对其分类过于自信 。

9. 模式崩溃

如何判断GAN模型发生崩溃?

表现:生成器开始退化,总是生成同样的样本点,无法继续学习,这种情况称为模式崩溃。例如,在MNIST数据集上训练GAN之后,我们的生成器可能无法生成与数字4不同的数字。

可以尝试以下解决办法:

[1] - 学习率:尝试使用一个较低的学习率从头开始训练;

[2] - Mini-batch discrimination:在判别器的中间层建立一个mini-batch layer用于计算基于L1距离的样本统计量,通过建立该统计量,实现了一个batch内某个样本与其他样本有多接近。这个信息可以被判别器利用到,从而甄别出哪些缺乏多样性的样本。对生成器而言,则要试图生成具有多样性的样本;

[3] - 使用GAN改进算法:Multi agent diverse GAN(MAD-GAN)采用多个生成器,一个判别器以保障样本生成的多样性等。

Last modification:September 18th, 2021 at 11:11 pm