Fork me on GitHub

WGAN-GP

Improved Training of Wasserstein GANs(WGAN-GP)

Abstract

生成对抗网络是一种功能强大的生成模型,但训练不稳定。最近提出的Wasserstein GAN(WGAN)在GANs的稳定训练方面取得了进展,但有时仍然只能生成较差的样本或不能收敛。我们发现这些问题通常是由于在WGAN中使用了权值裁剪来加强对批评家的Lipschitz约束,这可能导致不希望的行为。我们提出了裁剪权重的替代方法:惩罚评论家关于其输入的渐变范数。我们提出的方法比标准WGAN表现更好,并且能够在几乎没有超参数调整的情况下对各种GAN架构进行稳定的训练,包括带有持续生成器的101层ResNets和语言模型。我们还在CIFAR-10和LSUN卧室上实现了高品质的生成器。

Introduction

生成性对抗网络(GAN)[9]是一类强大的生成模型,它将生成模型作为两个网络之间的游戏:生成器网络在给定一些噪声源的情况下生成合成数据,鉴别器网络区分生成器的输出和真实数据。GAN可以产生非常具有视觉吸引力的样本,但通常难以训练,并且最近关于该主题的大部分工作[23,19,2,21]致力于寻找稳定训练的方法。尽管如此,持续稳定GAN训练仍然是一个悬而未决的问题。

特别是,[1]提供了由GAN优化的值函数的收敛性质的分析。他们提出的替代方案,名为Wasserstein GAN(WGAN)[2],利用Wasserstein距离产生一个值函数,该函数具有比原始值更好的理论性质。WGAN要求鉴别者(在该工作中称为评论家)必须位于1-Lipschitz函数的空间内,作者通过权重削减强制执行。

我们的贡献如下:1.在玩具数据集上,我们展示了批评者权重裁剪如何导致不良行为。

2.我们提出了梯度惩罚(WGAN-GP),它没有遇到同样的问题。

3.我们展示了各种GAN架构的稳定培训,通过权值削减性能改进,高质量的图像生成,以及没有任何离散采样的字符级GAN语言模型。

Background

Generative adversarial networks

GAN的训练策略是定义两个网络之间竞争的游戏。生成器网络将噪声源映射到输入空间。鉴别器网络接收生成的样本或真实的数据样本,并且必须区分这两者。训练生成器以愚弄鉴别器。

形式上,生成器G和鉴别器D之间的游戏是最小极大目标:

其中Pr是数据分布,Pg是x‘~G(z),z~p(z)隐含定义的模型分布(生成器的输入z是从一些简单的噪声分布p中采样的,例如均匀分布或球形高斯分布)。

如果在每个生成器参数更新之前将鉴别器训练为最优性,则最小化值函数相当于最小化Pr和Pg之间的Jensen-Shannon散度[9],但这样做通常会导致梯度消失,因为鉴别器饱和。在实践中,[9]方法来规避这种困难。然而,即使这种修改的损失函数也可能在倡导者中行为不端,即生成器被训练以最大化

,这会产生良好的鉴别器[1]。

Waterstone GAN

[2]认为,GAN通常最小化散度可能是不连续的,关于生成器的参数,导致训练困难。他们建议改为使用Earth-Mover(也称为Wasserstein-1)距离W(q,p),它被非正式地定义为运输质量的最低成本,以便将分布q转换为分布p(其中成本是质量乘以运输距离)。在温和的假设下,W(q,p)在任何地方都是连续的,几乎无处不在。

使用Kantorovich-Rubinstein对偶[25]构造WGAN值函数以获得

其中D是1-Lipschitz函数的集合,Pg再次是由x‘~G(z),z~p(z)隐含定义的模型分布。在这种情况下,在最佳鉴别器(在论文中称为批评者,因为它没有经过分类训练)下,最小化关于生成器参数的值函数可以最小化W(Pr,Pg)。

WGAN值函数产生一个批评函数,其相对于其输入的梯度比其GAN对应物表现得更好,使得生成器的优化更容易。根据经验,还观察到WGAN值函数似乎与样本质量相关,而GAN则不然[2]。

为了对评论家强制执行Lipschitz约束,[2]建议将评论者的权重剪辑为k-Lipschitz函数,取决于c和评论体系结构。在紧凑的空间内[-c,c]。满足此约束的函数集是这些部分的子集,我们演示了此方法的一些问题并提出了替代方案。

