終末 A.I.

データいじりや機械学習するエンジニアのブログ

Chainer の ptb サンプルで遊んでみる

AWSGPU環境をなんとか整えたので、RNN で遊んでみようと思い、Chainer の ptb サンプルを試しに動かしてみました。 ptb サンプルでは、入力された単語系列(文章)を元に、次の単語を推論する構造で、RNNのよくあるモデリングになっています。 ちなみに前回までに入力された単語(ようするに文脈)を覚えておく構造にはLSTMという特殊な層を利用しています。 また、入力層と出力層にはそれぞれの単語に相当するユニットが存在します。語彙が一万個あればそれぞれ一万ユニット必要で、学習時にユニットが割り当てられていない単語については入力も出力も行うことができません。

このRNNの学習はなかなかに遅く、今回は epoch5 くらいまで学習させてみましたが、GPU機で軽く2時間はかかりました。 g2.xlarge インスタンスは1時間100円するので、学習が終わったら自動で停止してくれるギミックを作っておかないとオチオチ夜も寝られません。

今回は、学習はリモートのGPU環境、テストデータの結果出力はローカルのCPU環境で行いました。 サンプルコードは Chainer 1.7.1 に付属のものを使用しています。 ちなみに、結果出力用のコードは下のような感じです。

import numpy as np
import six

import chainer
from chainer import cuda
import chainer.links as L
from chainer import optimizers
from chainer import serializers

import net


def load_data(filename):
    global vocab
    global inv_vocab
    words = open(filename).read().replace('\n', '<eos>').strip().split()
    dataset = np.ndarray((len(words),), dtype=np.int32)
    for i, word in enumerate(words):
        if word not in vocab:
            vocab[word] = len(vocab)
            inv_vocab[len(vocab)-1]=word # 単語逆引き用辞書
        dataset[i] = vocab[word]
    return dataset

vocab = {}
inv_vocab = {}
train_data = load_data('ptb.train.txt')
train_data = load_data('ptb.train.txt')
test_data = load_data('ptb.test.txt')
print('#vocab =', len(vocab))

# 学習済みモデル取り込み
model = L.Classifier(net.RNNLM(len(vocab), 650))
serializers.load_npz('rnnlm.model', model)
model.predictor.reset_state() # LSTM層の初期化

for i in six.moves.range(10):
    data = test_data[i:i+1]
    print(inv_vocab[data[0]])

    x = chainer.Variable(data, volatile=False)
    h0 = model.predictor.embed(x)
    h1 = model.predictor.l1(h0)
    h2 = model.predictor.l2(h1)
    y = model.predictor.l3(h2)

    # 出力結果のうち上位3個を出力
    prediction = list(zip(y.data[0].tolist(), inv_vocab.values()))
    prediction.sort()
    prediction.reverse()

    for i, (score, word) in enumerate(prediction):
        print('{},{}'.format(word, score))
        if i >= 2: break

上で書いたように、RNNの学習はそれまでにどんな値を入力されたかどうかで変化します。 今回は、サンプルのテスト用データ内の「no it was n't black monday」という文を使用して、毎回LSTM層のパラメーターをリセットした文脈を無視した場合と、最初の単語の時のみLSTM層のパラメータをリセットした文脈を考慮した場合でそれぞれ結果を比べてみました。

  • 文脈無視
no
longer,10.33497428894043
doubt,8.990683555603027
way,8.030651092529297
it
is,10.528656005859375
will,9.515053749084473
has,9.51209831237793
was
a,8.827898979187012
<unk>,8.493036270141602
the,7.862812042236328
n't
<unk>,7.8151350021362305
yet,6.953648090362549
<eos>,6.921860694885254
black
<unk>,9.096553802490234
and,8.121747016906738
<eos>,7.601911544799805
monday
<eos>,10.35195541381836
's,9.545741081237793
and,8.740907669067383
<eos>
the,9.107492446899414
<unk>,7.814574241638184
but,7.787456512451172

<eos> は End Of State の略で文の終わりを意味します。<unk> は Unknown の略で未知語を意味し、今回学習に用いたサンプルデータでは、特定の人名に相当します。 最初の単語の no を見てみると no longer, no doubt, no way など慣用句的に使用されるような単語が出力されており、うまいこと学習できていることがわかります。他の単語を見ても、文章のどこかには出てきそうな組み合わせばかりで、次の単語を推測するというタスクの結果としては悪くないもののように思えます。

  • 文脈考慮
no
longer,10.33497428894043
doubt,8.990683555603027
way,8.030651092529297
it
is,9.707919120788574
<eos>,9.36506462097168
will,9.056198120117188
was
a,8.486732482910156
<unk>,8.168980598449707
the,8.025396347045898
n't
a,8.514782905578613
the,8.452549934387207
clear,8.159638404846191
black
<eos>,12.180181503295898
in,10.142268180847168
to,9.996685028076172
monday
<eos>,11.57198715209961
and,9.793222427368164
in,9.25011920928955
<eos>
the,9.041053771972656
but,8.242995262145996
mr.,8.005669593811035

さて今度は、以前までの単語系列を考慮した場合の学習結果です。特筆すべきは n't の次の単語を推測した結果で、文脈を考慮しなかった場合は、<unk>, yet, <eos> が推測結果でしたが、考慮した場合は、a, the, clear と was n't に続きそうな単語が推測されていることが分かります。(yet や <eos> は wasn't の次にはあまりこないように思われます。) なんとなくですが、文脈を考慮して結果が出てきてるような感じですね。モデルやデータセット、学習量によってどれくらい結果が変わるかも調べてみると面白そうです。