最近在用VAE处理一些文本问题的时候遇到了对离散形式的后验分布求期望的问题,于是沿着“离散分布 + 重参数”这个思路一直搜索下去,最后搜到了Gumbel Softmax,从对Gumbel Softmax的学习过程中,把重参数的相关内容都捋了一遍,还学到一些梯度估计的新知识,遂记录在此。

文章从连续情形出发开始介绍重参数,主要的例子是正态分布的重参数;然后引入离散分布的重参数,这就涉及到了Gumbel Softmax,包括Gumbel Softmax的一些证明和讨论;最后再讲讲重参数背后的一些故事,这主要跟梯度估计有关。

基本概念 #

重参数(Reparameterization)实际上是处理如下期望形式的目标函数的一种技巧:

(1)Lθ=Ezpθ(z)[f(z)]

这样的目标在VAE中会出现,在文本GAN也会出现,在强化学习中也会出现(f(z)对应于奖励函数),所以深究下去,我们会经常碰到这样的目标函数。取决于z的连续性,它对应不同的形式:
(2)pθ(z)f(z)dz(连续情形)zpθ(z)f(z)(离散情形)

当然,离散情况下我们更喜欢将记号z换成y或者c

为了最小化Lθ,我们就需要把Lθ明确地写出来,这意味着我们要实现从pθ(z)中采样,而pθ(z)是带有参数θ的,如果直接采样的话,那么就失去了θ的信息(梯度),从而无法更新参数θ。而Reparameterization则是提供了这样的一种变换,使得我们可以直接从pθ(z)中采样,并且保留θ的梯度。(注:如果考虑最一般的形式,那么f(z)也应该带上参数θ,但这没有增加本质难度。)

连续情形 #

简单起见,我们先考虑连续情形

(3)Lθ=pθ(z)f(z)dz

其中pθ(z)是具有显式概率密度表达式的分布,在变分自编码器中常见的是正态分布pθ(z)=N(z;μθ,σθ2)

形式 #

从式(3)中知道,连续情形的Lθ实际上就对应一个积分,所以,为了明确写出Lθ,有两种途径:最直接的方式是精确地完成积分(3),得到显式表达式,但这通常都是不可能的了;所以,唯一的办法是转化为采样形式(1),并试图在采样过程中保留θ的梯度。

重参数就是这样的一种技巧,它假设从分布pθ(z)中采样可以分解为两个步骤:(1) 从无参数分布q(ε)中采样一个ε;(2) 通过变换z=gθ(ε)生成z那么,式(1)就变成了

(4)Lθ=Eεq(ε)[f(gθ(ε))]

这时候被采样的分布就没有任何参数了,全部被转移到f内部了,因此可以采样若干个点,当成普通的loss那样写下来了。

例子 #

一个最简单的例子就是正态分布:对于正态分布来说,重参数就是“从N(z;μθ,σθ2)中采样一个z”变成“从N(ε;0,1)中采样一个ε,然后计算ε×σθ+μθ”,所以

(5)EzN(z;μθ,σθ2)[f(z)]=EεN(ε;0,1)[f(ε×σθ+μθ)]

如何理解直接采样没有梯度而重参数之后就有梯度呢?其实很简单,比如我说从N(z;μθ,σθ2)中采样一个数来,然后你跟我说采样到5,我完全看不出5跟θ有什么关系呀(求梯度只能为0);但是如果先从N(ε;0,1)中采样一个数比如0.2,然后计算0.2σθ+μθ,这样我就知道采样出来的结果跟θ的关系了(能求出有效的梯度)。

总结 #

让我们把前面的内容重新整理一下。总的来说,连续情形的重参数还是比较简单的:连续情形下,我们要处理的Lθ实际上是式(3),由于精确的积分我们没有办法显式地写出来,所以需要转化为采样,而为了在采样的过程中得到有效的梯度,我们就需要重参数。

从数学本质来看,重参数是一种积分变换,即原来是关于z积分,通过z=gθ(ε)变换之后得到新的积分形式,

离散情形 #

为了突出“离散”,我们将随机变量z换成y,即对于离散情形要面对的目标函数是

(6)Lθ=Eypθ(y)[f(y)]=ypθ(y)f(y)

其中离散意味着一般情况y是可枚举的,换句话说pθ(y)此时是一个k分类模型:
(7)pθ(y)=softmax(o1,o2,,ok)=1i=1keoi(eo1,eo2,,eok)