Properties of the optimal WGAN critic

为了理解为什么权值削减在WGAN评论家中存在问题,以及激励我们的方法,我们强调了WGAN框架中最佳批评家的一些属性。我们在附录中证明了这些。

Difficulties with weight constraints

我们发现WGAN中的权重削减会导致优化困难,即使优化成功,最终的批评者也可能具有病态的表面价值。我们在下面解释这些问题并展示它们的影响;但是我们并不是说每问题总是在实践中出现,也不是说它们是唯一这样的机制。

我们的实验使用[2]中的特殊形式的权重约束(每个权重的大小的硬限幅),但我们也尝试了其他权重约束(L2范数裁剪,权重归一化)以及软约束(L1和L2的权重衰减)并发现他们表现出类似的问题。

在某种程度上,这些问题可以通过批评家中的批量标准化来减轻,[2]在他们的所有实验中使用。然而,即使批量归一化,我们也观察到非常深刻的WGAN批评者经常无法收敛。

Capacity underuse

通过权重削减实现k-Lipshitz约束使批评者偏向更简单的函数。如前面的推论1中所述,最佳WGAN批评家几乎在Pr和Pg下都有单位梯度范数;在权重限制约束下,我们观察到我们的神经网络架构试图获得其最大梯度范数k,最终学习极其简单的函数。

为了证明这一点,我们训练WGAN批评者在几个玩具分布上使用权重削减到最优,保持生成器分布Pg固定在实际分布加单位方差高斯噪声。我们绘制了图1a中批评家的值曲面。我们在批评家中省略了批量归一化。在每种情况下,使用权重削减训练的批评家,忽略数据分布的更高时刻,而是模拟对最优函数的非常简单的近似。相反,我们的方法不会受到这种行为的影响。

Exploding and vanishing gradients

我们观察到WGAN优化过程是困难的,不仔细调整限幅阈值c的话,因为权重约束和成本函数之间的相互作用,这导致消失或爆炸的梯度。

为了证明这一点,我们在Swiss Roll玩具数据集上训练WGAN,改变裁剪阈值c为[0.1, 0.01,0.001],并绘制关于连续激活层的批评者损失梯度的范数。生成器和判别器都是12层ReLU MLP,无需批量归一化。图1b显示,对于这些值中的每一个,随着我们在网络中向后移动,梯度会以指数方式增长或衰减。我们发现我们的方法产生更稳定的梯度,既不会消失也不会爆炸,从而可以训练更复杂的网络。

Gradient penalty

我们现在提出一种强制Lipschitz约束的替代方法。可微函数是1-Lipschtiz,当且仅当它具有最多1个范数的梯度时,所以我们考虑直接约束批评者输出相对于其输入的梯度范数。为了避免易处理性问题,我们强制执行约束的软版本,对随机样本

的梯度范数进行惩罚。我们的新目标是:

采样分布我们隐式地定义Px沿着从数据分布Pr和生成器分布Pg采样的点对之间的直线均匀采样。这是因为最优批评家包含直线,其中梯度范数1连接来自Pr和Pg的耦合点(参见命题1)。鉴于在任何地方强制执行单位梯度范数约束是难以处理的,仅沿着这些直线强制执行它似乎是足够的,并且在实验上导致良好的性能。

惩罚系数本文中的所有实验都使用λ=10,我们发现它可以很好地适用于从玩具任务到大型ImageNet CNN的各种架构和数据集。

没有批量归一化的批评家大多数先前的GAN实现[22,23,2]在生成器和鉴别器中都使用批量归一化来帮助稳定训练,但批量标准化会将鉴别器问题的形式从单个输入映射到单个输出变为从一批输入映射到一批输出[23]。我们的惩罚性训练目标在此设置中不再有效,因为我们会独立地惩罚批评家关于每个输入的梯度的标准,而不是整个批次。为了解决这个问题,我们在模型中忽略批评归一化,发现它们在没有它的情况下表现良好。我们的方法适用于规范化方案,这些方案不会引入示例之间的相关性。特别是,我们建议将层归一化[3]作为批量归一化的直接替代。

