漫谈重参数:从正态分布到Gumbel Softmax
By 苏剑林 | 2019-06-10 | 124872位读者 |最近在用VAE处理一些文本问题的时候遇到了对离散形式的后验分布求期望的问题,于是沿着“离散分布 + 重参数”这个思路一直搜索下去,最后搜到了Gumbel Softmax,从对Gumbel Softmax的学习过程中,把重参数的相关内容都捋了一遍,还学到一些梯度估计的新知识,遂记录在此。
文章从连续情形出发开始介绍重参数,主要的例子是正态分布的重参数;然后引入离散分布的重参数,这就涉及到了Gumbel Softmax,包括Gumbel Softmax的一些证明和讨论;最后再讲讲重参数背后的一些故事,这主要跟梯度估计有关。
基本概念 #
重参数(Reparameterization)实际上是处理如下期望形式的目标函数的一种技巧:
这样的目标在VAE中会出现,在文本GAN也会出现,在强化学习中也会出现(对应于奖励函数),所以深究下去,我们会经常碰到这样的目标函数。取决于的连续性,它对应不同的形式:
当然,离散情况下我们更喜欢将记号换成或者。
为了最小化,我们就需要把明确地写出来,这意味着我们要实现从中采样,而是带有参数的,如果直接采样的话,那么就失去了的信息(梯度),从而无法更新参数。而Reparameterization则是提供了这样的一种变换,使得我们可以直接从中采样,并且保留的梯度。(注:如果考虑最一般的形式,那么也应该带上参数,但这没有增加本质难度。)
连续情形 #
简单起见,我们先考虑连续情形
其中是具有显式概率密度表达式的分布,在变分自编码器中常见的是正态分布。
形式 #
从式中知道,连续情形的实际上就对应一个积分,所以,为了明确写出,有两种途径:最直接的方式是精确地完成积分,得到显式表达式,但这通常都是不可能的了;所以,唯一的办法是转化为采样形式,并试图在采样过程中保留的梯度。
重参数就是这样的一种技巧,它假设从分布中采样可以分解为两个步骤:(1) 从无参数分布中采样一个;(2) 通过变换生成。那么,式就变成了
这时候被采样的分布就没有任何参数了,全部被转移到内部了,因此可以采样若干个点,当成普通的loss那样写下来了。
例子 #
一个最简单的例子就是正态分布:对于正态分布来说,重参数就是“从中采样一个”变成“从中采样一个,然后计算”,所以
如何理解直接采样没有梯度而重参数之后就有梯度呢?其实很简单,比如我说从中采样一个数来,然后你跟我说采样到5,我完全看不出5跟有什么关系呀(求梯度只能为0);但是如果先从中采样一个数比如,然后计算,这样我就知道采样出来的结果跟的关系了(能求出有效的梯度)。
总结 #
让我们把前面的内容重新整理一下。总的来说,连续情形的重参数还是比较简单的:连续情形下,我们要处理的实际上是式,由于精确的积分我们没有办法显式地写出来,所以需要转化为采样,而为了在采样的过程中得到有效的梯度,我们就需要重参数。
从数学本质来看,重参数是一种积分变换,即原来是关于积分,通过变换之后得到新的积分形式,
离散情形 #
为了突出“离散”,我们将随机变量换成,即对于离散情形要面对的目标函数是
其中离散意味着一般情况是可枚举的,换句话说此时是一个分类模型:
其中各个是的函数。
分析 #
读者看到中的求和,第一反应可能是“求和?那就求呗,又不是求不了。”。
的确,这也是笔者当时看到它的第一反应。与连续情形的不一样,式如果直接硬杠的话需要完成积分(也可以看成无穷多个点的求和),我们没法做到这一点。但是对于离散的,只不过是有限项求和,理论上确实可以直接完成求和再去梯度下降。
但是,如果特别大呢?举个例子,假设是一个100维的向量,每个元素不是0就是1(二元变量),那么所有不同的的总数目就是,要对这样的个单项进行求和,计算量是难以接受的;还有一个典型的例子是seq2seq的解码端(如果要做文本GAN就需要面对它),它的类别总数目是,其中是词表大小而是句子长度。这样的情况下,直接完成精确的求和都是难以实现的。
形式 #
所以,还是需要回到采样上去,如果能够采样若干个点就能得到的有效估计,并且还不损失梯度信息,那自然是最好了。为此,需要先引入Gumbel Max,它提供了一种从类别分布中采样的方法。
假设每个类别的概率是,那么下述过程提供了一种依概率采样类别的方案,称为Gumbel Max:
也就是说,先算出各个概率的对数,然后从均匀分布中采样个随机数,把加到上去,最后把最大值对应的类别抽取出来就行了。
后面我们会证明,这样的过程精确等价于依概率采样一个类别,换句话说,在Gumbel Max中,输出的概率正好是。由于现在的随机性已经转移到上去了,并且不带有未知参数,因此Gumbel Max就是离散分布的一个重参数过程。
但是,我们希望重参数不丢失梯度信息,但是Gumbel Max做不到,因为不可导,为此,需要做进一步的近似。首先,留意到在神经网络中,处理离散输入的基本方法是转化为one hot形式,包括Embedding层的本质也是one hot全连接(参考《词向量与Embedding究竟是怎么回事?》),因此实际上是,然后,我们寻求的光滑近似,它就是(参考《函数光滑化杂谈:不可导函数的可导逼近》)。
由此,我们得到Gumbel Max的光滑近似版本——Gumbel Softmax:
其中参数称为退火参数,它越小输出结果就越接近one hot形式(但同时梯度消失就越严重)。提示一个小技巧,如果是softmax的输出,即的形式,那么大可不必先算出再取对数,直接将替换为即可:
Gumbel Max的证明:
Gumbel Max的形式看上去有点复杂,远没有正态分布的重参数简单,但事实上只要鼓起勇气去看它,它连证明都不困难。我们想要证明Gumbel Max最后输出的概率是,不失一般性,这里我们证明输出1的概率是。
注意,输出1意味着是最大的,这又意味着:
注意每个不等式都是独立的,也就是说与的关系如何,也不影响它跟的关系。这样我们只需要单独分析每一个不等式的概率。不失一般性,我们只分析第一个不等式,化简后得到:
由于,所以的概率就是,这就是固定的情况下,第一个不等式成立的概率。那么,所有不等式同时成立的概率是
然后对所有求平均,就是
这就是类别1出现的概率,它就是。至此,我们完成了Gumbel Max采样过程的证明。
例子 #
跟连续情形一样,Gumbel Softmax就是用在需要求、且无法直接完成对求和的场景,这时候我们算出(或者),然后选定一个,用Gumbel Softmax算出一个随机向量来,代入计算得到,它就是的一个好的近似,且保留了梯度信息。
注意,Gumbel Softmax不是类别采样的等价形式,Gumbel Max才是。而Gumbel Max可以看成是Gumbel Softmax在时的极限。所以在应用Gumbel Softmax时,开始可以选择较大的(比如1),然后慢慢退火到一个接近于0的数(比如0.01),这样才能得到比较好的结果。
下面提供一个自己实现的离散隐变量的VAE例子:
https://github.com/bojone/vae/blob/master/vae_keras_cnn_gs.py
效果图:
溯源 #
Gumbel Max由来已久,但首次提出并应用Gumbel Softmax的是论文《Categorical Reparameterization with Gumbel-Softmax》,这篇论文主要探讨了部分隐变量是离散型变量的变分推断问题,比如基于VAE的半监督学习(方法上有点类似《变分自编码器(四):一步到位的聚类方案》)。其后,在文章《GANS for Sequences of Discrete Elements with the Gumbel-softmax Distribution》中,Gumbel Softmax首次被用在离散序列生成,但还不是文本生成,而是比较简单的人造字符序列。
其后,SeqGAN被提出,自那以后文本GAN模型一直以与强化学习结合的方式出现,基于Gumbel Softmax的纯深度学习和梯度下降的方法相对沉寂,直到RelGAN的出现。RelGAN是ICLR 2019提出的模型,它提出了新型的生成器和判别器结构,使得直接用Gumbel Softmax训练出的文本GAN大幅度超过了以往的各种文本GAN模型。关于RelGAN,我们后面有机会再谈。
总结 #
这部分内容主要介绍的是Gumbel Softmax,它是离散情形下型损失的一个重参数技巧。
理论上来说,离散情形的只是有限项求和,不一定需要重参数。但事实上,“有限”也可能是相当大的数字,因此遍历求和可能难以进行,所以还是要转化为采样形式,从而需要重参数技巧,这就是Gumbel Softmax,源于对Gumbel Max的光滑化。
除了上述视角外,还有一个辅助的视角:Gumbel Softmax通过的退火来逐渐逼近one hot,相比直接用原始的Softmax进行退火,区别在于原始Softmax退火只能得到最大值位置为1的one hot向量,而Gumbel Softmax有概率得到非最大值位置的one hot向量,增加了随机性,会使得基于采样的训练更充分一些。
背后的故事 #
重参数就这样介绍完了吗?远远没有,重参数的背后,实际上是一个称为“梯度估计(gradient estimator)”的大家族,而重参数只不过是这个大家族中的一员。每年的ICLR、ICML等顶会上搜索gradient estimator、REINFORCE等关键词,可以搜索到不少文章,说明这是个大家还在钻研的课题。
要想说清重参数的来龙去脉,也要说些梯度估计的故事。
SF估计 #
前面我们分别讲了连续型和离散型的重参数,都是在“loss层面”讲述的,也就是说都是想办法把loss显式地定义好,剩下的交给框架自动求导、自动优化就是了。而事实上,就算不能显式地写出loss函数,也不妨碍我们对它求导,自然也不妨碍我们去用梯度下降了。比如
现在我们得到了梯度的一个估计式,称为“SF估计”,全称是Score Function Estimator,这是对原来损失函数的最朴素的估计,在强化学习中代表着策略,那么上式就是一个最基本的策略梯度,所以有时候也直接称上述估计为叫REINFORCE。要注意,对离散情形的损失函数重新推导一遍,结果也是一样的,也就是说,上述结果是通用的,不区分是连续变量还是离散变量。现在我们可以直接从中采样若干个点来估算式的值了,不用担心会不会没梯度,因为式本身就是梯度了。
梯度方差 #
看上去很美好,得到了一个连续和离散变量都适用的估计式,那为什么还需要重参数呢?
主要的原因是:SF估计的方差太大。式是函数在分布下的期望,我们要采样几个点来算(理想情况下,希望只采样一个点),换句话说,我们想用下面的近似
于是问题就来了:这样的梯度估计方差很大。
什么是方差很大?它有什么影响?举个简单的例子,假如,也就是说,我们的目标是三个数的平均值,这三个数要不就是,要不就是,在精确估计的情况下,两者是等价的,但是如果每一组只能随机选其中一个数呢?第一组可能选到4,这也没什么,跟准确值5只差一点;但是第二组可能选到0,这跟准确值5差得就有点大了。也就是说,随机选一个的情况下,第二组估计的波动(方差)太大了。类似地,SF估计出来的梯度方差也是如此,这导致了我们用梯度下降优化的时候相当不稳定,非常容易崩。
降方差 #
从形式上看,式是非常漂亮的,本身形式不复杂,而且对离散变量和连续变量都通用,还对没有特别要求(相反,重参数要求可导,但是在诸如强化学习的场景下,对应着奖励函数,很难做到光滑可导)。所以,很多文章探讨基于式的降方差技巧,论文《Categorical Reparameterization with Gumbel-Softmax》就列举了一些,近几年来也有一些新发展,总之,还是那句话,大家搜索gradient estimator、REINFORCE等关键词,就有不少文章了。
重参数是另一种降方差技巧,为此,我们写出重参数后的的梯度表达式:
对比SF估计的式,我们可以直观感知为什么上式方差更小了:
1、SF估计中包含了,我们知道,作为一个合理的概率分布,一般都在无穷远处(即)都会有,取了之后反而会趋于负无穷,换句话说,这一项实际上放大了无穷远处的波动,从而一定程度上增加了方差;
2、SF估计中包含的是而重参数之后变成了,一般是神经网络,而通常我们定义的神经网络模型其实都是级别的模型,从而我们可以预期它的梯度是级别的(不严格成立,只能说在平均意义下基本成立),所以相对情况下更平稳一些,因此的方差也比的方差要大。
鉴于这两个理由,我们就可以得出,一般情况下重参数之后梯度估计的方差会比SF估计要小。注意,这里还是要强调“一般情况”,换言之,“重参数降低梯度估计的方差”这个结论不是绝对成立的,上述两个理由都是在一般情况下(我们面对的多数模型)成立,如果非要较劲,我们总能构造出重参数反而增加方差的例子。
文章小结 #
经过一番长篇大论,我们总算把重参数的故事基本上都捋清楚了。更深入地理解重参数技巧,是更好地理解VAE及文本GAN的必经之路。
从loss层面看,我们需要分连续和离散两种情形:连续情形下,重参数是用采样形式且不损失梯度地写出loss的方法;离散情形下,重参数有着跟连续情形一样的作用,不过更根本的原因是降低计算量(否则直接遍历求和也行)。从梯度估计层面看,重参数是降低梯度估计方差的一种有效手段,而同时还有其他的降低方差手段也被不少学者研究中。
总之,怎么看也不是个让人省心的玩意~
转载到请包括本文地址:https://spaces.ac.cn/archives/6705
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (Jun. 10, 2019). 《漫谈重参数:从正态分布到Gumbel Softmax 》[Blog post]. Retrieved from https://spaces.ac.cn/archives/6705
@online{kexuefm-6705,
title={漫谈重参数:从正态分布到Gumbel Softmax},
author={苏剑林},
year={2019},
month={Jun},
url={\url{https://spaces.ac.cn/archives/6705}},
}
December 21st, 2021
您好,我看到有些文献中有这样的表述:当隐变量为离散变量时无法使用重参数化,是发表在NIPS2016上的文章《The Concrete Distribution:A Continuous Relaxation of Discrete Random Variables》的2.3节
文章链接是 http://bayesiandeeplearning.org/2016/papers/BDL_31.pdf
我不是很理解这里的隐变量表示什么。
隐变量就是求和的对象吧。我觉得它这里想表达的意思是“当隐变量为离散变量时无法使用(精确的)重参数化”,因为哪怕是本文的Gumbel Softmax,只能算是一个可用的近似,不算精确的等价结果。
感谢您的回复,我继续思考之后的理解是这样的。N(μ,σ2) 时,可以精确地进行重参数化:
Gumbel Softmax 是一个针对类别分布的重参数化技巧,用于近似离散的类别分布;而当原分布是连续分布,例如最常见的例子
不过我这里新的问题又产生了,重参数化的目的在于保持梯度信息,那么理论上是否可以使用一个离散分布
十分感谢您的回复。
你没搞清楚问题所在。
重参数本质上,就是希望将待训练参数从分布转移到样本,那么求参数的梯度就可以通过对样本求梯度来完成。对于“离散”的分布来说,样本是什么?一个整数可以是离散的样本,一只猫一只狗也可以是离散的样本,这些怎么求导?转化为神经网络可以处理的输入,那么就变成了one hot向量,one hot可以看成是分段常数函数,其导数是0。
所以说,只要你做到精确的离散采样,那么样本的梯度就必然是0;如果你要重参数,就必然不可能做到精确的离散采样。
十分感谢 我明白我的疑惑根源在哪里了 我上面的问题表述和自己的想法也存在偏差
February 9th, 2022
苏神,看到您在评论区提到GAN不适用于文本生成,这是为什么呢?
因为GAN一般情况下只适用于连续型对象生成,而文本是离散的。
April 10th, 2022
苏神你好
假设我有一个长度为200的sequence,然后sequence中的每个向量都是5维的(softmax过,实际上就是代表着类别1,2,3,4,5的概率),那么就是200*5的一个矩阵,那么这很明显是离散情况的,如果按照(6)式求和,有5^200种可能性,这是一个很大的数字,所以我需要进行蒙特卡洛采样,那么为了保证方差小,就需要Gumbel-softmax进行随机性的转移
那么我想问:我采样完成后还是一个200*5的矩阵(添加了Gumbel的随机性的softmax向量集合),那么请问我最后的类别该如何取呢,毕竟我想要的是200*1的类别向量,还有就是即使取到了200*1的类别,那请问这个概率是多少呢?
感谢苏神解答
如果你是用神经网络处理,那么你的类别应该是用自己的Embedding来表征的,所以可以用200*5的采样矩阵来对类Embedding进行求和。这一步一般是避免不了的。
April 29th, 2022
[...]漫谈重参数:从正态分布到Gumbel Softmax[...]
July 11th, 2022
您好,在您证明Gumbel Max最后的那里,有一个对所有ε1求平均的过程,就是随后的从0到1的积分。可您前面说被积函数的那个式子就是输出1的概率,我没有搞懂这个积分的意义,想在这里请教一下您。感谢啦!
因为前面的讨论是对于固定的ε1 来进行了,得到的结果也是\varpesilon1 的函数。事实上\varpesilon1 也是随机的,所以要对它进行平均一下。
September 7th, 2022
苏神您好,您代码中实现为Gi=−log(−log(Ui)),Ui∈U(0,1) ,但是我发现pytorch的实现直接从指数分布采样,代码地址为:https://github.com/pytorch/pytorch/blob/b136f3f310aa01a8b3c1e63dc0bfda8fd2234b06/torch/nn/functional.py#L1892
这是为何呢?
将−logε,ε∼U[0,1] 看成一个整体,它服从指数分布。
September 28th, 2022
[...]πθ(a~∣s) 的具体形式, 比如高斯策略等, 可以用再参数化技巧来处理:[...]
October 18th, 2022
您好 我基础比较薄弱些 想再请教一下:
例如discrete VAE中是要优化ELBO:
Eq(z|x)[log p(z)q(z|x)]
q(z|x)我们是无法求出的 所以使用GumbelSoftmax进行采样 假设z是一个3维的10元变量 那么GumbelSoftmax后得到的向量维度应该也是3*10的 可是我要怎么将这个采样值代回上面的ELBO计算呢 是把GumbelSoftmax后得到的向量中z的每个维度对应最大概率的类别找到 然后将找到的这10个类别对应的概率进行如下计算 最后相加嘛
log p(zi)q(zi|x)
这一点我一直不是很清楚 希望您能不吝赐教 非常感激!!
一般有几种思路,但绝不可能是你这种思路,因为你说的“对应最大概率的类别找到”这个操作,就又把梯度断掉了,前面的努力都白费了。
首先,我们要清楚,我们平时是怎么处理离散信号输入的。像NLP中的“词”,我们是通过转换为整数id,然后传入Embedding层来构建特征的对吧。而事实上,这种“整数id+Embedding”的操作,只是一个简化的表象,它们本质都是“one hot + 全连接”,参考:https://kexue.fm/archives/4122
也就是说,所有处理离散特征的方式,基本上都是将它转化为one hot,然后加全连接。而对于GumbelSoftmax来说,采样出来的结果是一个非one hot的分布,后面也可以直接加同样的全连接处理。
感谢您的回答!!
May 10th, 2023
[...]【参考文章】线性回归1线性回归2线性回归3逻辑回归生成模型和判别模型生成模型和判别模型k均值聚类k均值聚类变分变分重参数化-推荐重参数化英文解释checkpointcheckpointcheckpoint数据泄露数据泄露检测和修复ELBO证据下界PRPRPRROC/AUCROC/AUC[...]