其中各个oiθ的函数。

分析 #

读者看到(6)中的求和,第一反应可能是“求和?那就求呗,又不是求不了。”。

的确,这也是笔者当时看到它的第一反应。与连续情形的(3)不一样,式(3)如果直接硬杠的话需要完成积分(也可以看成无穷多个点的求和),我们没法做到这一点。但是对于离散的(6),只不过是有限项求和,理论上确实可以直接完成求和再去梯度下降。

但是,如果k特别大呢?举个例子,假设y是一个100维的向量,每个元素不是0就是1(二元变量),那么所有不同的y的总数目就是2100,要对这样的2100个单项进行求和,计算量是难以接受的;还有一个典型的例子是seq2seq的解码端(如果要做文本GAN就需要面对它),它的类别总数目是|V|l,其中|V|是词表大小而l是句子长度。这样的情况下,直接完成精确的求和都是难以实现的。

形式 #

所以,还是需要回到采样上去,如果能够采样若干个点就能得到(6)的有效估计,并且还不损失梯度信息,那自然是最好了。为此,需要先引入Gumbel Max,它提供了一种从类别分布中采样的方法。

假设每个类别的概率是p1,p2,,pk,那么下述过程提供了一种依概率采样类别的方案,称为Gumbel Max:

(8)argmaxi(logpilog(logεi))i=1k,εiU[0,1]

也就是说,先算出各个概率的对数logpi,然后从均匀分布U[0,1]中采样k个随机数ε1,,εk,把log(logεi)加到logpi上去,最后把最大值对应的类别抽取出来就行了。

后面我们会证明,这样的过程精确等价于依概率p1,p2,,pk采样一个类别,换句话说,在Gumbel Max中,输出i的概率正好是pi。由于现在的随机性已经转移到U[0,1]上去了,并且U[0,1]不带有未知参数,因此Gumbel Max就是离散分布的一个重参数过程。

但是,我们希望重参数不丢失梯度信息,但是Gumbel Max做不到,因为argmax不可导,为此,需要做进一步的近似。首先,留意到在神经网络中,处理离散输入的基本方法是转化为one hot形式,包括Embedding层的本质也是one hot全连接(参考《词向量与Embedding究竟是怎么回事?》),因此argmax实际上是one_hot(argmax)),然后,我们寻求one_hot(argmax))的光滑近似,它就是softmax(参考《函数光滑化杂谈:不可导函数的可导逼近》)。

由此,我们得到Gumbel Max的光滑近似版本——Gumbel Softmax:

(9)softmax((logpilog(logεi))/τ)i=1k,εiU[0,1]

其中参数τ>0称为退火参数,它越小输出结果就越接近one hot形式(但同时梯度消失就越严重)。提示一个小技巧,如果pi是softmax的输出,即(7)的形式,那么大可不必先算出pi再取对数,直接将logpi替换为oi即可:
(10)softmax((oilog(logεi))/τ)i=1k,εiU[0,1]

Gumbel Max的证明:

Gumbel Max的形式看上去有点复杂,远没有正态分布的重参数简单,但事实上只要鼓起勇气去看它,它连证明都不困难。我们想要证明Gumbel Max最后输出i的概率是pi,不失一般性,这里我们证明输出1的概率是p1

注意,输出1意味着logp1log(logε1)是最大的,这又意味着:

(11)logp1log(logε1)>logp2log(logε2)logp1log(logε1)>logp3log(logε3)logp1log(logε1)>logpklog(logεk)

注意每个不等式都是独立的,也就是说logp1log(logε1)logp2log(logε2)的关系如何,也不影响它跟logp3log(logε3)的关系。这样我们只需要单独分析每一个不等式的概率。不失一般性,我们只分析第一个不等式,化简后得到:

(12)ε2<ε1p2/p11

由于ε2U[0,1],所以ε2<ε1p2/p1的概率就是ε1p2/p1,这就是固定ε1的情况下,第一个不等式成立的概率。那么,所有不等式同时成立的概率是
(13)ε1p2/p1ε1p3/p1ε1pk/p1=ε1(p2+p3++pk)/p1=ε1(1/p1)1

然后对所有ε1求平均,就是
(14)01ε1(1/p1)1dε1=p1

这就是类别1出现的概率,它就是p1。至此,我们完成了Gumbel Max采样过程的证明。

例子 #

