終末 A.I.

データいじりや機械学習するエンジニアのブログ

2018年風TensorFlowでの学習処理の記述方法

TensorFlowが登場して早いことで3年近く経とうとしています。 Deep Learning自体がブームになってからだと、それ以上の月日が経っているわけで、人工知能ブームも以外と続いているなあというのが正直な感想です。

Theanoやtorch、chainerに遅れをとって立ち上がったTensorFlowでしたが、はじめのうちはチュートリアルコードですらこのようなありさまで、とてもではありませんが簡単に誰もが使えるというような状態ではありませんでした。 1年ほど前からようやく、Keras の取り込みや Dataset API の実装、MonitoredTrainingSession のようなリッチな Session オブジェクトの導入などで、少し凝ったことをする場合でもかなり簡単に書けるようになってきました。

一方で公式のチュートリアルでは、データセットの読み込みはありもののAPIを使用するのが基本で、モデルの構築や学習処理もKerasのAPIのみや Eager Execution だけで解決できるようなシンプルな実装が多く、実践的にはどう書くとかゆい所に手が届きやすくなるのかがイマイチ掴みづらいところがあります。

どのような方法が現状でのベストプラクティスなのかわかりにくい状況ですが、自分用のメモも兼ねて、今回は自分がどのような考えで、どのように TensorFlow の学習処理を記述しているかを晒してみることにします。

学習処理の実装方針

今回のソースコードは以下にアップしています。

TensorFlow での学習処理を大まかに分けると、データの読み込み、モデルと学習処理の定義、学習環境の定義および学習の実行の3つに分けることができます。

一つずつどのような要件があれば良さそうかを考えてみます。

まず、「データの読み込み」で必要になるのは、様々な形式で保存されているデータセットを、適切な形に前処理しながらバッチ単位で読み出すというのがもりもりの要件になるかと思います。

データセットの読み込みや前処理自体は、あらかじめ tfrecord 形式に変更する際に実行したり、ファイルやDBへの書き込み内容そのものを事前に処理する方法もありますが、今回はそのような処理もまとめて学習時に実行する場合を想定します。

この時に活躍するのが tf.data.DataSet APIです。様々な形式からの読み込み、ストリーム処理による前処理、バッチサイズの指定やデータセットの繰り返しなど、データの読み込み時に発生する様々な処理をシームレスに実行する事ができます。

次に、「モデルと学習処理の定義」では、モデルグラフの定義および、ロスと学習処理の定義を行います。

モデルグラフの定義は、CNNやRNNなど実際に組むことになります。基本的には Keras の Layers API を使用することでほぼほぼ対応できますが、必要に応じて tf.Variable や tf.nn API を利用して自身でスクラッチで計算グラフを組んでいく必要もあります。

ロスと学習処理の定義は、ロスの計算の定義と tf.train.Optimizer を利用した勾配学習の定義を行える必要があります。 このあたりは、ニューラルネット系のフレームワークではクリティカルな部分でもあるので、簡単に実装できるようになっており、変わったレイヤー定義でもしない限りそこまで難しい部分ではなくなっています。

最後に、「学習環境の定義および実行」ですが、こちらを以下にシンプルに使いまわしやすく書いておくかが、学習処理をとっつきやすく開始できるかの鍵になるのではないでしょうか。必要なのは、学習の実行だけでなく、学習時のサマリーの記録、モデルの定期的な保存、定期的な学習ログの出力など学習処理に紐づく細かい実装が必要になります。

これらの実装には、MonitoredTrainingSessiontf.train.SessionRunHook を活用することでわかりやすく実現することができます。

Dataset API を利用したパイプラインの構築

上であげたように、データの読み込みパイプラインに必要な要素は以下のようになります。

  • ファイルなどからのデータの読み込み
  • データの前処理
  • バッチおよびデータの繰り返しなど学習上必要な処理

具体的にどのように TensorFlow の Dataset API で実現するかを一個ずつ見てみましょう。

ファイルからのデータの読み込み

まず、ファイルからのデータの読み込みは下記のようになります。

image_dataset = tf.data.FixedLengthRecordDataset(str(image_path), record_bytes=28*28, header_bytes=16)
label_dataset = tf.data.FixedLengthRecordDataset(str(label_path), record_bytes=1, header_bytes=8)
dataset = tf.data.Dataset.zip((image_dataset, label_dataset))

上記はローカルに保存した MNIST のデータセットを読み込んでいる箇所になります。 MNISTのデータはイメージデータとラベルデータがバラバラのファイルで管理され、それぞれにバイナリ形式で一定のバイト単位でデータが保存されています。

このようなデータを読み込む際に役に立つのが tf.data.FixedLengthRecordDataset クラスです。このクラスはまさに一定のバイト単位で保存されているデータを切り取って順番に読み出すことができます。

また、MNISTでは画像とラベルが別々のファイルに分かれているため実際の処理を行う前に結合しておく必要があります。tf.data.Dataset.zip メソッドは、python の組み込み関数である zip のように、複数のデータセットの出力をタプル形式でまとめて出力できる Dataset オブジェクトを返してくれます。

