Keras Loss Behavior with Language Model
KerasのModelクラスを使用した際のロスの計算は、Paddingで追加した余計な値を勾配の計算から除外する処理は自動でやってくれるのですが、
historyに記録されるlossの平均値を求める際に、maskを部分的にしか考慮しておらず、padding数が多くなればなるほど、実際のロスより小さくなってしまうという現象が発生します。
この記事は、KerasのModelクラスでLossを利用する、
特にEmbedding層で、mask_zeroをTrueにした場合に、Paddingで追加した余計な値を、勾配の計算に使用しない、ロスの計算から完全に除外する方法についてのメモです。
検証用のコードはこちらです。
マスクを使用したロスの計算は、TensorFlowのチュートリアルを参考にしています。
目次
KerasのLossの種類と種類毎の処理の違い
Kerasのmodel.compileで指定できるLossの種類は大きく分けると以下の3つがあります。
※ 参考:compile内で呼ばれているtf.keras.Model.prepare_loss_functions
tf.keras.losses.Lossの派生クラス
独自定義したloss関数
独自定義したcallableなLossオブジェクト
このうち、「独自定義したloss関数」は、prepare_loss_functions内でtf.keras.losses.Lossの派生クラスであるLossFunctionWrapperクラスに変換されるため、実質は2種類のLossオブジェクトが使い分けられることになります。
Lossを含んだ計算グラフの構築は、tf.keras.Model._prepare_total_loss内で行われます。
この関数の中で、上記2つのオブジェクトは別々の処理が行われことになります。
tf.keras.losses.Lossの派生クラスの場合の処理
tf.keras.losses.Lossの場合は、Kerasの内部関数であるtf.keras.utils.losses_util.compute_weighted_loss内で、以下の処理が行われます。
ロス関数の呼び出しによるロスの計算
tf.keras.utils.losses_util.scale_losses_by_sample_weightでロス×sample_weightsを計算
Lossオブジェクトのreductionに応じた集計をtf.keras.utils.losses_util.reduce_weighted_lossで計算。大体の場合は、ロスの平均値を計算する処理
これから、sample_weightsを使用しない場合は、独自定義のロス関数でどんなshapeの値を返しても良いことがわかります。
ただし、sample_weightsとしてEmbedding層で計算したmaskが暗黙的に使用される場合があるので後述のように注意が必要です。
独自定義したcallableなLossオブジェクトの場合の処理
正解ラベルとニューラルネットワークの出力だけでなく、sample_weightsも引数に渡されて、Lossオブジェクトのcall関数が呼ばれます。
ベクトル値を返した場合、ベクトルの平均値が最終のロスとして使用されます。
Embedding層でmask_zeroをTrueにした場合のLossの挙動
Embedding層でmask_zeroをTrueに指定した場合、sample_weightsとしてEmbedding層で計算したmaskが暗黙的に使用されます。
maskはpadding処理で0埋めした部分はFalse、それ以外はTrueとなり、paddingで埋めた部分の計算に勾配が伝わらないようにしてくれます。
ただし、あくまで勾配が伝わらないようにしてくれるだけで、ロスの平均を取る際の分母の数をmask分減らしてはくれません。
このため、Perplexity等の計算をこのLossが出力した値を元に行うと、paddingで埋める長さが長いほど小さい値になってしまいおかしくなります。
例えば、正解ラベルとして ] の系列データがあるケースを考えます。
この時、ネットワークの出力として、各時点でのラベルの予測確率が ]と得られたとします。
0は計算処理の都合上いれているだけのデータで、実際の処理では無視するため、このデータに対するクロスエントロピーロスは下記のようになります。
しかしKerasのModelクラスを使用して計算すると
という値になってしまい、実際の値に比べてかなり小さくなってしまいます。
これを避けるためには、検証コード内LOSS_MODEが2の場合のように、独自定義したLossオブジェクトで、padiding部分の値を元にした処理に勾配が行かないように、かつ平均計算時の分母からの除去するように必要があります。
ちなみにMetoricsはレートに直す際に、sample_weightsの合計で割るような実装になっているため、maskの場合も意図した動作になります。