跟连续情形一样,Gumbel Softmax就是用在需要求Eypθ(y)[f(y)]、且无法直接完成对y求和的场景,这时候我们算出pθ(y)(或者oi),然后选定一个τ>0,用Gumbel Softmax算出一个随机向量来y~,代入计算得到f(y~),它就是Eypθ(y)[f(y)]的一个好的近似,且保留了梯度信息。

注意,Gumbel Softmax不是类别采样的等价形式,Gumbel Max才是。而Gumbel Max可以看成是Gumbel Softmax在τ0时的极限。所以在应用Gumbel Softmax时,开始可以选择较大的τ(比如1),然后慢慢退火到一个接近于0的数(比如0.01),这样才能得到比较好的结果。

下面提供一个自己实现的离散隐变量的VAE例子:
https://github.com/bojone/vae/blob/master/vae_keras_cnn_gs.py

效果图:

基于Gumbel Softmax重参数的离散隐变量VAE生成

基于Gumbel Softmax重参数的离散隐变量VAE生成

溯源 #

Gumbel Max由来已久,但首次提出并应用Gumbel Softmax的是论文《Categorical Reparameterization with Gumbel-Softmax》,这篇论文主要探讨了部分隐变量是离散型变量的变分推断问题,比如基于VAE的半监督学习(方法上有点类似《变分自编码器(四):一步到位的聚类方案》)。其后,在文章《GANS for Sequences of Discrete Elements with the Gumbel-softmax Distribution》中,Gumbel Softmax首次被用在离散序列生成,但还不是文本生成,而是比较简单的人造字符序列。

其后,SeqGAN被提出,自那以后文本GAN模型一直以与强化学习结合的方式出现,基于Gumbel Softmax的纯深度学习和梯度下降的方法相对沉寂,直到RelGAN的出现。RelGAN是ICLR 2019提出的模型,它提出了新型的生成器和判别器结构,使得直接用Gumbel Softmax训练出的文本GAN大幅度超过了以往的各种文本GAN模型。关于RelGAN,我们后面有机会再谈。

总结 #

这部分内容主要介绍的是Gumbel Softmax,它是离散情形下(1)型损失的一个重参数技巧。

理论上来说,离散情形的(1)只是有限项求和,不一定需要重参数。但事实上,“有限”也可能是相当大的数字,因此遍历求和可能难以进行,所以还是要转化为采样形式,从而需要重参数技巧,这就是Gumbel Softmax,源于对Gumbel Max的光滑化。

除了上述视角外,还有一个辅助的视角:Gumbel Softmax通过τ0的退火来逐渐逼近one hot,相比直接用原始的Softmax进行退火,区别在于原始Softmax退火只能得到最大值位置为1的one hot向量,而Gumbel Softmax有概率得到非最大值位置的one hot向量,增加了随机性,会使得基于采样的训练更充分一些。

背后的故事 #

重参数就这样介绍完了吗?远远没有,重参数的背后,实际上是一个称为“梯度估计(gradient estimator)”的大家族,而重参数只不过是这个大家族中的一员。每年的ICLR、ICML等顶会上搜索gradient estimatorREINFORCE等关键词,可以搜索到不少文章,说明这是个大家还在钻研的课题。

要想说清重参数的来龙去脉,也要说些梯度估计的故事。

SF估计 #

前面我们分别讲了连续型和离散型的重参数,都是在“loss层面”讲述的,也就是说都是想办法把loss显式地定义好,剩下的交给框架自动求导、自动优化就是了。而事实上,就算不能显式地写出loss函数,也不妨碍我们对它求导,自然也不妨碍我们去用梯度下降了。比如

(15)θpθ(z)f(z)dz=f(z)θpθ(z)dz=pθ(z)×f(z)pθ(z)θpθ(z)dz=Ezpθ(z)[f(z)pθ(z)θpθ(z)]=Ezpθ(z)[f(z)θlogpθ(z)]

现在我们得到了梯度的一个估计式,称为“SF估计”,全称是Score Function Estimator,这是对原来损失函数的最朴素的估计,在强化学习中z代表着策略,那么上式就是一个最基本的策略梯度,所以有时候也直接称上述估计为叫REINFORCE。要注意,对离散情形的损失函数重新推导一遍,结果也是一样的,也就是说,上述结果是通用的,不区分z是连续变量还是离散变量。现在我们可以直接从pθ(z)中采样若干个点来估算式(15)的值了,不用担心会不会没梯度,因为式(15)本身就是梯度了。