双向惩罚我们鼓励梯度的范数朝向1(双向惩罚),而不是仅仅保持在1以下(单侧惩罚)。根据经验,这似乎并没有过多地限制批评者,可能是因为最佳的WGAN批评家无论如何都在Pr和Pg之间的几乎所有地方都有1范式的渐变,并且在其间的大部分地区(见2.3小节)。在我们的早期观察中,我们发现它的表现略好一些,但我们并未对此进行全面调查。我们描述了附录中片面惩罚的实验。

Experiments

Training random architectures within a set

我们通过实验证明了我们的模型训练大量架构的能力,我们认为这些架构对训练有用。从DCGAN架构开始,我们通过将模型设置更改为表1中的随机对应值来定义一组架构变体。我们相信,对这一系列中的许多架构进行可靠的训练是一个有用的目标,但我们并不认为我们的集合是整个有用架构空间的公正或有代表性的样本:它旨在展示我们的成功制度。方法,读者应评估它是否包含与其预期应用类似的架构。

从这个集合中,我们对32X32的ImageNet中的200个体系结构进行了采样,并使用WGAN-GP和标准GAN目标进行训练。表2列出了以下任一情况的实例数:只有标准GAN成功,只有WGAN-GP成功,两者都成功和两者都失败,成功定义为inception score > min score。对于大多数得分阈值的选择,WGAN-GP成功地训练了许多我们无法用标准GAN目标训练的架构。我们在附录中提供了更多实验细节。

Training varied architectures on LSUN bedrooms

为了展示我们的模型能够以默认设置训练许多架构,我们在LSUN卧室数据集上训练了六种不同的GAN架构[31]。除了[22]的基线DCGAN架构外,我们选择了六种架构,我们展示了它们的成功训练:(1)生成器中没有BN和恒定数量的滤波器,如[2],(2)4层512维ReLU MLP生成器,如[2]中所述,(3)在鉴别器或生成器中没有归一化(4)门控乘法非线性,如[24],(5)tanh非线性,和(6)101层ResNet生成器和鉴别器。

虽然我们没有声称没有我们的方法是不可能的,但据我们所知,这是第一次在GAN设置中成功训练非常深的残差网络。对于每种架构,我们使用四种不同的GAN方法训练模型:WGAN-GP,带权重削减的WGAN,DCGAN[22]和最小二乘GAN[18]。对于每个目标,我们使用了该工作中推荐的默认优化器超参数集(除了LSGAN,我们搜索了学习率)。

对于WGAN-GP,我们用层规范化替换鉴别器中的任何批量归一化(参见第4节)。我们训练每个模型进行200K次迭代,并在图2中显示样本。我们只使用WGAN-GP成功地使用一组共享的超参数来训练每个架构。对于其他所有训练方法,其中一些架构不稳定或遭受模式崩溃。

Improved performance over weight clipping

我们的方法优于权重削减的一个优点是提高了训练速度和样本质量。为了证明这一点,我们在图3中的训练过程中训练WGAN进行了权重削减和CIFAR10
[13]的梯度惩罚以及初始得分[23]。对于WGAN-GP,我们训练一个模型使用相同的优化器(RMSProp)和学习率作为WGAN进行权重削减,另一个模型使用Adam和更高的学习率。即使使用相同的优化器,我们的方法收敛速度更快,并且比权重削减更好。使用Adam进一步提高了性能。我们还绘制了DCGAN[22]的性能,并发现我们的方法比DCGAN收敛得更慢(在挂钟时间内),但其收敛在收敛时更稳定。

Sample quality on CIFAR-10 and LSUN bedrooms

对于等效架构,我们的方法实现了与标准GAN目标相当的样本质量。然而,增加的稳定性使我们能够通过探索更广泛的架构来提高样品质量。为了证明这一点,我们找到了一种架构,它在无人监督的CIFAR-10上建立了一种新的最先进的入门分数(表3)。当我们添加标签信息时(使用[20]中的方法),相同的架构优于除SGAN之外的所有其他已发布模型。

我们还在128×128 LSUN卧室中训练了一个deepResNet,如图4所示。我们相信,这些样本与迄今为止任何分辨率的数据集最佳报告至少是有竞争力。

Modeling discrete data with a continuous generator

