終末 A.I.

データいじりや機械学習するエンジニアのブログ

Cloud Storage Transfer ServiceでAssumeRoleを使ってS3からデータを移行する

※ この記事は2021年10月の情報に基づいて記載しています。

※ 最新情報はGCPのドキュメントを参照ください。

Cloud Storage Transfer Serviceは、GCP内から直接S3等のクラウドストレージ(もしくはオンプレミス)のデータ移行を行うことにより、高速で高並列にGCSへのデータ移行を実現できるサービスです。

一方で、データソースをS3とする場合は、AWSのアクセスキーとシークレットをGCPに設定しておく必要があり、アクセス情報の漏洩に対して、あまりセキュアではありませんでした。

それに対応できる機能が、2021年7月にパブリックプレビューとなりました。AWSのAssumeRoleWithWebIdentityを利用したフェデレーテッドアクセスにより、データの転送が行えるようになったのです。以下では、その使い方について、簡単に説明したいと思います。

GCPが管理する専用のサービスアカウント(SA)を確認する

Storage Transfer Serviceでは、GCPが管理する専用のサービスアカウントを用いてデータの転送処理を行っています。

通常のIAM一覧上では確認できず、必要な情報はサービス専用の googleServiceAccounts.get を呼び出すことで取得できます。このAPIの戻り値は以下のようになります。

{
  "accountEmail": "project-xxxxxxx@storage-transfer-service.iam.gserviceaccount.com",
  "subjectId": "xxxxxxx"
}

accountEmailは見ての通りSAを識別するための固有のメールアドレスで、xxxxの部分にはプロジェクト固有のプロジェクト識別番号が入ります。subjectIdもアカウントを識別するための固有のIDで、後でAWS上でフェデレーテッドアクセスを設定するために使用する情報になります。

AWSに専用のロールを作成する

続いて、AWS上に上記のSAがフェデレーテッドアクセスするためのロールを作成します。ロールには、GCPのSAがAssumeRoleを行うための権限と、該当のS3バケットからデータを読み出すための権限設定が必要です。

まず、フェデレーテッドアクセス用の権限ですが、対象のロールのAssumeRolePolicyに下記を設定すれば良いです。

{
  "Version": "2012-10-17",
  "Statement": [
    {
      "Effect": "Allow",
      "Principal": {
        "Federated": "accounts.google.com"
      },
      "Action": "sts:AssumeRoleWithWebIdentity",
      "Condition": {
        "StringEquals": {
          "accounts.google.com:sub": "SAのsubjectId"
        }
      }
    }
  ]
}

設定内容の詳しい説明はAWSのドキュメントを参照していただきたいのですが、GoogleからOIDCでAssumeRoleを呼び出された場合に、アカウントのsubjectIdが指定したものについてのみ許可するという設定になります。

次にS3からデータを読み出すための権限をロールのポリシーに指定します。適宜必要に応じて追加したり削ったりしてもらえれば良いですが、簡単には下記のようになります。

{
  "Version": "2012-10-17",
  "Statement": [
    {
      "Effect": "Allow",
      "Action": [
          "s3:Get*",
          "s3:List*",
       ],
      "Resource": "送信元バケット"
    }
  ]
}

GCSの書き込み権限を設定する

Storage Transfer Serviceが使用するSAには、対象のGCSバケットに書き込みを行う権限も必要になります。GCSの権限設定を用いて、バケットにSAが roles/storage.legacyBucketWriter および roles/storage.objectViewer を使用してアクセスできるように設定を行います。

Storage Transfer ServiceのJobを作成する

最後に、上記で作成したロールを用いてS3にアクセスするように設定したジョブを作成します。記事を書いている時点では、Webコンソールにて設定する方法はありませんでしたので、API呼び出しで作成する方法について記載します。

ジョブの作成は、transferJobs.createを使用して行います。設定内容は以下が設定されていれば最低限問題ありません。

{
  "name": "ジョブの名称",
  "projectId": "GCPのプロジェクトのID",
  "transferSpec": {
    "awsS3DataSource": {
      "bucketName": "送信元バケット名",
      "roleArn": "作成したAWSのロールのARN"
    },
    "gcsDataSink": {
      "bucketName": "送信先バケット名"
    }
  },
  "status": "ENABLED"
}

このAPIを実行したタイミングで、AssumeRoleでフェデレーテッドアクセスができるか、ロールでS3に指定のバケットの読み取りが行えるかの確認も走ります。アクセス情報に問題なければジョブが作成されます。

あとは、通常のジョブと同様手動で実行したりスケジュールで実行したりすることができます。

データテストライブラリー「Deequ」を触ってみた

DeequはAWSがリリースしているデータテストを行うためのライブラリです(Deequの説明ではUnit Testと表現されています)。

ここで言うデータテストは、ETL処理やデータマート作成処理などの意図通り動いているどうか、取り込んだデータが昔と変化していないかを確認するための検証処理のことを指しています。

ETL処理などを最初に作成したタイミングでは、その処理が意図したものになっているか確認すると思います。一方で、日次のバッチ処理や、動き続けているストリーム処理について、本当に意図したようにデータが加工されているかどうかは、通常の方法では処理自体が成功したかどうかくらいしか確認するすべがありません。

しかし、日々のデータ処理は簡単に意図しないデータを生み出してしまう可能性があります。気づいたらデータの中身が変わっていて、変換処理が意図しない動作をしてしまっていたり、そもそもソースデータがおかしくなっていて重要な指標がずれてしまう、というようなことも考えられるでしょう。

そのような時に役に立つのがデータテストです。データテストでは、Nullを許容しないはずのカラムに何故かNullか入ってくるようになっていないか、過去のデータと比較して極端にデータの数が変化していないか、などを調べることを含む概念です。一言でいうと、データが意図しないものになっていないかを確認する処理、と言えます。

目次

Deequの何がいいのか

AWSのDeequは、そんなデータテストを簡単に実施できるようにするためのScala製ライブラリです。PythonラッパーであるPyDeequもあります。 Deequは、例えばデータ変換ツールであるdbtでもサポートしていますが、それと比べると、下記のような点が特徴としてあげられます。

  • Sparkベースでできている
    • SQLクエリで直接アクセスできない、ファイルだけがあるようなデータにも適用できる
    • プラグラムベースでしか実現しにくいような処理でも比較的組み込みやすい
  • プリセットのテスト関数が豊富に組み込まれている
    • 手元のデータ単体にフォーカスしたテスト関数だけで40個ほどプリセットである
    • AnomalyDetectionという、過去のデータの状態も参照してテストするための処理も組み込まれている
  • 必要なテスト処理をカラムごとにレコメンドしてくれる機能も組み込まれている

個人的に特にいいのは、AnomalyDetectionの機能が最初から組み込まれている点です。言わずもがなデータは日々変わりますので、実データについて何が正解かを決めにくい、という問題がテストを実装するにあたって存在します。

そのため、今までのデータと比べてどうかというのが、データがおかしくなっているかどうかを判断するための1つの重要な指標になります。特定の指標について経年変化をもとに「違和感」を自動で検出できる仕組みは非常に重要です。

Deequは、このようなAnomalyDetectionと、一般的なデータ単体のテストを、1つのライブラリで完結して行うことができるようになっています。

データテストの実行

基本的な使い方

データテストの基本的な実行は、PyDeequでは下記のように書くことで実現できます。

サンプルコード1

from pyspark.sql.types import StructType, StructField, StringType, IntegerType
from pydeequ.checks import *
from pydeequ.verification import *

df = spark.createDataFrame(data=[
        (1, "Test1", "foo"), 
        (2, "Test2", "foo"), 
        (3, "Test3", "bar"), 
        (4, "Test4", "baz"), 
        (5, "Test5", "baz"), 
        (6, "Test6", "bar"), 
        (7, "Test7", None), 
        (8, "Test8", "bar"), 
], schema=StructType([
        StructField("a", IntegerType(), True),
        StructField("b", StringType(), True),
        StructField("c", StringType(), True),
]))

check_warning = Check(spark, CheckLevel.Warning, "Warning Check")
check_error = Check(spark, CheckLevel.Error, "Error Check")

checkResult = VerificationSuite(spark) \
    .onData(df) \
    .addCheck(
        check_warning.isComplete("a") \
        .isComplete("b") \
        .isComplete("c")) \
    .addCheck(
        check_error.isPositive("a") \
        .isUnique("b") \
        .isContainedIn("c", ["foo", "bar", "baz"])) \
    .run()

checkResult_df = VerificationResult.checkResultsAsDataFrame(spark, checkResult)
checkResult_df.show(truncate=False)

実行結果は下記のようになります。

結果1

+-------------+-----------+------------+-----------------------------------------------------------------------------------------------------------+-----------------+------------------------------------------------------+
|check        |check_level|check_status|constraint                                                                                                 |constraint_status|constraint_message                                    |
+-------------+-----------+------------+-----------------------------------------------------------------------------------------------------------+-----------------+------------------------------------------------------+
|Warning Check|Warning    |Warning     |CompletenessConstraint(Completeness(a,None))                                                               |Success          |                                                      |
|Warning Check|Warning    |Warning     |CompletenessConstraint(Completeness(b,None))                                                               |Success          |                                                      |
|Warning Check|Warning    |Warning     |CompletenessConstraint(Completeness(c,None))                                                               |Failure          |Value: 0.875 does not meet the constraint requirement!|
|Error Check  |Error      |Success     |ComplianceConstraint(Compliance(a is positive,COALESCE(CAST(a AS DECIMAL(20,10)), 1.0) > 0,None))          |Success          |                                                      |
|Error Check  |Error      |Success     |UniquenessConstraint(Uniqueness(List(b),None))                                                             |Success          |                                                      |
|Error Check  |Error      |Success     |ComplianceConstraint(Compliance(c contained in foo,bar,baz,`c` IS NULL OR `c` IN ('foo','bar','baz'),None))|Success          |                                                      |
+-------------+-----------+------------+-----------------------------------------------------------------------------------------------------------+-----------------+------------------------------------------------------+

VerificationSuiteとCheckの2種類のオブジェクトから構成され、Checkオブジェクトにメソッドチェーンの形でテストしたい項目を追加していき、addCheckでCheckオブジェクトをVerificationSuiteに登録、runでテストを実行できます。

Checkがテストする内容は、関数とカラム名によって決まります。 isComplete("a") であれば、カラムaの値がすべてNullでないことを検証してくれます。検証の結果は、結果が格納されたDataFrameのconstraint_statusカラムに格納され、成功であればSuccess、失敗であればFailureが入ります。

Checkオブジェクトには、WarningとErrorの2種類のCheckLevelと、そのCheckについての説明を指定することができます。同じCheckオブジェクトに紐付いたテストが1つでも失敗すると、結果のcheck_statusは、設定されているCheckLevelに応じた値が入るようになります。

Checkオブジェクト毎にどのようなテストを追加するかは後処理でどのようにしたいかによって分けるのが良いでしょう。ErrorLevelでオブジェクトを分ける、カラムごとにオブジェクトを分ける、テスト項目ごとにオブジェクトを分ける、などがパターンとして考えられます。

複数カラムの関係をテストする

Deequでは、特定カラムだけでなく、複数のカラムの組み合わせについてテストすることもできます。

下記のように、テスト関数の引数が異なるだけで、基本的な使い方は単一カラムのテストの場合と使い方は変わりません。

サンプルコード2

from pyspark.sql.types import StructType, StructField, StringType, IntegerType
from pydeequ.checks import *
from pydeequ.verification import *

df = spark.createDataFrame(data=[
        (1, 10, "foo"), 
        (2, 20, "foo"), 
        (3, 30, "bar"), 
        (4, 40, "baz"), 
        (5, 50, "baz"), 
        (6, 60, "bar"), 
        (7, 70, None), 
        (8, 80, "bar"), 
], schema=StructType([
        StructField("a", IntegerType(), True),
        StructField("b", IntegerType(), True),
        StructField("c", StringType(), True),
]))

check_warning = Check(spark, CheckLevel.Warning, "Warning Check")

checkResult = VerificationSuite(spark) \
    .onData(df) \
    .addCheck(
        check_warning.isLessThan("a", "b") \
        .hasCorrelation("a", "b", lambda x: x >= 1.0) \
        .hasUniqueness(["b", "c"], lambda x: x >= 1.0)) \
    .run()

checkResult_df = VerificationResult.checkResultsAsDataFrame(spark, checkResult)
checkResult_df.show(truncate=False)

isLessThanは、文字通りの意味でaの値が対応するbの値よりも小さいことを確認するテストです。

hasCorrelationは、指定した2つのカラムの相関係数が、指定したassertion関数を満たすことを確認するテストです。

hasUniquenessは、hasCorrelationと同様で、指定したカラムすべてを考慮してユニークかどうかを判定し、その結果がassertion関数を満たすことを確認するテストです。

制約条件を満たさないことをテストする

has系のテスト関数は、上記で書いたようにassertion関数を満たすかどうかでテスト結果が決まります。

このテスト関数の性質を利用して、条件を満たさないことをテストすることも可能です。

サンプルコード3

from pyspark.sql.types import StructType, StructField, StringType, IntegerType
from pydeequ.checks import *
from pydeequ.verification import *

df = spark.createDataFrame(data=[
        (1, 10, "foo"), 
        (2, 20, "foo"), 
        (3, 30, "bar"), 
        (4, 40, "baz"), 
        (5, 50, "baz"), 
        (6, 60, "bar"), 
        (7, 70, None), 
        (8, 80, "bar"), 
], schema=StructType([
        StructField("a", IntegerType(), True),
        StructField("b", IntegerType(), True),
        StructField("c", StringType(), True),
]))

check_warning = Check(spark, CheckLevel.Warning, "Warning Check")

checkResult = VerificationSuite(spark) \
    .onData(df) \
    .addCheck(
        check_warning.containsEmail("c", lambda x: x == 0.0) \
        .hasMin("a", lambda x: x > 0)) \
    .run()

checkResult_df = VerificationResult.checkResultsAsDataFrame(spark, checkResult)
checkResult_df.show(truncate=False)

hasMinは、指定したカラムの最小値が条件を満たすことを確認するテストです。0より小さくないことを確認する場合は、assertion関数で最小値が0より大きくなることを確認すればよいです。

containsEmailは、デフォルトでは指定したカラムのすべての値がEmailの形式を満たすかを確認するテスト関数です。assertion関数をしていした場合、Emailの値を含むカラムの割合を確認することができます。つまりこの値が0であることを確認すれば、指定したカラムにEmail形式の文字列を含まないことをテストすることができます。

カスタマイズした内容をテストする

satisfies関数を使う事により、任意のSQLを満たすかどうか、もしくは、レコードのうち何割が満たすかをテストすることができます。

サンプルコード4

from pyspark.sql.types import StructType, StructField, StringType, IntegerType
from pydeequ.checks import *
from pydeequ.verification import *

