0%

Professor forcing

Professor forcing

引入前的想法

teacher forcing的不足

  • teacher forcing的方法通过将被观测的序列值作为训练过程中的输入和使用该网络自己的提前一步预测(one-step-ahead-predictions)l来进行多步采样

    • 比如时间序列的条件分布模型:

      • $$
        P(y_1,y_2,…,y_T)=P(y_1)\prod_{t=1}^T P(y_t|y_1,…,y_{t-1})
        $$

      • 这种形式一般的机器学习的训练策略就是最大似然,而在RNN中,这种策略可以类似为teacher forcing,由于使用真值样本,将其反馈到模型中,以便对后一时刻输出进行预测。这种反馈迫使(force)RNN接近真实的序列

不过:teacher forcing有个问题就是,在测试的时候不能使用真值样本,而只能根据前面的采样,但是这样的条件环境优又会有所不同。比如在一段时间序列的预测中,一个地方出错,就容易产生大偏差

professor forcing的想法

  • 所以,Professor Forcing的目标就是:使free running(或说self feeding)和teacher forcing行为尽可能的接近。对于接近,抛开概率的KL,那么就是GAN了。
    • 用GAN一点的话就是:我们在生成RNN与训练数据匹配的同时,我们还希望网络的行为(无论是输出中还是隐藏状态的动态中)能无法区分其输入是限制在teacher forcing的还是free running的

模型

这里的professor forcing使用对抗域适应(adversarial domain adaptation)来促进训练网络的动态变化相同

在这里插入图片描述

判别器训练

对D(d)最大化下面这个
$$
C_d(\theta_d|\theta_g)=E_{(x,y)\sim data}[-\log D(B(x,y,\theta_g),\theta_d)+E_{y\sim P_{\theta_g(y|x)}}[-\log (1-D(B(x,y,\theta_g)),\theta_d)]
$$

  • D就是判别属于哪种,是概率分类器
  • B(x,y,$\theta_g$)是给定数据(x是训练数据,y是训练数据或是自生成的)输出序列,==注意==:这里的输出序列包含hidden state和output

生成器训练

最小化下列的东西

生成器参数$\theta_g$按两种方式训练

  • 1、最大化数据概率(正常RNN的loss部分)

    • RNN通常的teacher forcing训练标准:负对数似然

      • $$
        NNL(\theta_g)=E_{(x,y)\sim data}[-log_{\theta_g}(y|x)]
        $$
  • 2、愚弄判别器(而第二种方式有两个变体,。。。其实就是把GAN的式子拆开了)

    • 1、训练目标只是试图改变free-running的行为,为了更好地匹配teacher forcing

      • $$
        C_f(\theta_g|\theta_d)=E_{x\sim data,y\sim P_{\theta_g}(y|x)}[-\log D(B(x,y,\theta_g),\theta_d)]
        $$
    • 2、可以让teacher forcing 与free running无法区分

      • $$
        C_t(\theta_g|\theta_d)=E_{(x,y)\sim data}[-\log (1-D(B(x,y,\theta_g),\theta_d))]
        $$

实现中可选$NLL+C_f$或$NLL+C_f+C_y$

为什么要匹配hidden state(Ethan Caballero)

  • 在采样的输出(argmax)上使用GAN明显难得多,因为这些是离散的,如hidden state和各自的softmax。
    • 所以不得不使用seqGAN中的策略梯度来估计离散输出
    • 所以要融合hidden state和output(无论如何,hidden state已经包含了有关离散采样输出的信息(分布中最高的概率索引))
  • professor forcing的独特之处就是:可以在,试图将两者推近的、两个序列生成模式的每个时间步上访问每个输出的连续概率分布。
    • 相反,像seqGAN这样的模型中传统上将GAN应用到逼近真实样本和生成的样本时,在每个时间步只能访问下一个离散输出(而不是下一个输出的连续分布),这会阻止直接进行区分(用于professor forcing),迫使人们使用策略梯度估计。如果有人用连续分布的词嵌入(从预训练的word2vec,GloVe等中)替换掉每个离散的采样令牌,则有可能在传统采样情况下使用直截了当的训练来训练seqGAN

实验

  • 按照人为评估,该法在笔迹生成、音乐合成方面好
  • 比起teacher forcing能生成更长的序列,长期依赖性的字符建模中表现较好
  • 作者使用t-SNE展示使用professor forcing是的隐藏状态分布上变得更加相似