梯度方差 #

看上去很美好,得到了一个连续和离散变量都适用的估计式,那为什么还需要重参数呢?

主要的原因是:SF估计的方差太大。式(15)是函数f(z)θlogpθ(z)在分布pθ(z)下的期望,我们要采样几个点来算(理想情况下,希望只采样一个点),换句话说,我们想用下面的近似

(16)Ezpθ(z)[f(z)θlogpθ(z)]f(z~)θlogpθ(z~),z~pθ(z)

于是问题就来了:这样的梯度估计方差很大。

什么是方差很大?它有什么影响?举个简单的例子,假如α=avg([4,5,6])=avg([0,5,10]),也就是说,我们的目标α是三个数的平均值,这三个数要不就是4,5,6,要不就是0,5,10,在精确估计的情况下,两者是等价的,但是如果每一组只能随机选其中一个数呢?第一组可能选到4,这也没什么,跟准确值5只差一点;但是第二组可能选到0,这跟准确值5差得就有点大了。也就是说,随机选一个的情况下,第二组估计的波动(方差)太大了。类似地,SF估计出来的梯度方差也是如此,这导致了我们用梯度下降优化的时候相当不稳定,非常容易崩。

降方差 #

从形式上看,式(15)是非常漂亮的,本身形式不复杂,而且对离散变量和连续变量都通用,还对f没有特别要求(相反,重参数要求f可导,但是在诸如强化学习的场景下,f(z)对应着奖励函数,很难做到光滑可导)。所以,很多文章探讨基于式(15)降方差技巧,论文《Categorical Reparameterization with Gumbel-Softmax》就列举了一些,近几年来也有一些新发展,总之,还是那句话,大家搜索gradient estimator、REINFORCE等关键词,就有不少文章了。

重参数是另一种降方差技巧,为此,我们写出重参数后的(4)的梯度表达式:

(17)θEεq(ε)[f(gθ(ε))]=Eεq(ε)[θf(gθ(ε))]=Eεq(ε)[fggθ(ε)θ]

对比SF估计的式(15),我们可以直观感知为什么上式方差更小了:

1、SF估计中包含了logpθ(z),我们知道,作为一个合理的概率分布,一般都在无穷远处(即z)都会有pθ(z)0,取了log之后反而会趋于负无穷,换句话说,logpθ(z)这一项实际上放大了无穷远处的波动,从而一定程度上增加了方差;

2、SF估计中包含的是f而重参数之后变成了fgf一般是神经网络,而通常我们定义的神经网络模型其实都是O(z)级别的模型,从而我们可以预期它的梯度是O(1)级别的(不严格成立,只能说在平均意义下基本成立),所以相对情况下更平稳一些,因此f的方差也比fg的方差要大。

鉴于这两个理由,我们就可以得出,一般情况下重参数之后梯度估计的方差会比SF估计要小。注意,这里还是要强调“一般情况”,换言之,“重参数降低梯度估计的方差”这个结论不是绝对成立的,上述两个理由都是在一般情况下(我们面对的多数模型)成立,如果非要较劲,我们总能构造出重参数反而增加方差的例子。

文章小结 #

经过一番长篇大论,我们总算把重参数的故事基本上都捋清楚了。更深入地理解重参数技巧,是更好地理解VAE及文本GAN的必经之路。

从loss层面看,我们需要分连续和离散两种情形:连续情形下,重参数是用采样形式且不损失梯度地写出loss的方法;离散情形下,重参数有着跟连续情形一样的作用,不过更根本的原因是降低计算量(否则直接遍历求和也行)。从梯度估计层面看,重参数是降低梯度估计方差的一种有效手段,而同时还有其他的降低方差手段也被不少学者研究中。

总之,怎么看也不是个让人省心的玩意~

转载到请包括本文地址:https://spaces.ac.cn/archives/6705

更详细的转载事宜请参考:《科学空间FAQ》

如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。

如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!

如果您需要引用本文,请参考:

苏剑林. (Jun. 10, 2019). 《漫谈重参数:从正态分布到Gumbel Softmax 》[Blog post]. Retrieved from https://spaces.ac.cn/archives/6705

@online{kexuefm-6705,
        title={漫谈重参数:从正态分布到Gumbel Softmax},
        author={苏剑林},
        year={2019},
        month={Jun},
        url={\url{https://spaces.ac.cn/archives/6705}},
}