df = spark.createDataFrame(data=[
        (1, 10, "foo"), 
        (2, 20, "foo"), 
        (3, 30, "bar"), 
        (4, 40, "baz"), 
        (5, 50, "baz"), 
        (6, 60, "bar"), 
        (7, 70, None), 
        (8, 80, "bar"), 
], schema=StructType([
        StructField("a", IntegerType(), True),
        StructField("b", IntegerType(), True),
        StructField("c", StringType(), True),
]))

check_warning = Check(spark, CheckLevel.Warning, "Warning Check")

checkResult = VerificationSuite(spark) \
    .onData(df) \
    .addCheck(
        check_warning.satisfies("b % 10 = 0", "B is 10 dividable", lambda x: x == 1.0) \
        .satisfies("rlike(c, '[0-9]+?')", "C is not contained numeric", lambda x: x == 0.0) \
        .satisfies("b / a = 10", "b / a is 10", lambda x: x == 1.0)) \
    .run()

checkResult_df = VerificationResult.checkResultsAsDataFrame(spark, checkResult)
checkResult_df.show(truncate=False)

satisfies関数の引数は、条件のSQL文、条件の説明文、条件を満たすレコードの割合のassertion関数の3つが必要です。

SQL文は、Sparkのfilter関数で使用できるSQLの条件文であれば何でも使用することができます。rlikeで正規表現でマッチさせたり、複数のカラム間の関係を条件にすることもできます。

説明文は、下記の結果の出力内に使用される文字列ですので、後で識別できるものであれば何でも大丈夫です。

assertion関数は、pythonのsatisfies関数の定義ではOptionalですが、指定していないとエラーになってしまいます。すべてのレコードを満たすことを確認する場合は lambda x: x == 1.0 を、満たさないことを確認する場合は lambda x: x == 0.0 を指定しておく必要があります。

結果4

+-------------+-----------+------------+-------------------------------------------------------------------------------------+-----------------+------------------+
|check        |check_level|check_status|constraint                                                                           |constraint_status|constraint_message|
+-------------+-----------+------------+-------------------------------------------------------------------------------------+-----------------+------------------+
|Warning Check|Warning    |Success     |ComplianceConstraint(Compliance(B is 10 dividable,b % 10 = 0,None))                  |Success          |                  |
|Warning Check|Warning    |Success     |ComplianceConstraint(Compliance(C is not contained numeric,rlike(c, '[0-9]+?'),None))|Success          |                  |
|Warning Check|Warning    |Success     |ComplianceConstraint(Compliance(b / a is 10,b / a = 10,None))                        |Success          |                  |
+-------------+-----------+------------+-------------------------------------------------------------------------------------+-----------------+------------------+

AnomalyDetectionの実行

基本的な使い方

AnomalyDetectionの基本的な実行は、PyDeequでは下記のように書くことで実現できます。

サンプルコード5

from pyspark.sql.types import StructType, StructField, StringType, IntegerType
from pydeequ.analyzers import *
from pydeequ.anomaly_detection import *
from pydeequ.repository import *
from pydeequ.verification import *
import datetime

now = datetime.datetime.now().timestamp()
repo = InMemoryMetricsRepository(spark)
yesterdaysKey = ResultKey(spark, int(now * 1000) - 24 * 60 * 60 * 1000)

df_yesterday = spark.createDataFrame(data=[
        (1, 10, "foo"), 
        (2, 20, "foo"), 
        (3, 30, "bar"), 
        (4, 40, "baz"), 
        (5, 50, "baz"), 
        (6, 60, "bar"), 
        (7, 70, None), 
        (8, 80, "bar"), 
], schema=StructType([
        StructField("a", IntegerType(), True),
        StructField("b", IntegerType(), True),
        StructField("c", StringType(), True),
]))

checkResult = AnalysisRunner(spark) \
    .onData(df_yesterday) \
    .useRepository(repo) \
    .saveOrAppendResult(yesterdaysKey) \
    .addAnalyzer(Size()) \
    .addAnalyzer(Maximum("a")) \
    .run()

df = spark.createDataFrame(data=[
        (1, 10, "foo"), 
        (2, 20, "foo"), 
        (3, 30, "bar"), 
        (4, 40, "baz"), 
        (5, 50, "baz"), 
        (6, 60, "bar"), 
        (7, 70, None), 
        (8, 80, "bar"), 
        (9, 90, "bar"), 
        (10, 100, "bar"), 
], schema=StructType([
        StructField("a", IntegerType(), True),
        StructField("b", IntegerType(), True),
        StructField("c", StringType(), True),
]))

todaysKey = ResultKey(spark, int(now * 1000))
checkResult = VerificationSuite(spark) \
    .onData(df) \
    .useRepository(repo) \
    .saveOrAppendResult(todaysKey) \
    .addAnomalyCheck(AbsoluteChangeStrategy(-1.0, 1.0), Size()) \
    .addAnomalyCheck(AbsoluteChangeStrategy(-2.0, 2.0), Size()) \
    .addAnomalyCheck(RelativeRateOfChangeStrategy(0.9, 1.1), Maximum("a")) \
    .addAnomalyCheck(RelativeRateOfChangeStrategy(0.7, 1.3), Maximum("a")) \
    .run()

checkResult_df = VerificationResult.checkResultsAsDataFrame(spark, checkResult)
checkResult_df.show(truncate=False)

repo.load() \
  .getSuccessMetricsAsDataFrame() \
  .show()

AnomalyDetectionでは、通常のテストと違い、Checkの代わりに、AnomalyDetectionの方法を指定するオブジェクト、Analyzerというテスト項目を指定するオブジェクト、Repositoryという今までのテスト項目の結果を保存してくれるオブジェクト、ResultKeyといういつの時間のデータを保存するかを指定するオブジェクトが必要になります。

基本的にはRepositoryのオブジェクトに保存されている直前(オブジェクトに保存されている順番で直前)のデータを各テスト項目について比較して結果を出力します。Repositoryオブジェクトへの結果の保存は、AnalyzerRunnerクラスを用いて、Analyzerによるテスト項目の計算のみを行いその結果を保存する方法と、VerificationSuiteでAnomalyDetectionも行いながら、テスト項目の計算結果を保存する2つのアプローチがあります。Repositoryオブジェクトが管理しているファイルやメモリ上のオブジェクトの結果は、保存するたびに結果が追記されていくため、VerificationSuiteでの保存は用途に使い分けるのが良さそうです。

AbsoluteChangeStrategyは、保存されている直前の値と比較して、変化量がmaxRateDecreaseとmaxRateIncreaseの範囲に収まっているかどうかを確認します。

RelativeRateOfChangeStrategyは、保存されている直前の値と比較して、変化率がmaxRateDecreaseとmaxRateIncreaseの範囲に収まっているかどうかを確認します。

結果5

+---------------------------------+-----------+------------+----------------------------------+-----------------+-----------------------------------------------------+
|check                            |check_level|check_status|constraint                        |constraint_status|constraint_message                                   |
+---------------------------------+-----------+------------+----------------------------------+-----------------+-----------------------------------------------------+
|Anomaly check for Size(None)     |Warning    |Warning     |AnomalyConstraint(Size(None))     |Failure          |Value: 10.0 does not meet the constraint requirement!|
|Anomaly check for Size(None)     |Warning    |Success     |AnomalyConstraint(Size(None))     |Success          |                                                     |
|Anomaly check for Maximum(a,None)|Warning    |Warning     |AnomalyConstraint(Maximum(a,None))|Failure          |Value: 10.0 does not meet the constraint requirement!|
|Anomaly check for Maximum(a,None)|Warning    |Success     |AnomalyConstraint(Maximum(a,None))|Success          |                                                     |
+---------------------------------+-----------+------------+----------------------------------+-----------------+-----------------------------------------------------+

テスト項目は、Analyzerクラスの種類と指定されているカラムの組み合わせで決まります。 サンプルコードでは、todaysKey に関連して実行されている addAnomalyCheck には Size()Maximum("a") が指定されているものが2つありますが、下記のようにレポジトリに保存されている内容は、1つのみとなります。

保存内容5

+-------+--------+-------+-----+-------------+
| entity|instance|   name|value| dataset_date|
+-------+--------+-------+-----+-------------+
|Dataset|       *|   Size|  8.0|1628216988556|
| Column|       a|Maximum|  8.0|1628216988556|
|Dataset|       *|   Size| 10.0|1628303388556|
| Column|       a|Maximum| 10.0|1628303388556|
+-------+--------+-------+-----+-------------+

傾向の変化を検出する

AnomalyDetectionでは、直前だけでなく過去の複数のデータを使って、テストを行うことができます。

サンプルコード6

from pyspark.sql.types import StructType, StructField, StringType, IntegerType
from pyspark.sql.types import StructType, StructField, StringType, IntegerType
from pydeequ.analyzers import *
from pydeequ.anomaly_detection import *
from pydeequ.repository import *
from pydeequ.verification import *
import datetime

now = datetime.datetime.now().timestamp()
repo = InMemoryMetricsRepository(spark)

for i in range(24):
  yesterdaysKey = ResultKey(spark, int(now * 1000) - (24 - i) * 60 * 60 * 1000)

  df_yesterday = spark.createDataFrame(data=[
          (1, 10, "foo"), 
          (2, 20, "foo"), 
          (3, 30, "bar"), 
          (4, 40, "baz"), 
          (5, 50, "baz"), 
          (6, 60, "bar"), 
          (7, 70, None), 
          (8, 80, "bar"), 
  ], schema=StructType([
          StructField("a", IntegerType(), True),
          StructField("b", IntegerType(), True),
          StructField("c", StringType(), True),
  ]))

  checkResult = AnalysisRunner(spark) \
      .onData(df_yesterday) \
      .useRepository(repo) \
      .saveOrAppendResult(yesterdaysKey) \
      .addAnalyzer(Mean("a")) \
      .addAnalyzer(Completeness("c")) \
      .run()

df = spark.createDataFrame(data=[
        (1, 10, "foo"), 
        (2, 20, "foo"), 
        (3, 30, "bar"), 
        (4, 40, "baz"), 
        (5, 50, None), 
        (6, 60, None), 
        (7, 70, None), 
        (8, 80, "bar"), 
        (9, 90, None), 
        (1000, 100, None), 
], schema=StructType([
        StructField("a", IntegerType(), True),
        StructField("b", IntegerType(), True),
        StructField("c", StringType(), True),
]))

todaysKey = ResultKey(spark, int(now * 1000))
checkResult = VerificationSuite(spark) \
    .onData(df) \
    .useRepository(repo) \
    .saveOrAppendResult(todaysKey) \
    .addAnomalyCheck(OnlineNormalStrategy(), Mean("a")) \
    .addAnomalyCheck(OnlineNormalStrategy(), Completeness("c")) \
    .run()

checkResult_df = VerificationResult.checkResultsAsDataFrame(spark, checkResult)
checkResult_df.show(truncate=False)

repo.load() \
  .getSuccessMetricsAsDataFrame() \
  .show()

OnlineNormalStrategyは、Analyzerで指定されたテスト項目の過去の値の平均と標準偏差を計算し、今回のテスト項目の値が mean - lowerDeviationFactor *stdDev と mean + upperDeviationFactor * stDev の間に収まっているかどうかを確認します。

ignoreAnomalies で履歴データ内の外れ値を無視して平均と標準偏差を計算してくれることを期待しますが、現状のDeequ側の実装では残念ながらそれらは無視されず、平均と標準偏差の計算に考慮されてしまいます。また、ウィンドウサイズのようなものを指定することができないため、Repositoryに保存されているデータを渡す前に特定日以前は切っておくというような操作が必要になります。

結果6

+--------------------------------------+-----------+------------+---------------------------------------+-----------------+------------------------------------------------------+
|check                                 |check_level|check_status|constraint                             |constraint_status|constraint_message                                    |
+--------------------------------------+-----------+------------+---------------------------------------+-----------------+------------------------------------------------------+
|Anomaly check for Mean(a,None)        |Warning    |Warning     |AnomalyConstraint(Mean(a,None))        |Failure          |Value: 104.5 does not meet the constraint requirement!|
|Anomaly check for Completeness(c,None)|Warning    |Warning     |AnomalyConstraint(Completeness(c,None))|Failure          |Value: 0.5 does not meet the constraint requirement!  |
+--------------------------------------+-----------+------------+---------------------------------------+-----------------+------------------------------------------------------+

季節性を考慮した変化の検出する

履歴をもとに AnomalyDetection を行う場合は、周期性を考慮して行いたいケースが多くあります。Deequでは、週単位、年単位の周期を考慮してテストすることができます。

サンプルコード7

from pyspark.sql.types import StructType, StructField, StringType, IntegerType
from pydeequ.analyzers import *
from pydeequ.anomaly_detection import *
from pydeequ.repository import *
from pydeequ.verification import *
import datetime

now = datetime.datetime.now().timestamp()
repo = InMemoryMetricsRepository(spark)

for k in range(2):
  for i in range(1, 7):
    yesterdaysKey = ResultKey(spark, int(now * 1000) - (k*7+i) * 24 * 60 * 60 * 1000)

    df_yesterday = spark.createDataFrame(data=[
            (1, 10, "foo"), 
            (2, 20, "foo"), 
            (3, 30, "bar"), 
            (4, 40, "baz"), 
            (5, 50, "baz"), 
            (6, 60, "bar"), 
            (7, 70, None), 
            (8, 80, "bar"), 
    ], schema=StructType([
            StructField("a", IntegerType(), True),
            StructField("b", IntegerType(), True),
            StructField("c", StringType(), True),
    ]))

    checkResult = AnalysisRunner(spark) \
        .onData(df_yesterday) \
        .useRepository(repo) \
        .saveOrAppendResult(yesterdaysKey) \
        .addAnalyzer(Mean("a")) \
        .addAnalyzer(Completeness("c")) \
        .run()
  yesterdaysKey = ResultKey(spark, int(now * 1000) - (k*7+7) * 24 * 60 * 60 * 1000)

  df_yesterday = spark.createDataFrame(data=[
          (1, 10, "foo"), 
          (2, 20, "foo"), 
          (3, 30, "bar"), 
          (4, 40, "baz"), 
          (5, 50, None), 
          (6, 60, None), 
          (7, 70, None), 
          (8, 80, None), 
  ], schema=StructType([
          StructField("a", IntegerType(), True),
          StructField("b", IntegerType(), True),
          StructField("c", StringType(), True),
  ]))

  checkResult = AnalysisRunner(spark) \
      .onData(df_yesterday) \
      .useRepository(repo) \
      .saveOrAppendResult(yesterdaysKey) \
      .addAnalyzer(Mean("a")) \
      .addAnalyzer(Completeness("c")) \
      .run()

df = spark.createDataFrame(data=[
        (1, 10, "foo"), 
        (2, 20, "foo"), 
        (3, 30, "bar"), 
        (4, 40, "baz"), 
        (5, 50, None), 
        (6, 60, None), 
        (7, 70, None), 
        (8, 80, "bar"), 
        (9, 90, None), 
        (1000, 100, None), 
], schema=StructType([
        StructField("a", IntegerType(), True),
        StructField("b", IntegerType(), True),
        StructField("c", StringType(), True),
]))