データの前処理

データの前処理で必要なものは、主にフィルタリングと学習で使用する形式への変換(Data Augmentation 含む)です。 以下は、Dataset の filter メソッドおよび map メソッドを利用してそれらを実現している例です。 呼び出し自体は、それぞれ実際の処理を行う関数を渡すだけで実現できます。

dataset = (dataset
           .filter(converter.filter)
           .map(converter.convert, num_parallel_calls=threads))

実際の処理は、下記のように Dataset の出力を引数に受けて、filter の場合 bool 値を、map の場合は変換したデータを出力する必要があります。 Dataset の出力は、今回の場合は zip で2つの Dataset をまとめていますので、処理を行う関数の引数は2つになっています。 また、tf.data.FixedLengthRecordDataset で読み込んでいるためそれぞれのデータはバイナリ形式になっており、 tf.decode_raw メソッドで事前に変換しています。

前処理の実際の処理は、 tf.py_func メソッドで生 python の関数を呼び出すようにしています。今回の実装では、必要な処理ではないですが、このようにしておくと処理をいくらでも柔軟に組み替えることが可能になります(ただしCPUパワーが必要になってきますが)。

tf.py_func メソッドは、処理をする関数、TensorFlow オブジェクトと出力する型を指定して、処理を行った関数の出力を戻り値として受け取ります。 戻り値は shape が指定されていないので、あとあとの処理のために set_shape で設定しておく必要があります。 また、tf.py_func 内の処理の組み方によっては、入力データを複数に分割して出力するような処理も実現することができます。

def _filter(self, images, labels):
    return True

def filter(self, images, labels):
    predicts = tf.py_func(
                self._filter,
                [images, labels],
                [tf.bool])[0]
    return predicts

def _convert(self, images, labels):
    images = images.reshape((28, 28, 1))
    images = images.astype(np.float32)

    labels = labels.astype(np.uint8)
    labels = labels.reshape((1, ))

    return images, labels

def convert(self, images, labels):
    images = tf.decode_raw(images, tf.uint8)
    labels = tf.decode_raw(labels, tf.uint8)
    images, labels = tf.py_func(
                self._convert,
                [images, labels],
                [tf.float32, tf.uint8])
    images.set_shape((28, 28, 1))
    labels.set_shape((1, ))
    return images, labels

バッチおよびデータの繰り返し

バッチ化およびデータの繰り返しは、 shuffle、repeat、batch(padded_batch)の3つのメソッドで実現します。

shuffle メソッドは、文字通り指定したファイルから順番に読み出したデータを buffer_size 分メモリに貯めて、その中でランダムに batch_size 分のデータを返すような挙動になります。

上述のようにこのメソッドは、メモリにデータをロードすることになるので、(GPUで学習していればGPUの)メモリが必要になりますが buffer_size の値が大きければ大きいほど、基本的にはランダムの質はよくなります(ファイルの前の方のデータと後ろの方のデータが混ざりやすくなります)。ロスの値などの metric がバッチ毎に偏ってしまっているならば、buffer_size の値を増やすことを検討してみるべきです。

repeat メソッドは、それこそそのままでデータを繰り返し読み出してくれます。引数に回数を指定した場合その回数分、何も指定しない場合エンドレスでデータを読み出し続けてくれます。

batch メソッドは、指定した batch_size 分の値を各ステップで読み出してくれます。padded_batch メソッドは、バッチ単位でデータの不揃いを整形してくれます。padding する値も引数で指定することができます。

if is_training:
    dataset = dataset.shuffle(buffer_size=buffer_size).repeat()

dataset = dataset.padded_batch(batch_size, dataset.output_shapes)

モデルと学習処理の定義

モデルと学習処理の定義は、Keras の Layers API やありものの Optimizer およびロスの定義用に、 tf.nn API や一部の数値計算用のメソッドを使用することでほぼ問題なく実装することができます。

自分のモデルクラスの実装は、Keras の Layers に準じて基本的に build と call で構成され、それぞれパラメーターやレイヤーの定義と計算グラフの構築を担当します。別途モデルに関連しているメソッドとして、loss 計算用のメソッドと学習処理(最適化処理)を構築するメソッド、予測を実行するメソッドをインスタンスに含めています。

def call(self, inputs, is_train=True):
    outputs = tf.reshape(inputs, (-1, 28*28)) / 255.0
    outputs = self.dense1(outputs)
    if is_train:
        outputs = tf.nn.dropout(outputs, 0.2)
    outputs = self.dense2(outputs)
    return outputs

def loss(self, logits, labels):
    labels = tf.reshape(tf.cast(labels, tf.int32), (-1, ))
    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels)
    loss = tf.reduce_mean(loss)
    return loss