为了证明我们的方法对简并分布进行建模的能力,我们考虑了一个复杂离散分布的建模问题,这个复杂离散分布的生成元是在连续空间上定义的。作为该问题的一个实例,我们在谷歌十亿字数据集[6]上训练了一个字符级的GAN语言模型。我们的生成器是一个简单的一维CNN,它通过一维卷积确定性地将一个潜在向量转换成32个一维字符向量序列。我们在输出上应用了一个softmax非线性,但是没有使用采样步骤:在训练期间,softmax输出是直接传递给批评家(同样是简单的1D CNN)。当解码样本时,我们只取每个输出向量的argmax。

我们提供了表4中模型的样本。我们的模型经常出现拼写错误(很可能是因为它必须独立地输出每个字符),但是仍然能够学到很多关于语言统计的知识。我们无法得出与标准GAN相当的结果的目标,虽然我们没有说这样做是不可能的。

WGAN与其他GANs之间的性能差异可以解释如下。假设采样∆n={p∈Rn:pi≥0,Σipi=1},集合中单一最高点(或一维向量)Vn={p∈Rn:pi∈{0,1},Σipi=1}⊆∆n。如果我们有一个大小为n的词汇表我们有一个大小为T的序列的分布Pr,我们有Pr是一个分布在VTn=Vn × · · · ×Vn。由于VT n是∆T n的一个子集,我们也可以将Pr看作∆T n上的一个分布(通过给不在V中的所有点分配零概率质量T n).

Pr是离散的(或由有限数量的元素支持,即VT n)上,但Pg很容易成为∆Tn上的连续分布。两个这样的分布之间的KL分歧是无限的,所以JS散度是饱和的。虽然GANs并没有从字面上最小化[16]的这些差异,但在实践中,这意味着一个鉴别器可能很快地学会拒绝所有不位于Vtn(单个热向量序列)上的样本,并为生成器提供无意义的梯度。然而,很容易看出,即使在X=Tn的非标准学习情形下,[2]的定理1和推论1也满足条件。这意味着W(Pr,Pg)仍然是定义良好的,处处连续且几乎处处可微,我们可以对它进行优化,就像在任何其他连续变量设置中一样。这表明在WGANs中,Lipschitz约束迫使批评家提供一个线性梯度,从所有的Tn指向VTn的实点。

其他使用GANs进行语言建模的尝试[32,14,30,5,15,10]通常使用离散模型和梯度估计器[28,12,17]。我们的方法更容易实现,尽管它是否可以扩展到一个玩具语言模型还不清楚。

Meaningful loss curves and detecting overfitting

权重限制WGANs的一个重要好处是,它们的损失与样本质量相关,并收敛于一个最小值。为了证明我们的方法保留了这个性质,我们训练一个在LSUN的卧室数据集[31]上的WGAN-GP,并在图5a中绘制负的批评家损失。我们看到当生成器使W(Pr,Pg)最小化时,损耗会收敛。

如果有足够的能力和太少的训练数据,GANs将会过拟合。为了探索网络超配时的损耗曲线行为,我们在MNIST的随机1000幅图像子集上训练了大量的非正则化WGANs,并在图5(b)中绘制了训练集和验证集上负的批评家损失。在WGAN和WGAN-gp中,两种损失都是不同的,这表明评论家对W(Pr,Pg)的估计是不准确的,在这一点上,忽视样品质量联系所有的赌注都是错的。而在WGAN-GP中,训练损失逐渐增大,验证损失逐渐减小。

[29]还通过估计生成器的log-likelihood来测量GANs中的过拟合。与那个工作相比,我们的方法检测到在批评家(而不是生成器)中的过拟合,并针对网络最小化的相同损失度量过拟合。

Conclusion

在这项工作中,我们在WGAN中演示了权重削减的问题,并引入了另一种形式的惩罚项,即在批评家损失中不显示相同的问题。使用我们的方法,我们展示了强大的建模性能和跨各种架构的稳定性。现在我们有了一个更稳定的算法来训练GANs,我们希望我们的工作为在大规模图像数据集和语言上更强的建模性能开辟了道路。另一个有趣的方向是将我们的惩罚项调整到标准的GAN目标函数中,它可能通过鼓励鉴别器学习更平滑的决策边界来稳定训练。

-------------本文结束感谢您的阅读-------------
显示 Gitment 评论