todaysKey = ResultKey(spark, int(now * 1000))
checkResult = VerificationSuite(spark) \
    .onData(df) \
    .useRepository(repo) \
    .saveOrAppendResult(todaysKey) \
    .addAnomalyCheck(HoltWinters(MetricInterval.Daily, SeriesSeasonality.Weekly), Mean("a")) \
    .addAnomalyCheck(HoltWinters(MetricInterval.Daily, SeriesSeasonality.Weekly), Completeness("c")) \
    .run()

checkResult_df = VerificationResult.checkResultsAsDataFrame(spark, checkResult)
checkResult_df.show(truncate=False)

repo.load() \
  .getSuccessMetricsAsDataFrame() \
  .show()

HoltWintersは、データの頻度と周期の2つを指定することで、該当のテスト項目が履歴データから外れたものでないかを確認することができます。

結果7

+--------------------------------------+-----------+------------+---------------------------------------+-----------------+------------------------------------------------------+
|check                                 |check_level|check_status|constraint                             |constraint_status|constraint_message                                    |
+--------------------------------------+-----------+------------+---------------------------------------+-----------------+------------------------------------------------------+
|Anomaly check for Mean(a,None)        |Warning    |Warning     |AnomalyConstraint(Mean(a,None))        |Failure          |Value: 104.5 does not meet the constraint requirement!|
|Anomaly check for Completeness(c,None)|Warning    |Success     |AnomalyConstraint(Completeness(c,None))|Success          |                                                      |
+--------------------------------------+-----------+------------+---------------------------------------+-----------------+------------------------------------------------------+

Repositoryの格納内容

Repositoryには下記のように、すべての時間のすべてのAnalyzerのテスト結果が1つのファイル内(メモリだと1つのオブジェクト)に格納されます。

保存内容

[
  {
    "resultKey": {
      "dataSetDate": 1628402595296,
      "tags": {}
    },
    "analyzerContext": {
      "metricMap": [
        {
          "analyzer": {
            "analyzerName": "Size"
          },
          "metric": {
            "metricName": "DoubleMetric",
            "entity": "Dataset",
            "instance": "*",
            "name": "Size",
            "value": 8.0
          }
        },
        {
          "analyzer": {
            "analyzerName": "Maximum",
            "column": "a"
          },
          "metric": {
            "metricName": "DoubleMetric",
            "entity": "Column",
            "instance": "a",
            "name": "Maximum",
            "value": 8.0
          }
        }
      ]
    }
  },
  {
    "resultKey": {
      "dataSetDate": 1628316195296,
      "tags": {}
    },
    "analyzerContext": {
      "metricMap": [
        {
          "analyzer": {
            "analyzerName": "Size"
          },
          "metric": {
            "metricName": "DoubleMetric",
            "entity": "Dataset",
            "instance": "*",
            "name": "Size",
            "value": 10.0
          }
        },
        {
          "analyzer": {
            "analyzerName": "Maximum",
            "column": "a"
          },
          "metric": {
            "metricName": "DoubleMetric",
            "entity": "Column",
            "instance": "a",
            "name": "Maximum",
            "value": 10.0
          }
        }
      ]
    }
  },
]

その他の機能

データのプロファイリング

Profilerを用いることで、どのようなデータが格納されるカラムかを簡単に確認することができます。出力された値をもとにどのようなテスト項目があると良さそうかを、自動もしくは手動で判定するのに使うことが主な用途として想定されます。

サンプルコード8

from pydeequ.profiles import *

df = spark.createDataFrame(data=[
        (1, 10, "foo"), 
        (2, 20, "foo"), 
        (3, 30, "bar"), 
        (4, 40, "baz"), 
        (5, 50, "baz"), 
        (6, 60, "bar"), 
        (7, 70, None), 
        (8, 80, "bar"), 
        (9, 90, "bar"), 
        (10, 100, "bar"), 
], schema=StructType([
        StructField("a", IntegerType(), True),
        StructField("b", IntegerType(), True),
        StructField("c", StringType(), True),
]))

result = ColumnProfilerRunner(spark) \
    .onData(df) \
    .run()

for col, profile in result.profiles.items():
    print(profile)

取得できる主な項目は、completeness、approximateNumDistinctValues、dataType、histogram、meanなどの統計量です。

completenessはNullでない値の割合、approximateNumDistinctValuesは推定した値の種類、dataTypeはデータの型、histogramはどの値がどれだけ(数と割合)含まれているかの配列、統計量は数字カラムのみでそのカラム全体の平均や最大値などを返してくれます。

結果8

NumericProfiles for column: a: {
    "completeness": 1.0,
    "approximateNumDistinctValues": 10,
    "dataType": "Integral",
    "isDataTypeInferred": false,
    "typeCounts": {},
    "histogram": [
        [
            "8",
            1,
            0.1
        ],
        [
            "4",
            1,
            0.1
        ],
        [
            "9",
            1,
            0.1
        ],
        [
            "5",
            1,
            0.1
        ],
        [
            "10",
            1,
            0.1
        ],
        [
            "6",
            1,
            0.1
        ],
        [
            "1",
            1,
            0.1
        ],
        [
            "2",
            1,
            0.1
        ],
        [
            "7",
            1,
            0.1
        ],
        [
            "3",
            1,
            0.1
        ]
    ],
    "kll": "None",
    "mean": 5.5,
    "maximum": 10.0,
    "minimum": 1.0,
    "sum": 55.0,
    "stdDev": 2.8722813232690143,
    "approxPercentiles": []
}
NumericProfiles for column: b: {
    "completeness": 1.0,
    "approximateNumDistinctValues": 10,
    "dataType": "Integral",
    "isDataTypeInferred": false,
    "typeCounts": {},
    "histogram": [
        [
            "100",
            1,
            0.1
        ],
        [
            "40",
            1,
            0.1
        ],
        [
            "90",
            1,
            0.1
        ],
        [
            "50",
            1,
            0.1
        ],
        [
            "10",
            1,
            0.1
        ],
        [
            "80",
            1,
            0.1
        ],
        [
            "60",
            1,
            0.1
        ],
        [
            "20",
            1,
            0.1
        ],
        [
            "70",
            1,
            0.1
        ],
        [
            "30",
            1,
            0.1
        ]
    ],
    "kll": "None",
    "mean": 55.0,
    "maximum": 100.0,
    "minimum": 10.0,
    "sum": 550.0,
    "stdDev": 28.722813232690143,
    "approxPercentiles": []
}
StandardProfiles for column: c: {
    "completeness": 0.9,
    "approximateNumDistinctValues": 3,
    "dataType": "String",
    "isDataTypeInferred": false,
    "typeCounts": {
        "Boolean": 0,
        "Fractional": 0,
        "Integral": 0,
        "Unknown": 1,
        "String": 9
    },
    "histogram": [
        [
            "bar",
            5,
            0.5
        ],
        [
            "baz",
            2,
            0.2
        ],
        [
            "foo",
            2,
            0.2
        ],
        [
            "NullValue",
            1,
            0.1
        ]
    ]
}

テスト項目のレコメンデーション

ConstraintSuggestionは、Profilerからさらに進んで、Checkオブジェクトにどのようなテスト項目を追加したほうが良さそうかを、レコメンドしてくれます。

サンプルコード9

from pydeequ.suggestions import *

df = spark.createDataFrame(data=[
        (1, 10, "foo"), 
        (2, 20, "foo"), 
        (3, 30, "bar"), 
        (4, 40, "baz"), 
        (5, 50, "baz"), 
        (6, 60, "bar"), 
        (7, 70, None), 
        (8, 80, "bar"), 
        (9, 90, "bar"), 
        (10, 100, "bar"), 
], schema=StructType([
        StructField("a", IntegerType(), True),
        StructField("b", IntegerType(), True),
        StructField("c", StringType(), True),
]))

suggestionResult = ConstraintSuggestionRunner(spark) \
             .onData(df) \
             .addConstraintRule(DEFAULT()) \
             .run()

for item in suggestionResult['constraint_suggestions']:
  print(item)
  print()

レコメンドされる対象のルールは、addConstraintRuleで指定する事ができ、DEFAULTではsuggestionsモジュール配下のすべてのルールが含まれます。含まれるルールには、Nullでないかを確認するテストを追加すべきか判定する CompleteIfCompleteRule 、負の値が含まれていないかを確認するテストを追加すべきかを判定する NonNegativeNumbersRule などがあります。

constraint_suggestions の各結果には、constraint_name、column_name、current_value、description、suggesting_rule、rule_description、code_for_constraintの各値が含まれます。利用上もっとも重要なのが code_for_constraint で、Checkオブジェクトに該当のテスト項目を追加するための実装がそのまま記載されています。

結果9

{'constraint_name': 'CompletenessConstraint(Completeness(b,None))', 'column_name': 'b', 'current_value': 'Completeness: 1.0', 'description': "'b' is not null", 'suggesting_rule': 'CompleteIfCompleteRule()', 'rule_description': 'If a column is complete in the sample, we suggest a NOT NULL constraint', 'code_for_constraint': '.isComplete("b")'}

{'constraint_name': "ComplianceConstraint(Compliance('b' has no negative values,b >= 0,None))", 'column_name': 'b', 'current_value': 'Minimum: 10.0', 'description': "'b' has no negative values", 'suggesting_rule': 'NonNegativeNumbersRule()', 'rule_description': 'If we see only non-negative numbers in a column, we suggest a corresponding constraint', 'code_for_constraint': '.isNonNegative("b")'}

{'constraint_name': 'UniquenessConstraint(Uniqueness(List(b),None))', 'column_name': 'b', 'current_value': 'ApproxDistinctness: 1.0', 'description': "'b' is unique", 'suggesting_rule': 'UniqueIfApproximatelyUniqueRule()', 'rule_description': 'If the ratio of approximate num distinct values in a column is close to the number of records (within the error of the HLL sketch), we suggest a UNIQUE constraint', 'code_for_constraint': '.isUnique("b")'}

{'constraint_name': 'CompletenessConstraint(Completeness(a,None))', 'column_name': 'a', 'current_value': 'Completeness: 1.0', 'description': "'a' is not null", 'suggesting_rule': 'CompleteIfCompleteRule()', 'rule_description': 'If a column is complete in the sample, we suggest a NOT NULL constraint', 'code_for_constraint': '.isComplete("a")'}

{'constraint_name': "ComplianceConstraint(Compliance('a' has no negative values,a >= 0,None))", 'column_name': 'a', 'current_value': 'Minimum: 1.0', 'description': "'a' has no negative values", 'suggesting_rule': 'NonNegativeNumbersRule()', 'rule_description': 'If we see only non-negative numbers in a column, we suggest a corresponding constraint', 'code_for_constraint': '.isNonNegative("a")'}

{'constraint_name': 'UniquenessConstraint(Uniqueness(List(a),None))', 'column_name': 'a', 'current_value': 'ApproxDistinctness: 1.0', 'description': "'a' is unique", 'suggesting_rule': 'UniqueIfApproximatelyUniqueRule()', 'rule_description': 'If the ratio of approximate num distinct values in a column is close to the number of records (within the error of the HLL sketch), we suggest a UNIQUE constraint', 'code_for_constraint': '.isUnique("a")'}

{'constraint_name': "ComplianceConstraint(Compliance('c' has value range 'bar', 'baz', 'foo' for at least 99.0% of values,`c` IN ('bar', 'baz', 'foo'),None))", 'column_name': 'c', 'current_value': 'Compliance: 0.9999999999999999', 'description': "'c' has value range 'bar', 'baz', 'foo' for at least 99.0% of values", 'suggesting_rule': 'FractionalCategoricalRangeRule(0.9)', 'rule_description': 'If we see a categorical range for most values in a column, we suggest an IS IN (...) constraint that should hold for most values', 'code_for_constraint': '.isContainedIn("c", ["bar", "baz", "foo"], lambda x: x >= 0.99, "It should be above 0.99!")'}

{'constraint_name': 'CompletenessConstraint(Completeness(c,None))', 'column_name': 'c', 'current_value': 'Completeness: 0.9', 'description': "'c' has less than 29% missing values", 'suggesting_rule': 'RetainCompletenessRule()', 'rule_description': 'If a column is incomplete in the sample, we model its completeness as a binomial variable, estimate a confidence interval and use this to define a lower bound for the completeness', 'code_for_constraint': '.hasCompleteness("c", lambda x: x >= 0.71, "It should be above 0.71!")'}

TensorFlow Array Indexing Correspond to numpy

TensorFlowで配列処理を効率的に行うのはなかなか難しいことがあります。

例えば、下記のようなIndexing処理はnumpyでは簡単に実現することができますが、TnesorFlowではそうはいきません。

a[:, [2, 3]]

スライス以外の方法でインデックスを指定して値を取得する際には、下記のように tf.gather もしくは tf.gather_nd 関数を利用する必要があります。

tf.gather(a, [2, 3], axis=1)

また、配列の値の更新には tf.tensor_scatter_nd_update を使用する必要があります。

first = tf.tile(tf.expand_dims(tf.range(4), axis=1), (1, 2))
indexes = tf.tile([[2, 3]], (4, 1))
indices = tf.stack([first, indexes], axis=-1)
a = tf.tensor_scatter_nd_update(a, indices, tf.ones((4, 2, 4)) * 2)

この記事では、頻出するインデクシングのシチュエーションにおいて、TensorFlowでの値の取得方法、更新方法を記載していきます。

コードは、記事で紹介している以外の実装も含めて下記に置いています。

TensorFlow Indexing.ipynb · GitHub

目次

Slicingのみの場合

Slicingのみの場合、TensorFlowでも簡単にIndexingを実現できます。

numpyでの以下のような値の取得と、以下のような値の更新を行うケースを考えます。

a = np.ones((4, 4, 4))
a[:, 2:3]
a[:, 2:3] = np.ones((4, 1, 4)) * 2

このケースでは、TensorFlowでもほとんど同じように記述することができます。

a = tf.ones((4, 4, 4))
a[:, 2:3]
a = a[:, 2:3].assign(tf.ones((4, 1, 4)) * 2)

Boolean Array によるIndexingの場合

Boolean Array によるIndexingは、少し特殊な書き方が必要になりますが、パターンが分かれば簡単に実現できます。

下記のようなnumpyでの値の取得と値の更新を行うケースを考えます。

a = np.ones((4, 4, 4))
d = np.array([[[True, True, True, True], [False, False, False, False], [False, False, False, False], [False, False, False, False]]] * 4)
a[d]
a[d] = 2

値の取得は問題なくnumpyと同様に扱うことができます。

一方、値の更新はややトリッキーな書き方をする必要があります。Bool値の配列を1,0の配列に変換することにより、Indexingの配列がTrueの場合に代入する配列の値を、Falseの場合には元の配列の値を使用するような配列を作成する必要があります。

この方法では、元の配列と同じサイズの配列を用意する必要がありますが、実用上の大体のケースでは、特定の値に更新するか、元々同じサイズの配列の値に一部置き換えるといったようなケースであるため、そこまで問題にはならないでしょう。