def optimize(self, loss, clipped_value=1.0):
    grads = self.optimizer.compute_gradients(loss, self.variables)
    clipped_grads = [(tf.clip_by_value(g, -clipped_value, clipped_value), v) for g, v in grads]
    train_op = self.optimizer.apply_gradients(clipped_grads)
    return train_op

def predict(self, logits):
    _, indices = tf.nn.top_k(logits, 1, sorted=False)
    return indices

callメソッドでどこまで担当するかいくつ値を返すか、lossやoptimize、predictメソッドの引数に何を取るかはモデルによって変わってしまう構成をとっていますが、大まかな関数の呼び出し順は、どのようなモデルを構築する際もこれをベースに行うことができます。 例えば、GANの実装では、call メソッドで descriminator の logits 元画像及び生成画像に対して計算して返し、loss で generator および descriminator のロスを計算、optimize でそれぞれの勾配計算を tf.control_dependencies や tf.group などで関連付けて返せば良いという実装になります。

ここで主に注意する点は、ロスの計算周りと最適化計算処理です。

ロスの計算では、 cross_entropy の計算などで log に渡す値や割る値が 0 になり、出力として -inf や NaN を返す可能性が十分にあります。keras の categorical_cross_entropy の実装などでは、log の計算で -inf が返らないように内部的に微小な値(epsilon値)を使って調整してくれていますが、生のTensorFlow はそこまで親切ではないので、その辺りを予防するために必要があればロス関数を自力実装する必要性があります。

また、最適化処理も、keras の Oprimizer のように勾配爆発を防ぐために勾配を clipping するのが一般的です(学習率の調整やデータの前処理でどうにかできる場合はそれでも問題ないです)。 clipping だけでは勾配消失に対応できませんが、こと画像認識においては Residual Block などのテクニックで大幅にそのリスクを減らすことができます。

学習環境の定義及び実行

学習の実行処理で必要になる処理は細かくも様々です。なんとなしに必要なものを洗い出すと、ただ学習をさせたいだけなのに色々と必要になってきます。

  • データの取得
  • モデルの定義
  • サマリーの保存
  • モデルの保存
  • 学習の監視
  • 学習処理の実行

このうち上の2つは、上記の2項目で実装した関数を呼び出すだけで実行できるようにしておくことがベストです。そうなっていない場合は、該当部分のクラス構造や関数定義を見直してみるほうが良いでしょう。

サマリーの保存は、session として MonitoredTrainingSession を使用していれば、tf.summary の各 API を使用して監視したいオブジェクトを設定しておくだけで勝手に summary_dir に指定したディレクトリにサマリーを保存してくれます。 保存する間隔も save_summaries_steps もしくは save_summaries_secs で指定することが可能です。

モデルの保存は、tf.train.Scaffold で tf.train.Saver をラップすることにより、同じく MonitoredTrainingSession に渡すことで実現できるようになります。 tf.train.Saver の初期化パラメーターで、最大いくつのチェックポイントを残すか(max_to_keep)と、何時間に一回のチェックポイントを残すか(keep_checkpoint_every_n_hours)を指定できます。 また、MonitoredTrainingSession のパラメーターとして、save_checkpoint_secs もしくは save_checkpoint_steps でチェックポイントを保存する間隔を、checkpoint_dir でチェックポイントを保存するディレクトリを指定することができます。

scaffold = tf.train.Scaffold(
    saver=tf.train.Saver(
        max_to_keep=checkpoints_to_keep,
        keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours))

学習の監視は、tf.train.SessionRunHook の派生オブジェクトのリストを、MonitoredTrainingSession にわたすことにより実現できます。 定期的に指定したオブジェクトの値をログとして出力してくれる tf.train.LoggingTensorHook。 指定したオブジェクトの値が NaN になった際にエラーを発生させてくれる tf.train.NanTensorHook。 global_step が指定の値以上になった時に、 session.shoud_stop メソッドを True にしてくれる tf.train.StopAtStepHook。 など、さまざまな Hook が用意されています。Hook 自体は自作もそこまで難しくありませんので、必要に応じて作ることも可能です。

hooks.append(tf.train.LoggingTensorHook(metrics, every_n_iter=100))
hooks.append(tf.train.NanTensorHook(loss))
if max_steps:
    hooks.append(tf.train.StopAtStepHook(last_step=max_steps)

ここまでくれば学習の実行は非常に簡単で、下記のようにひたすら session.run を実行するだけですみます。

with session:
    while not session.should_stop():
        session.run([train_op])

他にも、tf.ConfigProto で GPU の実行設定を行ったり、tf.train.ClusterSpec および tf.train.Server を使用した複数サーバーでの分散学習も設定することができます。

終わりに

近年の TensorFlow は、フレームワークにそれなりに足を突っ込めば、学習時にやりたいことは一通りできるようになってきています。それでも学習コストは十分にあるのですが、以前の時間をかけて理解しても結局よくわからないという状態に比べれば天と地ほどの差があるのは見てのとおりです。

TensorFlow 2.0 ではこの辺りが整理されてよりわかりやすくより使いやすくなることを期待します。

参考