一. GANs的基础概念
1.生成模型与判别模型
理解对抗网络,首先要了解生成模型和判别模型。判别模型比较好理解,就像分类一样,有一个判别界限,通过这个判别界限去区分样本。从概率角度分析就是获得样本x属于类别y的概率,是一个条件概率 $P(y|x)$。 而生成模型是需要在整个条件内去产生数据的分布,就像高斯分布一样, 需要去拟合整个分布,从概率角度分析就是样本 x 在整个分布中的产生的概率,即联合概率 $P(xy)$。
2.对抗网络思想
朴素来说,GAN的主要结构包括一个生成器G(Generator)和一个判别器D(Discriminator)。通过生成器G生成的数据来训练判别器D,使其尽可能好地准确判别真实样本和生成样本,尽可能大地区分正确样本和生成的样本;同时训练生成器G使其尽可能能够减小生成样本与真实样本之间的差距,也相当于尽量使得判别器D判别错误。
3.训练流程
- 初始化判别器D的参数和生成器G的参数。
- 从真实样本中采样得到样本,从先验分布噪声中采样得到噪音样本,并通过生成器获取生成样本 。固定生成器G,训练判别器D尽可能好地准确判别真实样本和生成样本,尽可能大地区分正确样本和生成的样本。
- 循环k次更新判别器之后,使用较小的学习率来更新一次生成器的参数,训练生成器使其尽可能能够减小生成样本与真实样本之间的差距,也相当于尽量使得判别器判别错误。
- 多次更新迭代之后,最终理想情况是使得判别器判别不出样本来自于生成器的输出还是真实的输出。亦即最终样本判别概率均为0.5。
Tips: 之所以要训练k次判别器,再训练生成器,是因为要先拥有一个好的判别器,使得能够教好地区分出真实样本和生成样本之后,才好更为准确地对生成器进行更新。
3.1 前向传播阶段
1. 模型输入
- 我们随机产生一个随机向量作为生成模型的数据,然后经过生成模型后产生一个新的向量,作为 Fake Image ,记作 $D(z)$ 。
- 从数据集中随机选择一张图片,将图片转化成向量,作为 Real Image, 记作 $x$ 。
2. 模型输出
- 由上面模型引入的1或者2产生的输出,作为判别网络的输入,经过判别网络后输 出 值为一个 0 到 1之间的数,用于表示输入图片为 Real Image 的概率, real 为 1 ,fake 为 0。
- 使用得到的概率值计算损失函数,解释损失函数之前,我们先解释下判别模型的输入。根据输入的图片类型是 Fake Image 或 Real Image 将判别模型的输入数据的 label 标记为 0或者 1 。即判别模型的输入类型为 $(xfake,0)$ 或者 $(xreal,1)$ 。
3.2 反向传播阶段
1. 优化目标
原文给出的一个优化函数:
我们可以先优化判别器D再优化生成器G,所以可以拆解为两步。
- 第一步:优化判别器D
优化判别器D,即优化判别网络时,没有生成网络什么事,后面的 G(z) 就相当于已经得到的假样本。优化 D 的公式的第一项,使得真样本 x 输入的时候,得到的结果越大越好,因为真样本的预测结果越接近 1 越好;对于假样本 $G(z)$ ,需要优化的是其结果越小越好,也就是 $D(G(z))$ 越小越好,因为它的标签为 0 。但是第一项越大,第二项越小,就矛盾了,所以把第二项改为 $1-D(G(z))$,这样就是越大越好。
- 第二步:优化生成器G
在优化生成器G的时候,这个时候没有真样本,所以把第一项直接去掉,这时候只有假样本,但是这个时候希望假样本的 标签是 1 ,所以是 $D(G(z)$ 越大越好,但是为了统一成 $1-D(G(z))$ 的形式,那么只能是最小化 $1-D(G(z))$,本质上没有区别,只是为了形式的统一。
之后这两个优化模型可以合并起来写,就变成最开始的最大最小目标函数了
2.判别器D的损失函数
当输入的是从数据集中取出的 Real Image 数据时,我们只需要考虑第二部分, D(x) 为判别器D的输出,表示输入 x 为 Real 数据的概率,我们的目的是让判别模型的输出D(x)的输出尽量靠近1。
当输入的为 fake 数据时,我们只计算第一部分,$G(z)$ 是生成模型的输出,输出的是一张 Fake Image 。我们要做的是让 $D(G(z))$ 的输出尽可能趋向于 0 。这样才能表示判别模型是有区分力的。
相对判别模型来说,这个损失函数其实就是交叉熵损失函数。计算loss ,进行梯度反传。 这里的梯度反传可以使用任何一种梯度修正的方法。
当更新完判别模型的参数后,我们再去更新生成模型的参数。
3. 生成器G的损失函数
对于生成模型来说,我们要做的是让 $G(z)$ 产生的数据尽可能的和数据集中的数据一样,就是所谓的同样的数据分布。 那么我们要做的就是最小化生成模型的误差,即只将由 $G(z)$产生的误差传给生成器G 。
但是针对判别器D的预测结果,要对梯度变化的方向进行改变。当判别器D认为 $G(z)$ 输出为真实数据集的时候和认为输出为噪声数据的时候,梯度更新方向要进行改变。
最终的损失函数:
其中 $\overline{D}$ 表示判别器D的预测类别,对预测概率取整,为0或者1;用于更改梯度方向,阈值可以自己设置,或者正常的话就是 0.5 。
4.反向传播
我们已经得到了生成器G和判别器D的损失函数,这样分开看其实就是两个单独的模型,
针对不同的模型可以按照自己的需要去是实现不同的误差修正(可以选择最常用的BP 做为误差修正算法,更新模型参数)。
生成对抗网络的生成器G和判别器D是没有任何限制,生成对抗网络提出的只是一种网络结构,我们可以使用任何的生成器G和判别器D去实现一个生成对抗网络。当得到损失函数后就对单个模型的更新方法进行修正即可。
2 条评论
wx哥哥⌇●﹏●⌇爱了爱了
llz哥哥!ヾ(≧∇≦*)ゝ