a = tf.ones((4, 4, 4))
d = tf.constant([[[True, True, True, True], [False, False, False, False], [False, False, False, False], [False, False, False, False]]] * 4)
a[d]
d = tf.cast(d, dtype=a.dtype)
a = (1 - d) * a + d * np.ones((4, 4, 4)) * 2

Integer Array によるIndexingの場合

Integer Arrayを用いたIndexingは、TensorFlowの機能をフルに使用する必要があります。

下記のようなnumpyでの値の取得と値の更新を行うケースを考えます。

a = np.ones((4, 4, 4))
b = np.array([[2, 3]])
c = np.array([1, 2])
a[:, [2, 3]]
a[b, c]
a[:, [2, 3]] = np.ones((4, 2, 4)) * 2
a[b, c] = np.ones((1, 2, 4)) * 2

Integer Array による値の取得

値の取得は、下記のように tf.gather および tf.gather_nd を使用する必要があります。

a = tf.ones((4, 4, 4))
b = tf.constant([[2, 3]])
c = tf.constant([1, 2])
tf.gather(a, [2, 3], axis=1)
tf.gather_nd(a, tf.stack([b, [c]], axis=-1))

tf.gatherは、最初に対象の配列、次にIndexingをする対象を示す1次元の配列、axisにどの次元のIndexingを行うかを指定します。

つまり、tf.gather(a, [2, 3], axis=1)を実行すると、(4, 2, 4)の配列が取得できることになります。

tf.gather_ndは、tf.gatherを多次元の配列でIndexingするように拡張したものです。ただし、その配列の値の並びの解釈はtf.gatherとは異なります。

tf.gatherでは[2, 3]が与えられた場合、この配列は同じ次元の2番目と3番目の値を取得することを示しているのに対し、tf.gather_ndでは1次元目の2番目、2次元目の3番目の値を取得することを意味します。

つまり、tf.gatherはnumpyでいう a[[2, 3]] の挙動であり、tf.gather_ndは a[2, 3] の挙動と似たような挙動を示します。

numpyの配列と違い、tf.gather_ndは複数の[2, 3]のペアを与えることができ、またその配列のshapeに合わせて出力のshapeが変化します。

例えば、上記の行列aに対して、tf.gather_nd(a, [2,3]) を呼び出すと、numpyで言う a[2, 3] と同じ出力を得ることができますが、tf.gather_nd(a, [[2,3], [1, 2]]) を呼び出すと、 a[2, 3] の結果と a[1, 2] の結果を縦方向にstackした値を取得できます。

また、Indexingを示す配列のndimsは制限されておらず、 [[[[2,3]]]] のような配列を指定することができ、その場合の結果のshapeは(1, 1, 1, 4)になります。

注意点としては、最終次元の配列の次元数は元の配列のndimsより大きくなることはできず、また、個々にIndexingした結果はstackされることになるため、stackできるような配列になっている必要があります。

Integer Array による値の更新

値の更新は tf.tensor_scatter_nd_update を使用する事により実現できますが、tf.gatherに相当する関数がないため、Slicingを必要とする配列の更新の場合には一工夫必要になります。

具体的には、下記のような処理になります。

a = tf.Variable(tf.ones((4, 4, 4)))
b = tf.constant([[2, 3]])
c = tf.constant([1, 2])

# likely a[:, [2, 3]] = np.ones((4, 2, 4)) * 2
first = tf.tile(tf.expand_dims(tf.range(4), axis=1), (1, 2))
indexes = tf.tile([[2, 3]], (4, 1))
indices = tf.stack([first, indexes], axis=-1)
a = tf.tensor_scatter_nd_update(a, indices, tf.ones((4, 2, 4)) * 2)

# likely a[b, c] = np.ones((1, 2, 4)) * 2
indices = tf.stack([b, [c]], axis=-1)
a = tf.tensor_scatter_nd_update(a, indices, tf.ones((1, 2, 4)) * 2)

Indexingを行う対象を示す配列の仕様はtf.gather_ndと全く同じです。つまり、配列の1次元目から順にその何番目の値を更新するかを指定する必要があるため、スライス相当のIndexの指定を、tf.rangeやtf.tileなどを駆使して実装者が行う必要があります。

一方で、スライス相当の処理が必要ない場合は、tf.gather_ndと同じような考えで実現できます。代入する配列だけ、代入した領域と同じshapeの配列が必要になる点だけは注意が必要です。

Keras Loss Behavior with Language Model

KerasのModelクラスを使用した際のロスの計算は、Paddingで追加した余計な値を勾配の計算から除外する処理は自動でやってくれるのですが、
historyに記録されるlossの平均値を求める際に、maskを部分的にしか考慮しておらず、padding数が多くなればなるほど、実際のロスより小さくなってしまうという現象が発生します。

この記事は、KerasのModelクラスでLossを利用する、
特にEmbedding層で、mask_zeroをTrueにした場合に、Paddingで追加した余計な値を、勾配の計算に使用しない、ロスの計算から完全に除外する方法についてのメモです。

検証用のコードはこちらです。

マスクを使用したロスの計算は、TensorFlowのチュートリアルを参考にしています。

目次

KerasのLossの種類と種類毎の処理の違い

Kerasのmodel.compileで指定できるLossの種類は大きく分けると以下の3つがあります。
※ 参考:compile内で呼ばれているtf.keras.Model.prepare_loss_functions

  • tf.keras.losses.Lossの派生クラス

  • 独自定義したloss関数

  • 独自定義したcallableなLossオブジェクト

このうち、「独自定義したloss関数」は、prepare_loss_functions内でtf.keras.losses.Lossの派生クラスであるLossFunctionWrapperクラスに変換されるため、実質は2種類のLossオブジェクトが使い分けられることになります。

Lossを含んだ計算グラフの構築は、tf.keras.Model._prepare_total_loss内で行われます。
この関数の中で、上記2つのオブジェクトは別々の処理が行われことになります。

tf.keras.losses.Lossの派生クラスの場合の処理

tf.keras.losses.Lossの場合は、Kerasの内部関数であるtf.keras.utils.losses_util.compute_weighted_loss内で、以下の処理が行われます。

  1. ロス関数の呼び出しによるロスの計算

  2. tf.keras.utils.losses_util.scale_losses_by_sample_weightでロス×sample_weightsを計算

  3. Lossオブジェクトのreductionに応じた集計をtf.keras.utils.losses_util.reduce_weighted_lossで計算。大体の場合は、ロスの平均値を計算する処理

これから、sample_weightsを使用しない場合は、独自定義のロス関数でどんなshapeの値を返しても良いことがわかります。

ただし、sample_weightsとしてEmbedding層で計算したmaskが暗黙的に使用される場合があるので後述のように注意が必要です。

独自定義したcallableなLossオブジェクトの場合の処理

正解ラベルとニューラルネットワークの出力だけでなく、sample_weightsも引数に渡されて、Lossオブジェクトのcall関数が呼ばれます。
ベクトル値を返した場合、ベクトルの平均値が最終のロスとして使用されます。

Embedding層でmask_zeroをTrueにした場合のLossの挙動

Embedding層でmask_zeroをTrueに指定した場合、sample_weightsとしてEmbedding層で計算したmaskが暗黙的に使用されます。
maskはpadding処理で0埋めした部分はFalse、それ以外はTrueとなり、paddingで埋めた部分の計算に勾配が伝わらないようにしてくれます。

ただし、あくまで勾配が伝わらないようにしてくれるだけで、ロスの平均を取る際の分母の数をmask分減らしてはくれません。
このため、Perplexity等の計算をこのLossが出力した値を元に行うと、paddingで埋める長さが長いほど小さい値になってしまいおかしくなります。

例えば、正解ラベルとして  [1, 1, 2, 0, 0 ] の系列データがあるケースを考えます。

この時、ネットワークの出力として、各時点でのラベルの予測確率が  [0.2, 0.2, 0.2, 0.2, 0.2 ]と得られたとします。

0は計算処理の都合上いれているだけのデータで、実際の処理では無視するため、このデータに対するクロスエントロピーロスは下記のようになります。

 -(log(0.2) + log(0.2) + log(0.2)) / 3. = 1.60943...

しかしKerasのModelクラスを使用して計算すると

 -(log(0.2) + log(0.2) + log(0.2)) / 5. = 0.96566...

という値になってしまい、実際の値に比べてかなり小さくなってしまいます。

これを避けるためには、検証コード内LOSS_MODEが2の場合のように、独自定義したLossオブジェクトで、padiding部分の値を元にした処理に勾配が行かないように、かつ平均計算時の分母からの除去するように必要があります。

ちなみにMetoricsはレートに直す際に、sample_weightsの合計で割るような実装になっているため、maskの場合も意図した動作になります。

Microsoft Academic Search APIで自分専用の論文検索エンジンを作る

サーベイなどで論文検索をする時によく困るのが、キーワードをこねくり回さないと以外と読むべき論文に出会えないという点です。

特に「Dialogue System」や「Image Captioning」などのように、母数が少ないニッチな分野になると、学術用検索エンジンにキーワードを入力するだけでは、キーワードにマッチするものがトップに上がってくるだけで、必ずしもその分野を代表するような論文がヒットしてくれるわけではありません。

ホットな分野であれば、サーベイ論文、学会のチュートリアル資料など、人工知能学会の「私のブックマーク」を漁ると良さそうな情報が見つかることもありますが、なかなか新しい情報がまとまっていないということも多くあります。

その点で検索しやすいなと思っているのが、Microsoft Academicです。

下記の記事にもまとまっていますように、文献に紐付けられたトピックで論文を絞り込める、類似の論文のサジェストの性能がまあまあ良い、Saliencyなどのどのような論文から引用されているかでランキングできるという点が、論文を探す際に非常に便利で、特にあまり馴染みのない、体系的に理解できていない分野の論文を探す時に重宝しています。

個人的なオススメは、「キーワード入力」→「トピックで絞り込み」→「Saliencyでランキング」で検索すると、割といい感じに論文を見つけることができます。

ただやはり難点があって、特にトピックにまではなっていないような時に、下記のうまくいかないケースが発生します。

  • 検索キーワードと論文中のワーディングや語順が違うと引っかからないことがある

  • 分野内で、最近よく利用されている技術や引用されている考え方を知りたい

前者については検索キーワードを変えながら、後者については自分で網羅的に文献を読むことにより、なんとか達成する事もできますが、分野のことをろくに知らない状態でワーディングを考えたり、ざっととはいえ論文を読むことは結構辛いものです。

「ヒットした文献内でよく参照されている論文」などを簡単に探すことができれば、上記の要望を満たせそうですが、残念ながらそのような要望を満たしてくれるような検索エンジンは僕の知る限り存在しません。

ならば自分で作ればよいのでは!?と考えたすえ、Microsoft Research APIsの中のAcademic Search APIを使用して論文検索エンジンを自作してみました。

自作したサイトはこちらです(APIの使用制限やエラー処理が適当なため動かないことがありますので、あしからず)。

ソースコードは下記。

中で使用しているクエリの書き方は以下2つの記事か、本記事内の「クエリ構文の書き方」をご参考ください。

以降は主にAcademic Search APIの使い方について記載していきます。

目次

Academic Search APIとは

Academic Search APIは、Microsoft Academicで使用されている論文データベースに、Microsoft Academicで使用されている自然言語による検索や内部的に使用されているクエリ構文でアクセスできるように提供されているAPIです。

月に10,000トランザクションまでという制限がありますが、無料で使用することができます。APIキーはこちらのページの「Subscribe」から簡単に取得することができます。

他にも上記の論文データベースにアクセスする方法がいくつか用意されていますす。

いずれの方法も利用回数無制限で使用できますが、手続きが必要であったりとAcademic Search APIに比べると気軽さが落ちますので、データをAggregationした結果を大量に使用する、というような重たいクエリを大量に発行するような使い方をする場合に検討してみるのがオススメです。

APIの種類と使い方

APIにはいくつかメソッドがありますが、主に使用するのはInterpetEvaluateです。

  • Interpet
    Microsoft Academic上の検索窓に入力するようなクエリを投げると、「Evaluate」メソッドで使用できるクエリの形式に変換したテキストが取得できるメソッドです。
    「query」に検索したいキーワードを入力し、「complete」を1に設定すると、検索窓で表示されるクエリ候補が返されるようなイメージです。

  • Evaluate
    専用のクエリ構文で記述したクエリに基づいて、各論文の情報を取得できるメソッドで、主に今回使用するのはこのメソッドです。
    「expr」にクエリ構文で記述したクエリ、「attributes」に取得したい論文に関する情報、「count」に取得したい論文の数(1000が上限)を指定します。

Evaluateメソッドは、「expr」と「attributes」には共通の表現を使用するなど、グラフ上のデータということもあり使い勝手はなんとなくGraph QLに似ています。

ただし、参照している文献のタイトルなども指定してまとめて引っ張ってくる、といったような使い方はできませんので、このあたりはプログラムで自分でカバーしてやる必要があります。

クエリ構文の書き方

クエリ構文のポイントは、①論文の属性で絞り込む、②Compositeを使用して論文の複合属性を絞り込む、③And/Orで複雑な条件を実現するの3つです。以下それぞれについて説明していきたいと思います。

論文の属性で絞り込む

論文の属性とよんでいるのが、この属性一覧に記載されている属性のうち、「.」が含まれていないもののことです。例えば、論文のタイトルである「Ti」や出版年である「Y」などです。

主に絞り込みで使用するのは、「W」や「Y」あたりです。

「W」は、論文のタイトルに含まれていて欲しい文字列を指定します。例えば「W='text'」といったように指定します。また、後で説明する「And」や「Or」と組み合わせることにより、「And(W='text',W='generation')」といった複数のキーワードを指定した検索を実現することができます。

「Y」は、上述しましたように論文の出版年のことで、「Y=2010」のようにクエリを指定します。数値なのでもちろん範囲指定や上限・下限を指定することができ、「Y<=2015」で2015年以前の論文、「Y=[2010, 2012]」で2010年から2012年に出版された論文を検索することができます。

Compositeを使用して論文の複合属性を絞り込む

ここで論文の複合属性と読んでいるのが、「.」が含まれている属性のことで、「AA.AuN」(著者名)や「F.FN」(分野名)のことです。

これらは上記のシンプルな属性とは異なり「Composite(AA.AuN='james')」のようにクエリを書く必要があります。論文を絞り込むための属性そのものが、一意に決定できないようなケースがあるため、このような特殊な書き方をします。

また、この性質のため、And/Or検索時の挙動には少し注意が必要です。本サイトの記事からの引用となりますが、例えば、「Composite(And(AA.AuN='mike smith',AA.AfN='harvard university'))」は「ハーバード大学に所属するmike smithが著者にいる」論文が検索されますが、「And(Composite(AA.AuN='mike smith'),Composite(AA.AfN='harvard university'))」は、「ある著書がmike smithで、ある著者がハーバード大学に所属する」論文を検索されます。

検索条件としてよく使用するのは、

And/Orで複雑な条件を実現する

And/Or検索はさして難しくなく、And()やOr()内に、有効なクエリを「,」で区切ることで実現できます。

