終末 A.I.

データいじりや機械学習するエンジニアのブログ

Chainer の imagenet サンプルで遊んでみる

Deep Learning といえば、やはり画像認識での利用です。今さら感もないではないですが、chainer の imagenet サンプルは、日々進歩している画像認識処理を様々な形のネットワークで試してみることができ、Deep Learning の仕組みがどのようになっているか、chainer でオリジナルのネットワークを組み立てるにはどのように記述したらいいか、概要を把握するにはぴったりのサンプルだと思います。 今回は、chainer 1.7.1 に付属している imagenet サンプルのうち NIN ネットワークサンプルを利用して画像認識を行ってみました。 画像データセットとして、こちらの256 Object Categoriesを使用しました。

Deep Learning で画像認識を行う場合、一般的に畳み込みニューラルネットワーク(CNN)と呼ばれるものを用います。ネットワークの名前にもなっている畳み込み処理ですが、Deep Learning 特有の処理ではなく、画像処理として一般的に用いられているフィルタ処理のことを指します。このフィルタ処理は、数学的には行列演算として表現することができ、前層の出力に重みをかけ合計することにより次層の入力とする、ニューラルネットワークの演算処理でそのまま表現することができます。 CNN では、学習によりこの重みが更新される、つまり画像を処理するためによりよいフィルタを獲得していくことが成果となります。また、フィルタは何層にも重ねることができ、フィルタの出力結果に対するフィルタといったより複雑な特徴抽出を行うことができます。

このフィルタ処理の結果は、似たような画像であっても微妙な違いにより出力にムラが出ることがあります。そのムラを除去するために用いられるのがプーリング層で、出力結果を安定させる効果があります。この層は、特定区間の平均や最大を出力とする固定の操作が行われるため、学習による重みの更新を行う必要はありません。プーリング層は、一般的に畳み込み層の直後に置かれます。

今回使用したNINは、畳み込み層として、MLPConv と呼ばれる非線形な畳み込み層を使用しています。具体的には、3つの畳み込み層の後に1つのプーリング層があるという構造になります。NINは、このMLPConv層を4つつなげることにより、複雑な特徴表現を学習するようなネットワークになっています。

imagenet サンプルでは、学習データの用意は自分で実装する必要があります。画像は適当な画像データベースからおとしてくれば良いですが、学習に用いるデータセットのリストの作成と、学習データを256x256サイズにリサイズしておく必要もあります。今回は、ラベル毎にランダムで、80%を学習用データ、残りを検証用データとして扱うようにしました。また、縦横それぞれが256ピクセルになるように拡大・縮小することにより、画像をCNNで扱えるようにリサイズしています。

import os
import argparse
import random
from PIL import Image


parser = argparse.ArgumentParser(description='Image net dataset create')
parser.add_argument('--root', '-r', default='.', help='image files root path')
parser.add_argument('--output', '-o', default='Images', help='output root path')


args = parser.parse_args()
train_list_file = open('train_list.txt', mode='w', encoding='utf-8')
validate_list_file = open('validate_list.txt', mode='w', encoding='utf-8')
train_rate = 0.8 # 80% の確率で学習用データとする

if not os.path.lexists(args.output): os.mkdir(args.output)
directories = os.listdir(args.root)

for i, directory in enumerate(directories):
    full_dir_path = os.path.join(args.root, directory)
    if not os.path.isdir(full_dir_path): continue

    for file_name in os.listdir(full_dir_path):
        try:
            # Resize 256x256
            input_path = os.path.join(full_dir_path, file_name)
            if not os.path.isfile(input_path): continue

            path = os.path.join(args.output, file_name)
            Image.open(input_path).resize((256, 256), Image.LANCZOS).convert("RGB").save(path)

            label = directory.split(".")[0]
            line = path + " " + label + "\n"

            if train_rate > random.random():
                train_list_file.write(line)
            else:
                validate_list_file.write(line)
        except:
            print(file_name, "is not image")

train_list_file.close()
validate_list_file.close()

ちょっと前置きが長くなりましたが、早速学習結果を見てみましょう。まず、30エポック学習させた結果どのような画像を認識できるようになったか見てみます。ラベルが256ありますが、そのうち野球用バッドの画像セットについての結果です。バッドといえば言わずもがなこんな感じの画像が含まれています。

f:id:KSKSKSKS2:20160327151445j:plain

一方で、下のようなバッド以外のものも混じった意地悪なデータも学習セットにはだいぶ含まれています。

f:id:KSKSKSKS2:20160327151435j:plain

このような画像を認識した結果、基本的に学習セットにあるデータに関しては、バッドと認識するようになりました。混合画像までバッドと認識してしまうのはちょっと考えものですが、そのように教え込んでいるので仕方ありません。一方、未学習データに関しては下の一番目や三番目の画像はバッドと認識してくれますが、二番目の画像は残念ながらバッドとしては認識してくれませんでした。なんとなくですが、向きが関係してるのではないかと予想しています。

f:id:KSKSKSKS2:20160327151645j:plain f:id:KSKSKSKS2:20160327152126j:plain f:id:KSKSKSKS2:20160327152408j:plain

次に、識別精度を見てみます。比較として同じデータで別に学習し、10エポックで学習を終了したモデルの結果も掲載します。見ての通り、学習用データにおいても、検証用データにおいても、30エポックまで学習したモデルの方が識別精度が良くなっています。 学習用データセットの識別精度は大幅に上昇しており、ほとんどの画像を正しくラベリングできるようになっています。しかし、先ほどのバッド例のように、他の物体も混じっている画像なども多く含まれているため、学び方としてはあまり良くないように思われます。 検証用データの精度があまり上昇していなことからもそのあたりは読み取っていただけるかなと思います。もうちょっとスクリーニングされているデータセットでないと、適当にやるには厳しかったかもしれません。

30 epoch result:
train accuracy: 0.926514966153
validate accuracy: 0.422843056697

10 epoch result:
train accuracy: 0.390098686893
validate accuracy: 0.315529991783

最後に、上で比較したそれぞれのモデルの第一層の重みを可視化したものを比較してみました。左がエポック10まで学習したもの、右がエポック30まで学習したものとなります。なんとなくですが、右の方が濃淡がはっきりしているように見えます。しっかり特徴抽出を行えるようになっているということですね。ここまで自動で識別できることを考えると、多少複雑な識別問題でもそれなりのデータセットを用意すれば的確に認識してくれそうです。

f:id:KSKSKSKS2:20160327154322p:plain f:id:KSKSKSKS2:20160327154325p:plain