読者です 読者をやめる 読者になる 読者になる

終末 A.I.

Deep Learning を中心に、週末に機械学習するエンジニアのブログ

テキスト生成モデル -SeqGAN-

Deep Learning 機械学習 自然言語処理 TensorFlow

この記事は、DeepLearning Advent Calendar 2016の20日目です。

今回は、時系列データに GAN の手法を適用した SeqGAN をご紹介したいと思います。SeqGAN は分かりやすく時系列データに GAN を適用しているためアルゴリズムが理解しやすく、公式の TensorFlow コードもあるので試しに動かしてみたい方にオススメできる手法です。

SeqGAN では、GAN と同じく生成モデルと識別モデルの両方を用いて生成モデルを学習させていきますが、時系列モデルに適用するにあたり( { Y_{1:t-1} } から { y_t } を求める生成モデルを作成するにあたり)、下図のように { y_{t+1} } 以降をモンテカルロ法により生成し、その結果も含めて本物か生成したものかを識別モデルに判定させることにより、{ Y_{1:t-1} } から { y_t } を求めた結果の評価値を決定します。

f:id:KSKSKSKS2:20161218112745p:plain

※ 論文から抜粋

この部分を式で表すと、{ G_{\theta} } を生成モデル、 { D_{\phi} } を識別モデル、{ Q } をリワード関数とすると、生成モデルの誤差関数は下記の用に表すことができます。

{ \displaystyle
J_{\theta} = \sum G_{\theta}(y_t|Y_{1:t-1}) Q^{G_{\theta}}_{D_{\phi}}(Y_{1:t-1}, y_t)
}

{ \displaystyle
Q^{G_{\theta}}_{D_{\phi}} = \frac{1}{N} \sum D_{\phi}(Y^{n}_{1:T}), Y^{n}_{1:T} \in MC^{G_{\beta}}(Y_{1:T};N)
}

誤差関数は、{ Y_{1:t-1} } から { y_t } を生成される最もらしさをリワード関数で表し、生成モデルにより{ Y_{1:t-1} } から { y_t } が生成される確率のもとに期待値をとった値となります。

リワード関数は、{ Y_{1:t-1} } から { y_t } を生成した後、{ y_{t+1} } 以降をパラメーター({\beta})を更新する前の生成モデルから、マルコフ連鎖により複数生成し、その生成結果を識別モデルで評価した平均を返します。

論文中では、生成モデルは LSTM を利用した RNN で、識別モデルは CNN で作成しています。

評価結果としては、最尤法で最適化した言語モデルで作成したテキストより、人間評価および oracle による評価でも改善することができました(oracle 評価に関しては学習曲線がかなり特殊なため再現性があるのかは微妙なのですが)。

テキスト生成をうまいことやるためには2つも3つも壁を超えなければならないように思いますが、時系列データの生成モデルとしては一つの参考となるモデルになるのではないでしょうか。