例えば、「And(W='text',Y<=2015)」は、タイトルに「text」が含まれる2015年以前の論文を表しますし、「And(Or(W='text',W='sentence'),Y<=2015)」は、タイトルに「text」か「sentence」が含まれる2015年以前の論文を表します。

もちろん複合属性をクエリに含めることもでき、「And(Composite(F.FN='computer vision'),Y<=2015)」は、Computer Vision分野の2015年以前の論文を表します。

自分専用の論文検索エンジンを作る

上記までで、かなり複雑なクエリを実現できるようになりましたが、特に「分野内で、最近よく利用されている技術や引用されている考え方を知りたい」を見つけてきたい場合は、クエリ検索だけではいまいちなことがあります。

下記は、「And(W='dialogue',W='generation',Y>=2015)」(2015年以降の「dialogue」と「generation」をタイトルに含む論文)から上位10件をAPIを使用して検索した結果になります。重複した結果は、学会で発表された論文とarxivなどのプリプリント論文が別のものとして扱われているために発生しています。

- Semantically Conditioned LSTM-based Natural Language Generation for Spoken Dialogue Systems
- Semantically Conditioned LSTM-based Natural Language Generation for Spoken Dialogue Systems
- Adversarial Learning for Neural Dialogue Generation
- Adversarial Learning for Neural Dialogue Generation
- Deep Reinforcement Learning for Dialogue Generation
- How NOT To Evaluate Your Dialogue System: An Empirical Study of Unsupervised Evaluation Metrics for Dialogue Response Generation
- How NOT To Evaluate Your Dialogue System: An Empirical Study of Unsupervised Evaluation Metrics for Dialogue Response Generation
- Deep Reinforcement Learning for Dialogue Generation
- Multiresolution Recurrent Neural Networks: An Application to Dialogue Response Generation.
- Multiresolution Recurrent Neural Networks: An Application to Dialogue Response Generation

この検索結果はいくつか難点があります。1つは、明示的に検索クエリ(この場合は「dialogue」)が含まれる論文しか検索結果に含まれない点。その結果、分野に関連する代表的な論文を必ずしも見つけるができない、というのが2点目です。

Computer Vision」や「Natural Language Processing」のようなトピックとして整理されているレベルのものであればもっとマシな結果が得られますが、サブタスクのような実際に興味のある範囲や、新しく振興してきた分野については適当な結果が得られないことが多くあります。

複雑なクエリをプログラムで実現する

上記のようなケースで最も効果的なのが、「ヒットした文献内でよく参照されている論文」を見つけることです。基本的にはこれらは自力で論文を読む中で見つけてくるものですが、できれば最初から重要そうな論文を見つけたいというのが本音です。

Academic Search API一発では上記を実現できませんが、応答を元に集計処理を行えば簡単に上記のような論文を見つけることができます。

方針としては下記です。

  1. クエリにマッチするような論文を最大件数(1000件)検索する

  2. 応答内の各論文の「RId」属性を参照し、その論文が引用している論文を抽出する

  3. 抽出した引用されている論文を、引用数が多い順にソートする

  4. 引用数が多い上位N件の論文の情報を、Academic Search APIで再度取得する

この手順で先程と同じ「And(W='dialogue',W='generation',Y>=2015)」で論文を検索した結果が下記です。Seq2SeqやAttentionのような要素技術や、この分野で近年注目されている論文を効果的に取得することができています。

- Sequence to Sequence Learning with Neural Networks
- Neural Machine Translation by Jointly Learning to Align and Translate
- Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation
- Bleu: a Method for Automatic Evaluation of Machine Translation
- Long short-term memory
- Building End-To-End Dialogue Systems Using Generative Hierarchical Neural Network Models
- A Neural Conversational Model
- A Diversity-Promoting Objective Function for Neural Conversation Models
- A Hierarchical Latent Variable Encoder-Decoder Model for Generating Dialogues
- A Network-based End-to-End Trainable Task-oriented Dialogue System

まとめ

この記事では、Academic Search APIの紹介とそれをもとに作成した自作の論文検索エンジンについて紹介しました。

検索できるクエリは紹介したもの以外にも色々ありますし、プログラムを使用することも考えるとより複雑なクエリを実現することができます。

また、プログラムで論文管理ツールやスプレッドシートと連携することも簡単にできますし、APIがあるだけでサーベイ作業をだいぶ効率化できそうだなと考えています。

ぜひ皆さんも自分専用の論文検索エンジンを作ってみてください。

VariationalでEnd2EndなDialogue Response Generationの世界

この記事は、自然言語処理 #2 Advent Calendar 2019の24日目の記事です。

Open-Domain Dialogueや非タスク指向対話、雑談対話と呼ばれる領域において、発話データのみを使用したEnd2Endな対話応答生成を試みる歴史はそこまで古くなく、[Ritter et al+ 11]や[Jafarpour+ 10]がまず名前をあげられるように、比較的最近始まった研究テーマとなります。

これらは、Twitterなどの登場により、ユーザー間で行われる、ほとんどドメインを限定しない、もしくは多様なドメインにまたがる、大量の対話データを、容易に収集できるようになったことにより、活発に研究されるようになってきました。

初期の研究である[Ritter+ 11]や[Jafarpour+ 10]では、統計的機械翻訳ベースや情報検索ベースの手法でEnd2Endな対話システムを構成していました。

これらの研究は、複数のモジュールをつなぎ合わせ、対話データだけでなく、人がアノテーションしたデータを元にシステムの学習を行う、従来の対話システム研究にパラダイムシフトをもたらしました。しかし、下記のような雑談対話システムの特性のため、初期の研究では扱える処理に限界がありました。

  • 一つのクエリー(ユーザー発話)に対して、複数の応答(システム発話)が可能である点

  • クエリー自体はそこまで長いものでないため、長期的なコンテキスト(対話履歴など)を考慮した応答生成を行う必要がある点

それを受けて現れたのが、ニューラルネットワークを使用した対話システムです。[Sordoni+ 15][Vinyals+ 15][Shang+ 15]

これらの論文はほぼ同時期に登場することになりますが、いずれも、ニューラルネットワークの表現学習能力を利用して、対話の長いコンテキストを一定次元の潜在表現に圧縮し、それを元にRNNを使用したデコーダーにて応答を生成することを試みています。

また、[Serban+ 16]にて提案されたHREDモデルは、発話文のエンコーダーと、コンテキストのエンコーダーを分けることにより、より効率的に長い対話コンテキストを利用できることを示し、構成の簡単さもあり、現在ではEnd2Endな対話応答生成のベースラインとしてよく使用されるモデルとなっています。

しかし、単純にニューラルネット化しただけでは、長期的なコンテキストを考慮しやすくはなっても、「I'm OK」や「I don't know」のような、どのようなコンテキストにも合いやすい文を生成してしまうという問題は解決しませんでした。

[Li+ 16]では、相互情報量(Mutual Information)を目的関数にすることにより、出現頻度の高い文を応答として使用することに対し、ペナルティーを与えるような設定でそれを克服しようとしました。

[Li+ 17]では、GANベースの学習方法を使用して、生成される対話を、実際のデータセットに含まれるような対話に近づけることにより、この問題を克服することを試みています。

そして、[Serban+ 17]および[Zhao+ 17]では、VAE、特にCVAEのテクニックを対話応答生成に応用して、潜在変数のサンプリングによって多様な応答生成が行えるようになることを示しています。

本記事では、特にこのVAEベースの手法に注目して、どのように発展してきたか、現在はどのような部分にスポットが当たっているかについて記載していきたいと思います。

目次

End2End対話システムの基本

End2Endな対話応答生成の問題設定

