Navigation

    Gpushare.com

    • Register
    • Login
    • Search
    • Popular
    • Categories
    • Recent
    • Tags

    Gumbel-Softmax完全解析

    语音识别与语义处理领域
    1
    1
    25
    Loading More Posts
    • Oldest to Newest
    • Newest to Oldest
    • Most Votes
    Reply
    • Reply as topic
    Log in to reply
    This topic has been deleted. Only users with topic management privileges can see it.
    • 155****7220
      155****7220 last edited by 155****7220

      写在前面

      本文对大部分人来说可能仅仅起到科普的作用,因为Gumbel-Max仅在部分领域会用到,例如GAN、VAE等。笔者是在研究EMNLP上的一篇论文时,看到其中有用Gumbel-Softmax公式解决对一个概率分布进行采样无法求导的问题,故想到对Gumbel-Softmax做一个总结,由此写下本文

      为什么我们需要Gumbel-Softmax ?

      假设现在我们有一个离散随机变量ZZZ的分布
      a1ba5b05-1347-40ec-860b-2e020ccd07c6-image.png
      其中,∑iπi=1\sum_i \pi_i=1∑i​πi​=1。我们想根据p1,p2,…,pxp_1,p_2,…,p_xp1​,p2​,…,px​的概率采样得到一系列离散zzz的值。但是这么做有一个问题,我们采样出来的zzz只有值,没有生成zzz的式子。例如我们要求ZZZ的期望,那么就有公式
      4a5519dd-59e1-4a49-b3cd-2ec9cd5a52db-image.png
      ZZZ对p1,p2,…,pxp_1,p_2,…,p_xp1​,p2​,…,px​的导数都很清楚。但是现在我们的需求是采样一些具体的zzz值,采样这个操作没有任何公式,因此也就无法求导。于是一个很自然的想法就产生了,我们能不能给一个以p1,p2,…,pzp_1,p_2,…,p_zp1​,p2​,…,pz​为参数的公式,让这个公式返回的结果是zzz采样的结果呢?

      Gumbel-Softmax

      一般来说πi\pi_iπi​是通过神经网络预测对于类别iii的概率,这在分类问题中非常常见,假设我们将一个样本送入模型,最后输出的概率分布为[0.2,0.4,0.1,0.2,0.1][0.2, 0.4,0.1,0.2,0.1][0.2,0.4,0.1,0.2,0.1],表明这是一个5分类问题,其中概率最大的是第2类,到这一步,我们直接通过argmax就能获得结果了,但现在我们不是预测问题,而是一个采样问题。对于模型来说,直接取出概率最大的就可以了,但对我们来说,每个类别都是有一定概率的,我们想根据这个概率来进行采样,而不是直接简单无脑的输出概率最大的值

      最常见的采样z\mathbf{z}z的onehot公式为

      9d43fa7d-8fda-491a-b171-f759e4b962f1-image.png

      其中i=1,2,…,xi=1,2,…,xi=1,2,…,x是类别的下标,随机变量uuu服从均匀分布U(0,1)U(0,1)U(0,1)

      上面这个过程实际上是很巧妙的,我们将概率分布从前往后不断加起来,当加到πi\pi_iπi​时超过了某个随机值0≤u≤10\leq u \leq 1 0≤u≤1,那么这一次随机采样过程,zzz就被随机采样为第iii类,最后通过一个onehot变换

      但是上述公式存在一个致命的问题:max函数是不可导的

      Gumbel-Max Trick

      Gumbel-Max技巧就是解决max函数不可导问题的,我们可以用argmax替换max,即

      30d54dea-68d8-4678-b004-1f5a3656284a-image.png

      其中,gi=−log⁡(−log⁡(ui)),ui∼U(0,1)g_i=-\log(-\log(u_i)), u_i \sim U(0,1)gi​=−log(−log(ui​)),ui​∼U(0,1),这一项名为Gumbel噪声,或者叫Gumbel分布,目的是使得z\mathbf{z}z的返回结果不固定

      可以看到式(2)(2)(2)的整个过程中,不可导的部分只有argmax,实际上我们可以用可导的softmax函数,在参数τ\tauτ的控制下逼近argmax,最终ziz_izi​的公式为

      5a7f5be3-e698-4305-86e6-022a1f7a3844-image.png

      其中,τ\tauτ越小(τ→0)(\tau \to 0)(τ→0),整个softmax越光滑逼近argmax,并且z=zi∣i=1,2,…,x\mathbf{z} = {z_i\mid i=1,2,…,x}z=zi​∣i=1,2,…,x也越接近onehot向量;τ\tauτ越大(τ→∞)(\tau \to \infty)(τ→∞),z\mathbf{z}z向量越接近于均匀分布

      总结

      整个过程相当于我们把不可导的取样过程,从z\mathbf{z}z本身转移到了求z\mathbf{z}z的公式中的一项gig_igi​中,而gig_igi​本身不依赖p1,…,pxp_1,…,p_xp1​,…,px​,所以zzz对p1,…,pxp_1,…,p_xp1​,…,px​就可以到了,而且我们得到的z\mathbf{z}z仍然是离散概率分布的采样。这种采样过程转嫁的技巧有一个专有名词,叫重参数化技巧(Reparameterization Trick)

      References

      • What is Gumbel-Softmax
      • Gumbel-Softmax Trick和Gumbel分布
      1 Reply Last reply Reply Quote 2
      • First post
        Last post