Chainer の ptb サンプルで遊んでみる
AWS の GPU環境をなんとか整えたので、RNN で遊んでみようと思い、Chainer の ptb サンプルを試しに動かしてみました。 ptb サンプルでは、入力された単語系列(文章)を元に、次の単語を推論する構造で、RNNのよくあるモデリングになっています。 ちなみに前回までに入力された単語(ようするに文脈)を覚えておく構造にはLSTMという特殊な層を利用しています。 また、入力層と出力層にはそれぞれの単語に相当するユニットが存在します。語彙が一万個あればそれぞれ一万ユニット必要で、学習時にユニットが割り当てられていない単語については入力も出力も行うことができません。
このRNNの学習はなかなかに遅く、今回は epoch5 くらいまで学習させてみましたが、GPU機で軽く2時間はかかりました。 g2.xlarge インスタンスは1時間100円するので、学習が終わったら自動で停止してくれるギミックを作っておかないとオチオチ夜も寝られません。
今回は、学習はリモートのGPU環境、テストデータの結果出力はローカルのCPU環境で行いました。 サンプルコードは Chainer 1.7.1 に付属のものを使用しています。 ちなみに、結果出力用のコードは下のような感じです。
import numpy as np import six import chainer from chainer import cuda import chainer.links as L from chainer import optimizers from chainer import serializers import net def load_data(filename): global vocab global inv_vocab words = open(filename).read().replace('\n', '<eos>').strip().split() dataset = np.ndarray((len(words),), dtype=np.int32) for i, word in enumerate(words): if word not in vocab: vocab[word] = len(vocab) inv_vocab[len(vocab)-1]=word # 単語逆引き用辞書 dataset[i] = vocab[word] return dataset vocab = {} inv_vocab = {} train_data = load_data('ptb.train.txt') train_data = load_data('ptb.train.txt') test_data = load_data('ptb.test.txt') print('#vocab =', len(vocab)) # 学習済みモデル取り込み model = L.Classifier(net.RNNLM(len(vocab), 650)) serializers.load_npz('rnnlm.model', model) model.predictor.reset_state() # LSTM層の初期化 for i in six.moves.range(10): data = test_data[i:i+1] print(inv_vocab[data[0]]) x = chainer.Variable(data, volatile=False) h0 = model.predictor.embed(x) h1 = model.predictor.l1(h0) h2 = model.predictor.l2(h1) y = model.predictor.l3(h2) # 出力結果のうち上位3個を出力 prediction = list(zip(y.data[0].tolist(), inv_vocab.values())) prediction.sort() prediction.reverse() for i, (score, word) in enumerate(prediction): print('{},{}'.format(word, score)) if i >= 2: break
上で書いたように、RNNの学習はそれまでにどんな値を入力されたかどうかで変化します。 今回は、サンプルのテスト用データ内の「no it was n't black monday」という文を使用して、毎回LSTM層のパラメーターをリセットした文脈を無視した場合と、最初の単語の時のみLSTM層のパラメータをリセットした文脈を考慮した場合でそれぞれ結果を比べてみました。
- 文脈無視
no longer,10.33497428894043 doubt,8.990683555603027 way,8.030651092529297 it is,10.528656005859375 will,9.515053749084473 has,9.51209831237793 was a,8.827898979187012 <unk>,8.493036270141602 the,7.862812042236328 n't <unk>,7.8151350021362305 yet,6.953648090362549 <eos>,6.921860694885254 black <unk>,9.096553802490234 and,8.121747016906738 <eos>,7.601911544799805 monday <eos>,10.35195541381836 's,9.545741081237793 and,8.740907669067383 <eos> the,9.107492446899414 <unk>,7.814574241638184 but,7.787456512451172
<eos> は End Of State の略で文の終わりを意味します。<unk> は Unknown の略で未知語を意味し、今回学習に用いたサンプルデータでは、特定の人名に相当します。 最初の単語の no を見てみると no longer, no doubt, no way など慣用句的に使用されるような単語が出力されており、うまいこと学習できていることがわかります。他の単語を見ても、文章のどこかには出てきそうな組み合わせばかりで、次の単語を推測するというタスクの結果としては悪くないもののように思えます。
- 文脈考慮
no longer,10.33497428894043 doubt,8.990683555603027 way,8.030651092529297 it is,9.707919120788574 <eos>,9.36506462097168 will,9.056198120117188 was a,8.486732482910156 <unk>,8.168980598449707 the,8.025396347045898 n't a,8.514782905578613 the,8.452549934387207 clear,8.159638404846191 black <eos>,12.180181503295898 in,10.142268180847168 to,9.996685028076172 monday <eos>,11.57198715209961 and,9.793222427368164 in,9.25011920928955 <eos> the,9.041053771972656 but,8.242995262145996 mr.,8.005669593811035
さて今度は、以前までの単語系列を考慮した場合の学習結果です。特筆すべきは n't の次の単語を推測した結果で、文脈を考慮しなかった場合は、<unk>, yet, <eos> が推測結果でしたが、考慮した場合は、a, the, clear と was n't に続きそうな単語が推測されていることが分かります。(yet や <eos> は wasn't の次にはあまりこないように思われます。) なんとなくですが、文脈を考慮して結果が出てきてるような感じですね。モデルやデータセット、学習量によってどれくらい結果が変わるかも調べてみると面白そうです。