End2Endな対話応答生成は、N個の単語(もしくはトークンおよび文字)からなるユーザー発話文  x = \{w_1^{x}, w_2^{x}, ..., w_N^{x} \} と、L個のユーザー発話とシステム発話の履歴であるコンテキスト  C = \{c_1 = \{w_1^{c_1}, w_2^{c_1}, ..., w_{N'}^{c_1}\}, c_2, ..., c_L\} が与えられた際に、適当なシステム発話文  y = \{w_1^{y}, w_2^{y}, ..., w_M^{y} \} を生成することが目的となります。

ここでは、一般的な問題設定を紹介していますが、ユーザー発話やシステム発話は必ずしもテキストのみである必要はありません。またコンテキストとして、対話履歴だけでなく、注目している画像や、文章、知識ベースなど、多様な周辺情報が含まれるケースも考える事ができます。

ここで、モデル  P_\theta (y | x, C) を考えた場合、自動で応答を生成することができるように、データセット  D = \{x_i, y_i, C_i \}を元にモデルのパラメーター  \theta を学習することが、End2Endな対話応答生成システムに機械学習システムを適用する際の基本的な考え方となります。

データセットにそもそも含まれないような状況には、対応することができないという課題もありますが、それはルールベースやアノテーションデータをもとに学習するような既存の対話システムでも同様です。

End2Endな対話システムは、ユーザー発話の理解から応答の生成までを一貫して行うことにより、データセットに含まれるような対話とは多少異なるような状況であっても、コンテキストやユーザー発話の違いを考慮して、より柔軟な応答が生成されることを期待しているものと考えることができます。

RNNを利用した対話応答生成

DCGM(Dynamic-Context Generative Model)

[Sordoni+ 15a]では、ユーザー発話文 xとコンテキスト Cに含まれる単語を元にbag-of-words表現を生成し、それをもとにRNNを使用したデコーダーにより、システム発話文 yを生成するシステムが提案されています。

直前のユーザー発話文とコンテキストのBOWのベクトルを分けるケースと、まとめたBOWを使用する2パターンが提案されていますが、以降の処理は共通で、BOWベクトルを元に多層のニューラルネットワークを元にコンテキストベクトル c_lを生成した後、RNNの各ステップで c_lを使用するようにシンプルなRNNを改良した層を元に、 yを構成する各単語を次々に生成していくようなアーキテクチャです。

 
\begin{align}
h_t &= \sigma (h_{t-1} W_h + c_l + w_{t-1}^{y} W_{in}) \\
p_\theta &(w_{t}^{y} | w_{1}^{y}, ..., w_{t-1}^{y}, x, C) = \sigma(h_t W_{out}) \\
\end{align}

ここで  W _ h, W _ {in}, W _ {out} は学習により最適化するNNのパラメーターです。以降も大文字の Wで表す行列は基本的に学習対象のパラメーターとします。

このように、RNNの計算中に逐次コンテキストを考慮するように計算を行うことにより、既存の機械翻訳ベース、情報検索ベースの手法と比較して、SNSでのやり取りを元にした対話応答生成タスクにおいて、高い性能で応答生成が行えることが示しています。

Seq2Seqを利用した対話応答生成

Neural Conversational Model

Seq2Seqは、[Sutskever+ 14]にて提案された、RNNを利用したEncoder-Decoderを持つネットワーク構造のことです。

この構造の一番大きな特徴は、そのシンプルさです。Encoderは、入力文ををもとにどんどん自身の隠れ層を更新していき、入力が完了した際の隠れ層のベクトル h _ {N}^{E}を出力とします。デコーダーは、この h _ N^{E}を自身の隠れ層の初期値( h _ 0^{D})とした上で、スタートトークンから逐次的に w_{t}^{y}をRNNの出力層を元に生成してきます。

f:id:KSKSKSKS2:20191222180159p:plain
[Luong+ 15]より

このモデルは、このようなシンプルな構造ながら当時の機械翻訳の最高性能モデルに匹敵する性能を示し、テキスト生成界隈を驚かせました。

[Vinyals+ 15]は、このSeq2Seqを対話応答生成に素直に応用した手法で、直前のユーザー発話文を含む対話履歴(つまり、 x C)をエンコーダーに入力し、システム発話文をデコーダーにより生成することを試みました。

この論文の結果は非常に衝撃的で、大量の対話データさえあれば、このようなシンプルな構造のモデルでも、対話システムとして成立する可能性を示しました。

一方で、同じ対話中であっても、コンテキストをうまく考慮できず、似たような質問に異なる応答を返してしまう問題があり、この一貫性(Consistency)については、以降の対話応答生成手法においても、重要なキーワードの一つになっていきます。

Neural Responding Machine

Seq2Seqにはほぼセットで使用される機構にAttentionがあります。Attentionは、デコーダーの注目している隠れ層のベクトルを元に、エンコーダーのどのステップの隠れ層に注目するかを決定し、注目した隠れ層のベクトルの値をデコーダーが生成する出力の計算に利用します。

確率的に注目するエンコーダーのステップを決定するソフトアテンションと呼ばれる機構を、[Luong+ 15]ではさらにグローバルアテンションとローカルアテンションに分類しており、よく使用されるのはグローバルアテンションです。

グローバルアテンションの計算は、 h_t^{D}デコーダーの隠れ層のベクトル、 h_t^{E}エンコーダーの隠れ層のベクトルとした場合、下記のように表せます。RNNが複数層スタックされている場合は、一般に最終層の隠れ層のベクトルを使用します。

 
\begin{align}
a_{it} &= \frac {score(h_t^{D}, h_i^{E})} {\sum_j^{N} score(h_t^{D}, h_j^{E}) }  \\
c_t &= \sum_i^{N} a_{it} h_i^{E} \\
p_\theta &(w_{t}^{y} | w_{1}^{y}, ..., w_{t-1}^{y}, x, C) = softmax([c_t; h_t^D] W_{out}) \\
\end{align}

 score は隠れ層のベクトル同士の類似度を計算する関数で、一般にはドット積( h_t^{D} h_j^{E})や h_t^{D} W_a h_j^{E}がよく使用されます。

f:id:KSKSKSKS2:20191222184629p:plain
[Luong+ 15]より

主に機械翻訳分野にて使われだしたアテンション機構ですが、[Shang+ 15]ではこの仕組みを対話応答生成に利用しました。

一つ前のデコーダー隠れ層のベクトルを元にアテンションを計算する部分と、アテンションを元に最新の隠れ層のベクトルを更新する部分以外は、前述のグローバルアテンションと同じような機構を使用し、Weiboで収集した大量の対話データを元に、既存の検索ベース、統計的機械翻訳ベースのシステムを上回る、対話応答生成性能を示しました。

このように、Seq2Seqベースの手法と大量の雑談対話データとの相性の良さが様々な研究によって示され、End2Endな対話応答生成においては、ニューラルネット、特にRNNを使用したSeq2Seqベースの構造は非常に良く使用されるようになっていきます。

階層構造を利用した対話応答生成

HRED(Hierarchical Recurrent Encoder-Decoder)

階層型の応答生成アーキテクチャであるHREDは、[Sordoni+ 15b]にて、まずはマルチターンの質問応答タスクにおいて採用されたSeq2Seqの拡張手法です。 通常のSeq2Seqとの大きな違いは、エンコーダーが発話エンコーダーとコンテキストエンコーダーの2つに分かれる点です。  h_t^{UE}を発話エンコーダーの隠れ層のベクトル、 h_l^{CE}をコンテキストエンコーダーの隠れ層のベクトルとすると、下記の用に表現する事ができます。

 
\begin{align}
h_t^{UE} &= UtteranceEncoder(x)  \\
h_l^{CE} &= ContextEncoder(h_{l-1}^{CE}, h_t^{UE}) \\
p_\theta(y | x, C) &= Decoder(h_l^{CE}) \\
\end{align}

UtteranceEncoder、ContextEncoder、Decoderは基本的にRNNを使用することが多く、コンテキストエンコーダーの隠れ層のベクトルは、一層のニューラルネットワークで次元数をそろえる処理をした後、デコーダーの隠れ層の初期値として使用されることが一般的です。

f:id:KSKSKSKS2:20191222193446p:plain
[Sordoni+ 15b]より

[Sordoni+ 15b]でマルチターンの質問応答タスクで応用されたHREDは、[Serban+ 16]で対話応答生成に応用され、映画シナリオを元に作成したデータセットにおいて、直前のユーザー発話だけでなく、対話履歴中の発話も考慮した発話生成が行える傾向が見られることを示しました。

HREDは、長いコンテキストを扱うことになる対話応答生成と相性が良いと考えられており、現在では対話応答生成のベースラインモデルとして頻繁に参照されています。

VariationalでEnd2Endな対話応答生成

上記までで、大量の対話データとニューラルネットを使用することにより、長いコンテキストを考慮したEnd2Endな対話応答生成処理を実現できそうであることがわかりました。

しかし、これらの対話応答生成には、特に確率的に最も有り得そうな文をデコーダーからサンプルすると、「I'm OK」や「I don't know」のような、どのようなコンテキストにも合う文を生成してしまうという課題がありました。[Serban+ 16][Li+ 16]

f:id:KSKSKSKS2:20191224111954p:plain
[Li+ 16]より

上の表は、Seq2Seqベースの手法で学習したEnd2End対話システムにて、各ユーザー発話文を与えた際に、尤度の高い順にBeam Searchでデコードしたシステム発話文を並べたものです。

ユーザ発話に対する固有の応答よりも、どのユーザー発話にでもある程度使いまわしのきく文が、生成されてしまいやすいという問題が見えてきます。

非常に限定された情報しか含まない対話データを使用して、最尤推定により生成モデルを作成する以上、上記は避けることが難しい問題でもあるため、目的関数を相互情報量ベースのものや、GAN・強化学習ベースのものに置き換えることにより、これらを解決しようとも試みられています。

一方で、学習の安定性が高い最尤推定の枠組みのままこの問題を解決しようという方向性もあります。それが、VAEを応用した対話応答生成手法です。VAEを応用した手法では、潜在変数をサンプリングすることにより、生成される文の多様化を試みます。

以降では、対話応答生成にどのようにVAEが応用されているのかについて記載していきます。

VAE(Variational Auto-Encoder)とテキスト生成

変分近似の基本

VAE(Variational Auto-Encoder)は、[Kingma+ 14a]にて提案された、ニューラルネットワークを使用した生成モデルの一種です。

VAEは、他の最尤推定ベースの生成モデルと同様、データ xが与えられた際に、尤度  P_\theta(x)が最大となるモデルのパラメーター \thetaを推定することを考えます。

この時、一般に潜在変数 zを用いて、下記のようにモデル化します。

 
\begin{align}
P_\theta(X) &= \int P_\theta(x, z) dz  \\
&= \int P_\theta(x | z)P(z) dz \\
\end{align}

一般的には、このままでは最適なパラメーター \thetaを解析的に求める事はできないため、別のパラメーターを持つ事前分布  Q_\phi(z)を用いて、以下のように対数尤度を変形します。

 
\begin{align}
log P_\theta(x) &= log \int P_\theta(x, z) dz  \\
&= log \int Q_\phi(z) \bullet P_\theta(x, z) / Q_\phi(z) dz \\
& \ge \int Q_\phi(z) \bullet log (P_\theta(x, z) / Q_\phi(z)) dz \\
& = \int Q_\phi(z) \bullet log (P_\theta(x | z)P(z) / Q_\phi(z)) dz \\
& = \int Q_\phi(z) \bullet log (P(z) / Q_\phi(z)) + Q_\phi(z) \bullet log P_\theta(x | z) dz \\
& = E_{z \sim Q_\phi(z)}[ log P_\theta(x | z) ] - KL[Q_\phi(z) \parallel P(z)] \\
&= L(x, z)
\end{align}

  L(X, z)を変分下限と呼ばれるもので、文字通り尤度の下限に当たる値です。尤度を最大化することを考える場合、この変分下限の値が最大となるようなパラメーター \thetaを求めればよいわけです。

ここで、 KL[Q _ \phi(z) \parallel P(z)]は、 Q _ \phi(z) P(z)のKLダイバージェンス E _ {z \sim Q _ \phi(z)}[ log P _ \theta(x | z) ]は、 Q _ {\phi}(z)を考えた時の log P _ {\theta}(x | z)の期待値です。

一般的な最尤推定EMアルゴリズム)では、 P(z)を標準正規分布として、 KL[Q _ {\phi}(z) \parallel P(z)]を最小化するような \phiを選んだ後(Eステップ)、 E _ {z \sim Q _ \phi(z)}[ log P _ \theta(x | z) ]を最大化する \thetaを選ぶ処理(Mステップ)を繰り返して、最適なパラメーターを求めます。

VAEの基本

VAEでは、事後分布の近似である Q _ {\phi}(z)をデータ xの識別モデルである Q _ {\phi}(z | x)に拡張し、 P _ \theta(x | z)とともにニューラルネットに置き換えてモデル化します。

この時、 Q _ \phi(z | x)を識別モデル(Discriminator)や事後分布(Posterior)、 P(z)を事前分布(Prior)、 P _ \theta(x | z)を生成モデル(Generator)と一般的に呼ばれます。

また、変分下限は、 L(x, z) = E _ {z \sim Q _ \phi(z | x)}[ log P _ \theta(x | z) ] - KL[Q _ \phi(z | x) \parallel P(z)] となります。

VAEでは、上記のような変更を行うことで、以下のような効果を期待します。

  1. パラメーター \thetaを最適化する際に用いる zの効率的な獲得

  2. データを反映した潜在表現である zの獲得

  3.  P _ \theta(x | z)からサンプリングの簡易化

ここで、生成モデルをニューラルネットにするにあたり、全てのモデルは微分可能な関数として置く必要があります。 P(z)を標準正規分布として置く場合、 Q _ {\phi}(z \mid x)正規分布となるように置くと、KLダイバージェンスを解析的に求めることができるようになるため、 \thetaについては微分可能となります。

しかし、 Q _ {\phi}(z | x) を直接確率分布としてモデル化しまうと、 zが確率変数となるため \phiについて微分することができません。

そこで、下記のようにReparameterization Trickを用いて、 zを決定変数として扱えるようにします。

  1. 確率変数 \epsilon N(0, I)からサンプリングする

  2.  z = \mu _ \phi(x) + \sigma _ \phi(x) \bullet \epsilon zを定義する

  3.  E _ {z \sim Q _ \phi(z | x)}[ log P _ \theta(x | z) ]が E _ {\epsilon \sim P(\epsilon)}[ log P _ \theta(x | \mu _ \phi(x) + \sigma _ \phi(x) \bullet \epsilon) ]と考えられるようになる

 \mu _ \phi(x)および \sigma _ \phi(x)は、 xから正規分布の平均と分散を推定するモデルで、一般的に途中までは一つのCNNやRNNで構成され、それらの出力をMLPでそれぞれように変換する処理を行います。

これにより、 \phiについても微分できるようになり、変分下限をSGDを使用して最大化できるようになります。

VAEを利用したテキスト生成

VAEをSeq2Seqを利用してテキストの生成に利用したモデルはいくつかありましたが、[Bowman +16]では、潜在変数をより生成するテキストに反映できるようにし、VAEによるテキスト生成の有効性を示した点で画期的でした。

Seq2Seqを用いたモデルでは、特にAttentionを利用しないような状態では、エンコーダーで捉えた情報をほとんど使わずに、デコーダーが入力された単語の情報を元にテキストを生成してしまうという問題がありました。

VAEでは、潜在変数 zに格納される情報に価値がなくなり、KL項が早々に0になってしまうという形で問題が現れます(KL-Vanishing Problem や KL Collapse Problem 、 Posterior Collapse Problem と呼ばれることもあります)。

このようなVAEの弱点に対し、[Bowman +16]では、KLコストアニーリングやWord Dropoutという手法を提案し、入力文を復元できるような潜在表現をエンコーダーによって獲得でき、デコード時に全く単語の情報を与えない状況でも、デコーダーは潜在表現をもとに出力文を生成できるようになることを示しました。

f:id:KSKSKSKS2:20191223124628p:plain
[Bowman +16]より

また、事後分布からサンプルした潜在変数で文を再構成した場合、意味的に同一だが表現が多少異なるような文を生成できること、それぞれの文の事後分布からサンプルした潜在変数の調和平均を取った潜在変数をデコードした場合、それぞれの文を混ぜたような意味や表現の文が生成されるなど、潜在表現を元にデコードしていると考えられる実験結果をいくつも提示しています。

f:id:KSKSKSKS2:20191223124649p:plain
[Bowman +16]より

この成果を受けて、多様な文を生成したい際に、VAEを使用して実現するというアプローチが広く行われるようになっていきます。

CVAEを利用した対話応答生成

CVAEの基本

CVAE(Conditional Variational Auto-Encoder)は、[Sohn+ 15]にて提案された、VAEをConditional情報を含んだ状態に拡張し、Conditional情報が与えられた際に学習データが生成される条件付き尤度、 P _ \theta(x | C)を最大化するように学習する手法です。

似たような定式化を行っている[Kingma+ 14b]では、 P _ \theta(x, C)や、コンディショなる情報を非自明とした上で P _ \theta(x)を最大化する事を考えますが、[Sohn+ 15]では条件付き尤度である P _ \theta(x | C)を最大化するように最適化します。

変分下限は、コンディショナル情報を加えた形で、下記のように定義します。


L(x, C, z) = E_{z \sim Q_\phi(z | x, C)}[ log P_\theta(x | C, z) ] - KL[Q_\phi(z | x, C) \parallel P_\theta(z | C)]

事前分布にも、事後分布にも、生成モデルにもコンディショナル情報の条件が含まれるようになる点がポイントです。 P  _ \theta(x | C)の条件付き尤度の式からも分かるように、CVAEはコンディショナル情報からデータ xを生成することに使用することを主な利用目的としています。

[Sohn+ 15]では、不完全なデータをコンディショナル情報 Cとして、 xを復元するためにCVAEを使用しています。

以降で述べるように、対話応答生成では、対話履歴と直前のユーザー発話をコンディショナル情報として、そこから次のシステム発話を生成することを試みることになります。

VHRED

VHREDは、[Serban+ 17]にて提案されたHREDをCVAE化した手法です。ContextEncoderが出力した隠れ層のベクトルの値を元に、事後分布 Q_\phi(z | y, C)から潜在変数 zをサンプリングして、それを元に出力文を生成するように学習します。

 
\begin{align}
h_t^{UE} &= UtteranceEncoder(x)  \\
h_l^{CE} &= ContextEncoder(h_{l-1}^{CE}, h_t^{UE}) \\
h_{t+1}^{UE} &= UtteranceEncoder(y) \\
z &\sim Q_\phi(z | h_{t+1}^{UE}, h_l^{CE}) \\
P_\theta(y | x, C) &= Decoder(z, h_l^{CE}) \\
\end{align}

実際に利用する際は、事前分布 P _ \theta(z | C)から潜在変数 zをサンプリングして、次のシステム発話を生成することを試みます。コンディショナル情報 C = \{x, C \}は、対話履歴と直前のユーザー発話から構成される前提です。

f:id:KSKSKSKS2:20191223130321p:plain
[Serban+ 17]より

VHREDは、SNS上の対話データや具体的な技術に関する対話データにおいて、特に数ターンと長いコンテキストを持つ際の対話応答生成、および生成される文の多様性で、通常のSeq2SeqベースのモデルおよびHREDと比較し、大幅に性能が向上することを示しました。

VAEの利点である、データを反映した潜在表現を得られるようになる点、サンプリングすることで生成結果の多様性に影響を与えることができる点が、対話応答生成において有効に機能する結果を示した論文です。

CVAE/kgCVAE

[Zhao+ 17]では、VHREDとほぼ同時期にほぼ同じようなモデルを提案しています。

VHREDと異なる点は、学習時にWord Dropoutの代わりに、(1)新たに提案したAuxiliary LossであるBag-of-Words Lossを用いて、潜在変数から生成する文のBOWベクトルを復元できるように学習できるようにした点と、(2)潜在変数だけでなく何かしらのラベル([Zhao+ 17]ではダイアログアクト)もコンディショナル情報から推定し、推定した潜在変数とラベルをもとにデコーダーで文を生成するkgCVAE(Knowledge-Guided CVAE)という拡張手法を提案している点です。

f:id:KSKSKSKS2:20191223150228p:plain
[Zhao+ 17]より

ラベル情報を aと置くと、ラベル情報を追加した場合の変分下限は下記のように定式化できます(図中では、ラベル情報が y、生成対象のテキストが xと表現されていますが、同じことを表しています)。


\begin{align}
L(y, C, a, z) &= E_{z \sim Q_\phi(z | y, C, a)}[ log P_\theta(y | C, a, z) ] - KL[Q_\phi(z | y, C, a) \parallel P_\theta(z | C)] \\
&+ E_{z \sim Q_\phi(z | y, C, a)}[ log P_\theta(a | C, z) ]
\end{align}

式の通り、識別モデル Q _ \phiは、生成対象の文とコンディショナル情報だけでなく、真のラベル情報も元にして潜在変数 zを推定するように学習します。一方デコーダ P _ \theta(y)は、コンディショナル情報、潜在変数、推定した(もしくは陽に与えた)ラベル情報から対象の文を生成します。

f:id:KSKSKSKS2:20191223145900p:plain
[Zhao+ 17]より

よりグラフィカルに書くと、上図のようになります。 MLP _ yと記載されている部分が、 P _ \theta(a | C, z)に、 MLP _ bと記載されている部分がBOW Lossに対応します。

青いラベル情報に関する処理を抜くとただのCVAEであり、基本的にVHREDと同様になります。

実験では、特にkgCVAEは、対話履歴を考慮するだけでなく、指定したラベルに基づいた応答を生成できることを示し、様々な情報を追加することにより、コントロール性と生成する文の多様性の両方を兼ね備えた、対話応答生成の可能性を示しました。

VariationalでEnd2Endな対話応答生成の現在

VHREDやCVAEで一通り基本的なモデルは出尽くしたきらいがありますが、より有用な対話応答生成を行うための拡張手法の提案が日々行われています。

その方向性の一つが、潜在変数をよりデコーダーが参照するように、デコーダーが参照しやすいようによりリッチな潜在変数表現を獲得するように学習し、今まで以上に多様な応答生成を可能にしようという方向性です。

も一つが、単純に、対話応答を生成するだけでなく、ある制約ある条件下のもとで対話応答を生成したいという方向性です。

以降では、多様な応答生成を可能にするためにどのようにVHREDが改良されていっているか、制約条件として発話するユーザーを指定して応答を生成したいというケースに、どのように対応することが試みられているかについて紹介していきたいと思います。

生成される発話の多様性の改善

VHCR

[Park+ 18]では、VHREDに対話全体の潜在変数である  z^{conv} を追加したVHCRモデルとUtterance Dropテクニックを元に、より潜在変数を考慮した対話の生成が行える方法を提案しています。

 
\begin{align}
h^{conv} &= ConvEncoder(all\_dialogue) \\
z^{conv} &\sim  Q_\phi(z | h^{conv}) \\
h_t^{UE} &= UtteranceEncoder(x)  \\
h_l^{CE} &= ContextEncoder(h_{l-1}^{CE}, h_t^{UE}, z^{conv}) \\
h_{t+1}^{UE} &= UtteranceEncoder(y) \\
z &\sim Q_\phi(z | h_{t+1}^{UE}, h_l^{CE}, z^{conv}) \\
P_\theta(y | x, C) &= Decoder(z, h_l^{CE}, z^{conv}) \\
\end{align}

 z^{conv}は全ての対話データから事後分布を元に推定され、コンテキストエンコーダーおよびデコーダーで使用されます。

f:id:KSKSKSKS2:20191224121827p:plain
[Park+ 18]より

変分下限は論文中には明記されていませんが、下記のようになります。


\begin{align}
L(y, C, z) &= E_{z^{conv}\sim Q_\phi(z^{conv} | all\_dialogue)}[E_{z \sim Q_\phi(z | y, C, z^{conv})}[ log P_\theta(y | C, z, z^{conv}) ] - KL[Q_\phi(z | y, C, z^{conv}) \parallel P_\theta(z | C, z^{conv})]] \\
&- KL[Q_\phi(z^{conv} | all\_dialogue) \parallel P(z^{conv})]
\end{align}

Utterance Dropは、コンテキストエンコーダーにわたす h_t^{UE}をランダムに事前定義済みのUnknown Vectorに変更することにより、 z^{conv}をもとにシステム発話文が生成されることを狙うテクニックです。

この二つの改良により、VHCRは発話一つずつではなく、対話全体を一つの潜在変数によりコントロールすることを試みており、実験では、 z^{conv}をもとに対話全体を実際にサンプリングすることができることを示しています。

CVAE+VAE

[Shen+ 18]では、KL項が消失してしまう問題の原因を、潜在変数の表現力が正規分布に制限されてしまうことであると考え、VHREDにAAE(Adversarial Auto-Encoder)[Makhzani+ 16]を応用し、潜在変数 zがより多様な形を取りうるように改良した手法です。

AAEはVAEのKL項をGANのスキームに置き換えることにより、多様な分布を事前分布として仮定できるようにした枠組みですが、[Shen+ 18]ではGANをさらにCVAEに置き換えることにより、コンディショナル情報も踏まえた上で、事前分布と事後分布を比較する枠組みに改良した手法を提案しました。

f:id:KSKSKSKS2:20191223163832p:plain
[Shen+ 18]より

上図のM1はほぼVAEと等価な構成で、対話履歴とシステム発話文yをエンコードした情報から、システム発話文yを再構成するようにパラメーターを最適化します。

VAEとの大きな違いは、(1)潜在変数を直接サンプリングするのではなく \epsilonという緩衝材となるパラメーターを用いて決める点と、(2) \epsilonを潜在変数に変換するモデル G _ \phi(\epsilon)はすでに学習済みのものを使用するという点、(3)KL項はM2モデルで別途最適化を行うという点です。


\begin{align}
\tilde{z} &= UtteranceEncoder(y) \\
z  &= (1 -p) G_\phi(\epsilon) + p \tilde{z} \\
L_{M1}(y, z, C) &= E _ {z \sim Q _ \phi(z | \tilde{z}, C)}[ log P _ \theta(y | z, C) ] \\
\end{align}

 \epsilonを潜在変数に変換するモデル G _ \phi(\epsilon)は、上図のM2に当たるCVAEパートにより学習を行います。対話履歴のみから \epsilonを推定する事前分布 P _ \theta(\epsilon | C)と、システム発話文yの情報も利用して \epsilonを推定する事後分布 Q _ \phi(\epsilon | \tilde{z}, C) の出力が近づくように学習を行います。

また、発話エンコーダーがシステム発話文 yエンコードしたベクトル \tilde{z}を、 \epsilonから再構成できるように学習することにより、事前分布の形状によらない潜在変数を対話履歴のみから得ることを試みます。


\begin{align}
&\tilde{z} = UtteranceEncoder(y) \\
&L_{M2}(\tilde{z}, C, \epsilon) = E_{\epsilon \sim Q_\phi(\epsilon | \tilde{z}, C)}[ log P_\theta(z | C, \epsilon) ] - KL[Q_\phi(\epsilon | \tilde{z}, C)  \parallel P_\theta(\epsilon | C)] \\
&E_{\epsilon \sim Q_\phi(\epsilon | \tilde{z}, C)}[ log P_\theta(z | C, \epsilon) ] = \frac {1} {2} E_{\epsilon \sim Q_\phi(\epsilon | \tilde{z}, C)}[ \parallel \tilde{z} - G _ \phi(\epsilon) \parallel_2^{2} ]
\end{align}

[Shen+ 18]では、M1モデルとM2モデルの学習交互に行うことにより、事前分布と事後分布の形状が近接するようになること、自動評価でも既存手法と比較して良い性能を得られることを示しました。

DialogWAE

[Gu+ 19]で提案されたDialogWAEは、上記の[Shen+ 18]が検討したより自由な事前分布を対話応答生成に適用するという発想を、別のアプローチを用いてCVAEベースのモデルに適用した手法です。

[Tolstikhin+ 18]で提案され、[Zhao+ 18]でラベル情報も利用するように拡張された、Wasserstein Auto-Encodersの枠組みを使用することにより、事前分布を複数のモードを表現できるガウス混合分布として表現できるようCVAEを拡張しています。

Wasserstein Auto-Encodersでは、Wasserstein GANの式を利用してVAEの潜在変数を最適化するというAAEを拡張したような手法で、自由度の高い潜在空間から潜在変数をサンプリングすることを試みています。

f:id:KSKSKSKS2:20191223172207p:plain
[Gu+ 19]より

オリジナルのWasserstein Auto-Encodersでは、事後分布からは特にサンプリングは行わず、エンコーダーが変換した値をそのまま潜在変数として使用しますが、[Gu+ 19]ではエンコーダーが変換した値をもとに事後分布 Q _ \phi(\epsilon | y, C)のパラメーターを推定し、そこからサンプリングした値をさらにニューラルネットワーク G _ \phi(\epsilon)で変換したベクトルを潜在変数 zとして用いるような構成になっています。


\begin{align}
\epsilon &\sim Q _ \phi(\epsilon | y, C) \\
z  &= G_\phi(\epsilon) \\
L_{rec}(y, z, C) &= E _ {z \sim Q _ \phi(z | y, C)}[ log P _ \theta(y | z, C ] \\
\end{align}

KL項に当たる部分は、Discriminatorとのmin maxゲームで最適化します。具体的には下記の L _ {disc}を最小化するようにパラメーター \psiを、最大化するようにパラメーター \theta \phiを学習します。


\begin{align}
\epsilon &\sim Q _ \phi(\epsilon | y, C) \\
z  &= G_\phi(\epsilon) \\
\tilde{\epsilon} &\sim P _ \theta(\epsilon | C) \\
\tilde{z}  &= G_\theta(\tilde{\epsilon}) \\
L_{disc}(y, C, z, \tilde{z}) &= E_{z \sim Q _ \phi(z | y, C)}[ D_\psi(z) ] - E_{\tilde{z} \sim P _ \theta(z | C)}[ D_\psi(\tilde{z}) ] \\
\end{align}

また事前分布には、Gumbel-Softmax[Jang+ 17]を応用して、一つの値のみが1、それ以外はほぼ0となるような重み係数 v _ kを取得することにより、ガウス混合分布を仮定し、対話履歴によりサンプルに利用されるガウス分布が変更されるように学習することを試みています。


\begin{align}
P _ \theta(\epsilon | C) = \sum_{k=1}^{K} v_k N(\epsilon; \mu_k, \sigma_k^2 I)
\end{align}

実験では、事前分布の形状によらずとも、コンテキストに則した文が生成されやすくなるという方向で、生成されるシステム発話文の多様性が大幅に向上することが示されています。

また、事前分布をガウス混合分布とした際には、構成する正規分布のうち特定の正規分布を使用して生成される応答が、肯定よりか否定よりかといったような、別々のモードの文が生成されることが可能せることを示しました。

ユーザーの特性を考慮した応答文生成への応用

発話するユーザーを指定して、応答を生成したいというケースに対して、Seq2Seqベースの手法は複数提案されており、比較的初期の手法である[Li+ 16]では、デコーダーに毎ステップ発話するユーザーの情報を与えることにより、指定するユーザーによって生成される応答を変更できることを示しています。

[Bhatia+ 17]では、[Li+ 16]の手法を拡張し、ユーザーを示す埋め込み表現を、SNSソーシャルグラフから取得するnode2vecをもとに生成したものを使用する手法が提案されています。

[Luan+ 17]では、デコーダーをSeq2SeqとAutoEncoderで共有し、次の発話の予測と元の発話の復元を同じデコーダーを使用して行うことにより、対話以外のデータからも、ユーザーに特徴的な発話を学習できる手法を提案しました。

現在では、これらのベーシックな手法をCVAE系のモデルに応用したものが、複数登場しています。

SSVN

[Chang+ 19]で提案されたSSVN(Semi-Supervised Stable Variational Network)は、[Zhao+ 17]で提案されたkgCVAEを拡張し、ダイアログアクトの代わりに、対話履歴に含まれる以前のシステム発話を元にVAEを用いて生成したユーザーベクトルをと使用し、発話するユーザーに合った発話応答を生成しようというものです。

f:id:KSKSKSKS2:20191223180929p:plain
[Chang+ 19]より

ユーザーベクトル z^{replier}は、対話履歴 Cのうち、 replier(ようするにシステム)が発話した履歴のみ C^{replier} = \{ c _ 2^{replier}, c _ 4^{replier}, ..., c _ L^{replier} \}を入力として、それを復元するVAEの潜在変数として構築されます。


\begin{align}
L_{VAE}(z^{replier}, C^{replier}) &= E _ {z^{replier} \sim Q _ \phi(z | C^{replier})}[ log P _ \theta(C^{replier} | z^{replier}) ] \\
&- KL[Q _ \phi(z | C^{replier}) \parallel P(z) ] \\
\end{align}

CVAEパートは、上述のように、kgCVEのラベル情報をユーザーベクトルに変えたものと考えることができ、下記のように変分下限を定式化できます。この二つの変分下限が最大になるように、パラメーターを最適化していきます。これにより、以前のシステム発話を通常のCVAE以上に考慮した応答生成が行われることを試みます。


\begin{align}
z^{replier} &\sim Q _ \phi(z | C^{replier}) \\
L_{CVAE}(y, C, z, z^{replier}) &= E_{z \sim Q_\phi(z | y, C, z^{replier})}[ log P_\theta(y | C, z, z^{replier}) ] \\
&- KL[Q_\phi(z | y, C, z^{replier}) \parallel P_\theta(z | C, z^{replier})] \\
\end{align}

もう一点特徴的なのは、各事前分布および事後分布の確率分布としてvon Mises-Fisher分布を仮定しているところです。

von Mises-Fisher分布を用いたVAEは、[Kingma+ 14a]で提案されているベーシックなReparameterization Trickでは対応することはできませんが、[Naesseth + 17]で提案されている棄却サンプリングを応用してノイズ項をサンプリングするアプローチを用いることにより、VAEにおいて使用することができるようになります(余談ですが、Reparameterization Trickの幅広い分布への応用は、[Figurnov+ 18]などでより高速な手法が提案されるなど、こちらも年を追うごとに使い勝手が良くなっていっています)。

[Chang+ 19]では、[Naesseth + 17]が提案したアプローチを、von Mises-Fisher分布に応用した[Davidson+ 18]と[Xu+ 18]のアプローチを踏襲しています。von Mises-Fisher分布を使用したVAEは、定式化によってはKL項の値を定数にすることができるため、潜在変数にデコーダーが使用する情報を格納しやすくなる点がメリットとなります。

実験では、以前までのシステム発話をより考慮しやすくなったため、既存手法と比較して大幅に生成される文の多様性が大幅に向上し、対話履歴内の以前の発話にも合った表現で文が生成されるようになることが示されています。

VHUCM

[Bak+ 19]で提案されたVHUCM(Variational Hierarchical User-based Conversation Model)は、[Park+ 18]が提案したVHCRを、対話中の二人のユーザーベクトルをコンディショナル情報として z^{conv}を生成すように、二重のCVAEを使用するように改良した手法です。

f:id:KSKSKSKS2:20191223183251p:plain
[Bak+ 19]より

ユーザーベクトルは、[Bhatia+ 17]でも使用されているSNSソーシャルグラフを元にnode2vecを用いたもの使用することにより、ユーザーおよびユーザーの組み合わせにより、発話内容が変更されることを試みます。

 
\begin{align}
h^{conv} &= ConvEncoder(all\_dialogue) \\
z^{conv} &\sim  Q_\phi(z | h^{conv}, u^a, u^b) \\
h_t^{UE} &= UtteranceEncoder(x)  \\
h_l^{CE} &= ContextEncoder(h_{l-1}^{CE}, h_t^{UE}, z^{conv}) \\
h_{t+1}^{UE} &= UtteranceEncoder(y) \\
z &\sim Q_\phi(z | h_{t+1}^{UE}, h_l^{CE}, z^{conv}, u^{s_t}) \\
P_\theta(y | x, C) &= Decoder(z, h_l^{CE}, z^{conv}) \\
\end{align}

実験では、リファレンス文と生成される文の一致度合いが向上すること、対話ペアによって同じ質問でも受け答えの意味が変化することを示し、事前学習済みの情報をコンディショナル情報として使用することにより、生成する文の内容をうまく制約できるようになることを示しています。

まとめ

この記事では、Seq2Seqの応用から進化してきたEnd2Endな対話応答生成タスクについて、特にVAE/CVAEを応用した手法に注目して、その変遷、および模索されている方向性について紹介しました。

VHREDなどのCVAE派生手法の登場で、基本のアーキテクチャについては落ち着いてきているものの、VHREDで実現できないレベルでの更なる発話文の多様性の追求や、実用上重要となる発話内容のコントロールを行うための方法が検討されている状況です。

いずれのアプローチも、構造が複雑化してきており、データセットの分量やパラメーターの最適化の難易度も考えると、思っている以上に難しい定式化になっていそうというのが正直な感想です。

また、以前の記事でも書きましたが、対話応答生成の分野にて自動評価指標が確立されていないという問題の弊害が非常に大きく、レポートされている数字だけでは、実際にどれくらい改善されているのか、どのように改善されているのかがイマイチつかみづらいのが現状です。

一方で、同じEnd2Endな対話システムでも、Knowledge Baseや文書を参照しながらマルチターンな質問応答に答えるようなシステムは、この記事で紹介したものとはアーキテクチャが大きく異なり、Attentionを有効に活用し、正解部分を効率よく見つけて文を生成することが大きなモチベーションとなっています。

さらに、完全に文を生成するのではなく、部分的な構造を指定した上で文を生成する方法は近年注目を集めており、検索ベースのような生成とは全く違ったアプローチもまだまだ根強い人気があります。

どのようなアプローチがどのような条件下で適切なのか、対話応答生成だけでもややタスク設定が細分化してきているきらいもありますが、対話以外の分野、機械翻訳や文書要約、画像キャプション生成なども含めて、これからも追っていければと思います。

参考文献

全体

Seq2Seq

VAE

End2End Dialogue Response Generation

User-based Dialogue Response Generation

End2Endな対話システムの評価指標

この記事は、Qiita 自然言語処理アドベントカレンダーの2日目です。

1日目は jojonki さんによるゼロから作った形態素解析器Taiyakiで学ぶ形態素解析でした。

この記事では、End2Endな対話システムの評価指標、特に応答文生成の自動評価指標に注目して、どのような指標があるのか、どのような点が問題と考えられているのかに注目して、現在の動向やどのような課題があると考えられているかについて記載しています。

自然言語処理分野、特にその応用分野へのDeep Learningへの適用は、特にSeq2SeqとAttention機構によって進んできたと言っても過言ではありません、

対話システムでも、機械翻訳や文書要約といったその他の自然言語処理の応用分野と同じく、End2Endなモデルで対話システムを構築しようという試みが多く行われています。

Deep Learning応用の比較的初期の頃からEnd2Endな対話システムの構築という問題は取り組まれており、手法としては、言語モデルベース、GANベース、VAEベースのもの、そしてそれらの様々な派生手法が提案されています。

一方で、End2Endな対話システムのための評価指標については、新たに提案されるものはそこまで多くはないものの、既存の自然言語処理分野で使用されていた評価指標も含めて、決定版と考えられているものがない状態です。

本記事では、現在主にどのような評価指標が使われているのかを概観した後、それらの指標がどのような問題点をもちどのような解決案が提示されているかの順で、記載していきます。

目次

対話システムの評価指標

対話システムの評価指標は、機械翻訳や文書要約といった、自然言語処理の他のテキスト出力型のタスクと同様な指標が使用されることが一般的です。

Facebookが公開している、対話システムの学習・評価環境であるParlAIでも、メトリクスとしてはBLEUおよびROUGEといった、伝統的なn-gramベースのリファレンステキストとのマッチング指標が使用されています。*1

一方で、End2Endな生成型の対話システムが一般化してきたタイミングで、対話システム独自の評価指標も用いられるようになり始めました。

1つは、Embeddingベースと呼称される評価指標の一群で、テキストを構成する単語そのものではなく、テキストをEmbedding表現に変換した状態で、生成文とリファレンス文を比較する指標です。

もう1つが、Distinct-Nと呼ばれる、システムが生成する文の多様性を評価するための指標です。

これらの指標は対話分野以外では使用されないこともあり(対話分野でも応答文生成以外では一般的でないため)、NLP系の各種ツール(NLTKやAllenNLP)には実装がなく、nlg-evalneural-dialogue-metricsなど、対話システム向けのリポジトリにまとめられたものが公開されています。

手動評価指標としても、対話システムに独特な指標が使用されるようになってきています。

以降では、BLEUなどの伝統的な指標を含めて、どのような評価指標があるのかについて説明していきます。

自動評価

BLEU(Bilingual Evaluation Understudy)

BLEUは、名前が示す通り機械翻訳で使用され始めた評価指標です。

Modified n-gram precisionとBest match lengthを計算し、最終的な評価スコアを出力する、n-gramベースの単純なPrecisionを置き換える指標です。

センテンスベースのBLEUは下記の用に表されます。

 Precision = exp(\sum_{n=1}^N w_n log p_n)

 BP = 1 (c \gt r), exp(1 - \frac {r} {c}) (c \le r)

 BLEU-N := BP * Precision

 p_nは生成文に含まれるn-gramの数で、リファレンス文にも含まれるn-gramn-gramにつき1つと数え上げた数で割った値です。

例えば、一般のUnigramのPrecisionでは、 I work on machine learningというreferenceがあった際、

ⅰ. He works on machine learningは60%、
ⅱ. He works on on machine machine learning learningは75%

がPrecisionとなりますが、ⅰの方がどう考えても良い文章です。

この問題を解決するため、BLEUではn-gramの出現回数を一回しかカウントしないことにより、同じ単語が複数回登場するだけの文の評価をあげないように工夫されています。*2

 w_nは重みパラメーターで、一様になるように  w_n = \frac {1} {N}と置かれる事が一般的です。*3

BPは短すぎる文にペナルティーを課す項で、 cは生成文の合計の長さ、 rは基準の長さ、一般的にはリファレンス文の平均長が使用されます。

ROUGE(Recall Oriented Understudy for Gisting Evaluation)

ROUGEは、BLEUとは異なり、n-gramベースのPrecisionだけでなくRecallにも注目する手法で、文書要約の分野で登場してきた指標です。

計算方法によっていくつかの種類が提案されています。

ROUGE-Nは、生成文中に現れるn-gramがリファレンス文中に出現する回数をもとに、RecallとPrecisionを計算する指標です。RecallとPrecisionを合わせて評価するため、F1まで計算しそこで比較することが一般的です。

ROUGE-Lは、LCS(Longest Common Subsequence) という関数を用いて、以下の式で表現される。

 R = \frac {LCS(target, reference)} {m}

 P = \frac {LCS(target, reference)} {n}

 ROUGE-L := \frac {R * P} {(1 - \alpha)R + \alpha * P}

LCSは、与えられた二つの文で最も長い一致部分の長さを返す関数で、mがリファレンス文の長さ、nが生成文の長さ、 \alphaは、調整用パラメーターで一般的には0.5が使用されます。

LCSを拡張したような、ROUGE-WやROUGE-Sも同じ論文の中で提案されていますが、一般的にROUGE-NやROUGE-Lがよく使用されます。

Embedding Base

Embedding Baseの評価指標は、How NOT To Evaluate Your Dialogue System: An Empirical Study of Unsupervised Evaluation Metrics for Dialogue Response Generation[Liu 2016]にて、対話応答生成・応答選択向けにまとめられた評価指標です。

上記の論文自体は対話システムの文脈では比較的よく参照されており、2019/11/10現在で401件の論文で引用されています。*4

後述しますが、上記の論文でもEmbedding Baseの指標と人間評価の間には相関はないと結論付けられていますが、語彙が異なっていても意味的な近さを既存指標よりも評価しやすいという期待からたびたび使用されています。

上記の論文では、Embedding Baseの指標として以下の3種類があげられています。

Embedding Averageは、生成文に含まれる単語のベクトルの和と、リファレンス文に含まれる単語のベクトルの和とのコサイン類似度をスコアとしたものです。

Vector Extremaは、文に含まれる各単語の単語ベクトルのうち、各次元ごとに最大値もしくは最小値を、文のベクトルの対応する次元の値として、リファレンス文と生成文のコサイン類似度をスコアとしたものです。

Greedy Matchingは、リファレンス文と生成文に含まれる単語ベクトルと比較した際に、最もコサイン類似度が高くなる単語のコサイン類似度の平均をスコアとしたものです。

いずれの指標でも、学習済みのベクトルとして、word2vecのリポジトリ*5で公開されているGoogleのNews Corpusで作成されたものを使用することが多いようです。

Distinct

A Diversity-Promoting Objective Function for Neural Conversation Models[Li 2016]にて、対話応答生成向けに、生成された文の多様性評価指標としして提案された指標です。

生成型の対話システムでは、 "I don’t know" や "I’m OK" といった無難で前後の文脈にも合うが、システムの挙動としては期待しない文が生成されることが問題なっており、上記の論文はその解決策の走りとして提出されたもので、他の論文からの参照も多く、2019/11/10現在で495件の論文で引用されています。*6

Distinctは、上記の論文内にてあくまで提案手法の評価用に定義されただけのものですが、生成文の多様性はEnd2Endな対話システムでは大きな課題として考えられており、同じような課題意識を持つ論文では幅広く使用されています。

定義は非常に簡単で、Distinct-Nは、生成された複数文中で登場したn-gramの種類数を、複数文中に登場する全てのn-gramの数で割った値として定義されます。*7

人間評価

人間による評価では、システム全体の良さ、もしくは評価軸に基づいて、直接2つのシステムを比較したり、5段階評価で点数つけしてもらうことにより評価を行います。

評価軸に基づいた評価は、FluencyやNaturalnessといった文生成一般に使用される評価軸以外にも、以下の二つの評価軸が対話システムの評価としてよく使用されます。

Interest (Informative, Richness): 対話として情報のある文かどうか。Distinctを同じ用に、常に I don’t know 等で返す場合低い評価となる

Relevance (Consistency): 直前のユーザー発話文や対話履歴との関連性がある文かどうか

後述するように、End2Endな対話システムの評価では決定的な自動評価指標はなく、必ずと言ってよいほど人間による評価を行うことが一般的です。

特に、上記の二つのように、対話として有意味か、前後の文脈を評価できているかについては、自動評価による評価が難しく、特に評価指標としてよく用いられています。

評価指標の課題

既存指標の課題

対話システムの応答生成は、他の文生成タスクと比較して、以下の2点の特徴があるため自動評価が難しいと考えられています。

  • コンテキストにあえば多様な応答が正解となりえ、リファレンス文と使用されている単語が違っていても正解となる
  • 単語が一致しないだけでなく、意味的にリファレンス文と全く異なっていても、コンテキストにあえば正解となる

Liu 2016では、検索型および生成型の応答生成タスクで、それぞれ複数のモデルを用いて、非タスク指向対話の2つのコーパスを使用にて、各種評価指標と人間評価の相関を検証しています。

各対話システムが生成した応答文を、25人の評価者に5段階で評価してもらい、平均スコアを算出。その値と、自動評価指標の評価結果の値の相関を見るというのが主なアプローチです。

下記グラフは、2つのコーパスにて、人間の評価と、BLEU、Embedding Average、2グループの評価者の評価の対応関係を散布図で表したものです。

f:id:KSKSKSKS2:20191116154708p:plain
Liu 2016より

グラフの通り、人間同士の評価の相関はいずれのコーパスでも非常に高い(0.9以上)一方、自動評価指標はBLEUもEmbedding Averageも、いずれのコーパスでもほぼ相関がないようなグラフとなっています。

実際、Twitterコーパスでは、BLEUは相関が0.35ほど(さらに、ストップワードを抜いた場合は0.2ほど)、Embedding Averageも0.2ほど、Ubuntuコーパスでは、どちらもほぼ0付近の相関値となっています。

この問題に対応する方法として、論文中では以下の3つのアプローチが可能なのではないかと示唆されています。

  • タスク指向対話システムのように、ポリシーがが決定したアクションをもとに応答を生成するなど、正解のバリエーションがすくなるなるようにタスク設計する
  • それぞれのコンテキストにおいて、正解と考えられる発話文を複数用意したデータセットを構築する
  • リファレンス文と生成文の意味や、コンテキストと生成文の関係性を評価できる新しい評価指標を検討する

1番目は、Relevance of Unsupervised Metrics in Task-Oriented Dialogue for Evaluating Natural Language Generation[Sharma 2017]にて検討されており、非タスク指向対話に比べるとかなり高い相関が見られることを報告していますが、相関が見られる範囲および使用しているデータセットが限定できである点について考慮が必要です。*8

2番目は、残念ながら対話というジャンルの性質上構築が難しいのか、複数の応答が正解として用意されているようなデータセットはまだ見かけません。

3番目は、BERTScoreやSentence Mover’s Similarityのような、事前学習済みのEmbeddingを使用した新しい評価指標が、対話以外の分野から続々と登場してきています。

また後述するように、対話システムの分野でも、対話のコンテキストを考慮した複数の評価指標が提案されています。

対話のコンテキストを考慮した評価指標

ADEM(Automatic Dialogue Evaluation Model)

ADEM[Lowe 2017]は、最初期のコンテキストを考慮した対話システムの評価指標で、対話履歴、リファレンスの応答文、システムが生成した応答文をもとにスコアを出力します。

スコアを出力するモデルは、ベースラインシステムが生成(選択)した対話応答文を、人間が評価した値を集めたデータセットを学習して構築します。

上記のデータセットのテストセットで評価した結果、人間評価との相関値がBLEUやROUGEなどの既存の指標を大幅に上回ることを示しています。

f:id:KSKSKSKS2:20191123172728p:plain
Lowe 2017より

一方で、上記データセットに含まれる該当システムへの人間評価を学習時に取り除いた場合、該当システムが生成した応答文への評価の相関値が大幅に低下する実験結果も示されており、扱いには注意が必要です。

RUBER(Referenced metric and Unreferenced metric Blended Evaluation)

RUBER[Tao 2018]は、生成した応答文と、リファレンスの応答文との一致率と、対話コンテキストの一致率を別々に計算し統合することにより、人間評価に近い評価を得られる指標の構築を試みた指標です。

生成した応答文とリファレンスの応答文との一致率(Referenced Metric)は、Embedding Baseの指標と同様、既存の事前学習済み単語ベクトルを使用します。

生成した応答文と対話コンテキストの一致率(Unreferenced Metric)は、ニューラルネットを使用して、学習データに含まれるコンテキスと応答文の組み合わせでは高く、学習データに含まれない組み合わせでは低くなるように、スコアを出力するモデルを学習します。

Referenced MetricとUnreferenced Metricの統合方法は、最小値を用いるケース、算術平均を用いるケース、幾何平均を用いるケース、最大値を用いるケースで検証されており、最大値以外は、概ね人間評価との相関値が0.45前後と比較的高い数値を得られることを示しています。

f:id:KSKSKSKS2:20191123174639p:plain
Tao 2018より

一方で、上記の結果は対話システムと同じデータを用いてUnreferenced Metricを算出するモデルを学習した場合の値であり、異なるデータで学習した対話システムが生成した応答文の評価においては、人間評価との相関値が0.35ほどと10%ほど下がってしまうことが示されています。

また、Referenced MetricとUnreferenced Metricともに、単語をベクトルに使用する事前学習済みのモデルとして、word embeddingではなくBERTなどのContextual Embedingにするなど、いくつか改良を加えた方が、人間評価との相関値がより良くなるという報告も存在しています。

まとめ

本記事では、End2Endな対話システムの自動評価指標として、言語処理分野で一般的なものから、対話の応答生成用に独自で使用されている指標についても紹介しました。

また、人間評価との乖離という問題についても、どのような議論が行われているかを紹介しました。

残念ながら決定的な自動評価指標がないのがまだまだ現状です。

ここで紹介されていないけど、こんな評価指標あるよなどがあれば、ぜひ教えていただければ幸いです。

参考文献