終末 A.I.

データいじりや機械学習するエンジニアのブログ

【読書メモ】『Data Governance: The Definitive Guide』 Chapter 2

目次

エンタープライズディクショナリ

  • まず、組織がデータをどのように処理し、データガバナンスを可能にするかを理解することが重要
  • エンタープライズディクショナリは、組織が使用するインフォタイプについて合意された情報の集まり
  • インフォタイプは、たとえば「メールアドレス」や「住所」、さらには「給与額」など、単一の意味を持つ情報の一部
  • エンタープライズディクショナリを定義すると、その中のさまざまな個別のインフォタイプをデータクラスにグループ化することができ、データクラスごとにポリシーを定義することができる
  • エンタープライズディクショナリには通常、データクラス、データクラスに関連するポリシー、および追加のメタデータが含まれている

データクラス

  • 優れたエンタープライズディクショナリには、組織が処理するデータのクラスのリストが含まれる
  • データクラスは、ポリシー管理の観点から共通の方法で扱われるグループにまとめられたインフォタイプのこと
  • つまり、エンタープライズディクショナリには、インフォタイプの階層が含まれる
  • 多くの組織で見られるデータクラスの例
    • 個人情報
    • 金融関連情報
    • ビジネス知的財産
  • データクラスの種類は、業種や関心に応じて変化する
  • データクラスは、1つのトピックに属する情報要素の組み合わせであることに注意。たとえば、電話番号は通常データクラスではないが、個人情報は通常データクラス
  • データクラスの特徴は以下の2つ
    • データクラスは一連のポリシーにひもづく。同じデータクラス内のデータには、同じ保持ルールとアクセスルールが必要
    • データクラスは、個々のインフォタイプのセット(インフォタイプの階層構造)

データクラスとポリシー

  • 組織が処理するデータがエンタープライズディクショナリで定義されると、データクラスを管理するポリシーを割り当てることができる
  • エンタープライズポリシーブックには、組織は、「どのような種類のデータを処理するのか」という質問に答えられる必要がある
  • 組織が使用するデータクラス、処理されるデータの種類、およびそれらの処理方法を指定し、データの「許可されていることと許可されていないこと」について詳しく説明する
  • 責任、リスク管理、および法的措置への露出を制限するために、組織は通常、データの最大(および最小)保持率を定義する
  • 別の種類のポリシーはアクセス制御。データの場合、アクセス制御は「はい/いいえ」を超えて、アクセスなし、部分的アクセス(マスクされたデータやハッシュ化されたデータ)、または完全アクセスのいずれかになる
  • 通常、ポリシーブックには以下の内容を指定する
    • 誰が(組織の内部または外部で)データクラスにアクセスできるか
    • データクラスの保持ポリシー(データが保持される期間)
    • 該当する場合、データの常駐 / ローカリティルール
    • データの処理方法(どの処理方法では OK なのか、もしくはNGなのか)
    • 組織によるその他の考慮事項

ユースケースごとのデータポリシー

  • データアクセスのユースケースまたは目的は、理想的には、組織のメンバーシップと組織の役割の上にオーバーレイする必要がある
  • 収集される新しいさまざまなタイプのデータに対応するために要件や規制が変化するにつれて、データのユースケースはポリシー管理の重要な側面
  • 複数の役割を果たす可能性のある従業員がいる会社では、インフォタイプ/データクラスと従業員の役割だけを考慮するのではなく、データの使用目的(ユースケース)に関連するアクセスを検討する方が効率的

データ分類と組織化

  • データのガバナンスを制御するには、データの分類を少なくとも部分的に自動化することが有益
  • データ分類器は、非構造化データ、または構造化データのカラムのセットを調べて、データが何であるかを推測する
  • データ分類の自動化は、主に2つの方法で実行できる
    • 取り込み時にデータクラスを特定し、データソースの追加に関する分類ジョブをトリガーする
    • データのサンプルを確認しながら、データ分類ジョブを定期的にトリガーする
  • データを分類すると、必要な自動化のレベルに応じて、次のことができる
    • データに「このデータクラスに属する」というタグを付けます
    • データにアクセスまたは操作されるデータクラス、「目的」、またはコンテキストの定義に従って、データへのアクセスと保持を制御するポリシーを自動的に(または手動で)適用する

データカタログとデータマネジメント

  • メタデータが、基礎となるデータ自体と同じポリシーと制御に従うと考えるのは単純だが、これが邪魔になる場合が多くある
    • テーブル自体にアクセスできない場合でも、そのようなテーブルが存在することを知っていることは価値がある
  • メタデータには、データの場所とそれに関連する技術情報(テーブルスキーマ、テーブル名、列名、列の説明)が含まれる
  • ただし、組織内の誰がデータを所有しているかなど、追加の「ビジネス」メタデータの添付も許可する必要がある
  • データがローカルで生成されたものか外部で購入されたものか、本番のユースケースまたはテストに関連するかどうかなども含む
  • データガバナンス戦略が成長するにつれて、データガバナンス情報の詳細(データクラス、データ品質、機密性など)をデータカタログ内のデータに添付する必要がある

データアセスメントとプロファイリング

  • ほとんどのデータ活用ワークフローの重要なステップの一つは、データをふるいにかけるときに、そのデータの外れ値を確認すること
  • 外れ値は、データ入力エラーの結果であるか、残りのデータと矛盾している可能性があるが、弱い信号またはあまり現れていない新しいセグメントまたはパターンである可能性もある
  • 外れ値の保持または削除は、データが使用されているビジネス目的のコンテキストごとに行う必要がある
  • データエンジニアは通常、データの外れ値やその他の疑わしい品質の問題を含むレポートを作成する責任がある
  • 理想的には、カラムごとの異常を検出し、関連するコンテキストで異常が意味をなしているかどうかを判断するために、データのプロファイルを作成する必要がある
  • 各フィールドで受け入れられるデータの種類の境界が設定され、自動化されたルールによって、データのバッチまたはイベントストリームが取り込まれるように準備され、クリーンアップする

データ品質

  • データ品質は、データソースに関連するユースケースを決定する際の重要なパラメータ
  • データ品質管理プロセスには、検証用のコントロールの作成、品質の監視とレポートの有効化、インシデントの重大度のレベルを評価するためのトリアージプロセスのサポート、根本原因の分析とデータの問題に対する救済策の推奨、およびデータインシデントの追跡を可能にすることが含まれる
  • さまざまな品質データセットに割り当てられたさまざまな信頼水準が必要。混合品質の祖先データを使用して結果のデータセットを許可することについても考慮が必要
  • データ品質管理のための適切なプロセスは、分析のために測定可能なほど信頼できるデータを提供すること
  • データの生成を担当するビジネスユニットがそのデータの品質も所有し、ダウンストリームのユーザーに影響を残さないようにする
  • 組織は、データの所有者がデータが組織の品質基準に合格する品質であることを証明するまで、データの使用を許可されないデータ受け入れプロセスを作成できる

リネージ追跡

  • リネージは、データのソースとその過程でどのように操作されたかから生成される
  • リネージを作成する1つの理由は、結果のダッシュボード/集計の品質を理解すること
  • もう1つの理由は、組織のデータスケープ全体での機密データクラスの移動を全体的に把握して、機密データが不正な「箱」に誤って公開されないようにすること
  • リネージ追跡では、何よりもまず、「品質」などの結果のメトリック、またはデータが機密情報で「汚染」されているかどうかについての計算を提示できる必要がある
  • そして後で、データトラバーサル自体のグラフィカルな「グラフ」を表示できる必要がある
  • 多くの場合、リネージについて話すときは、データがどこから来てどこに行くのかを知ることに重点が置かれるが、何かが壊れた時と場所を視覚的に確認/把握し、すぐに行動を起こすことにも価値がある
  • ダッシュボードへの現在の入力が何であるかだけでなく、それらの入力が過去に何であったか、そしてリネージがどのように進化したかを追跡する必要もある

データ保持とデータ削除

  • データガバナンスツールのもう1つの重要な項目は、データの保持期間を制御する機能
  • 時折のストレージスペースの最適化に耐えるデータを特定することには、保持する価値が高いという明らかな利点があるが、価値の低いデータクラスのデータ保持に最大保持時間を設定し、それを自動的に削除することはあまり価値が明白ではない
  • データの保持と削除について話すとき、機密データの処理方法のコンテキストでそれらについて考えることがよくある。つまり、それを保持するか、暗号化するか、削除するか、または他の方法で処理するかどうか
  • ただし、ガバナンスポリシーによって、コンプライアンス違反から保護されるだけでなく、作業の損失から保護されるシナリオもある
  • 機密データをどこに、どのくらいの期間保持するか、削除するかどうかという観点から、機密データをどのように処理および処理するかだけでなく、ガバナンスプログラムで検討することをお勧めする
  • また、バックアップするのに重要な他のクラスやカテゴリのデータに同じプログラムを実装する方法もある
  • データの損失がコンプライアンス違反につながることはないかもしれないが、それは確かに他の壊滅的なビジネス上の結果をもたらす可能性がある

learning.oreilly.com

【読書メモ】『Data Governance: The Definitive Guide』 Chapter 1

目次

データガバナンスとは

  • データガバナンスは、組織によって収集されたデータの品質、整合性、セキュリティ、および使いやすさを保証するためのデータ管理機能のこと
  • データのライフサイクル全体で、データガバナンスは、すべての利害関係者が簡単にアクセスできる形式でデータを利用できるようにすることに重点を置く
  • 望ましいビジネス成果(インサイト、分析)を生成し、関連する場合は規制基準に準拠する方法で使用できるものでなければならない
  • データガバナンスでは、利害関係者が企業内のすべてのデータの高品質な統合ビューを確実に取得できるようにする必要がある
  • データが安全であることを保証するためにデータガバナンスを実施する必要がある。具体的には下記
    • 許可された方法で許可されたユーザーのみがアクセスできる
    • 監査可能です。つまり、変更を含むすべてのアクセスがログに記録される
    • 規制に準拠している
  • ユーザーが企業データを使用して、主要業績評価指標(KPI)を使用した意思決定、リスク評価、および管理をサポートできるようにするには、信頼できるデータが必要
  • データガバナンスの原則は、企業の規模やデータの量に関係なく同じ

データガバナンスに関連するもの

データの信頼性の強化

  • データガバナンスの目的は、データの信頼性を構築すること
  • データの信頼性を確保するには、データガバナンス戦略が、発見可能性、セキュリティ、および説明責任という3つの重要な側面に対処する必要がある
  • 発見可能性のためには、技術メタデータ、リネージ情報、およびビジネス用語集をすぐに利用できるようにすることが必要
  • セキュリティのためには、規制への準拠、機密データ(個人を特定できる情報など)の管理、データのセキュリティと漏洩防止の観点が必要
  • 検出可能性とセキュリティが整っているてはじめて、データ自体を製品として扱い始めることができる。その時点で、説明責任が重要になり、データドメインの境界の周りの所有権と説明責任のための運用モデルを提供する必要がある

データの分類とアクセス制御

  • データガバナンスに関連する主なアクティビティには、データの分類とアクセス制御が含まれる
  • ガバナンスポリシーは通常、データに責任を持つグループ(例えば、雇用者情報の場合は人事部門)によって指定される
  • ポリシー自体は、多くの場合ITチームによって実行される

データガバナンスとデータイネーブルメントおよびデータセキュリティの関係

  • データガバナンスは、データイネーブルメントを拡張して、データ取得を実行できるワークフローを含む
  • ユーザーは、コンテキストと説明でデータを検索し、関連するデータストアを見つけて、正当な理由として目的のユースケースを含めてアクセスを要求する
  • 承認者(データスチュワード)は、ユーザーのニーズを確認し、ニーズが正当化かどうか、アクセスを要求されているデータが実際にユースケースに役立つかどうかを判断し、データにアクセスできるようにする必要がある
  • データガバナンスは、データセキュリティの仕組みが整っていることに依存するが、不正アクセスの防止だけでなく、データ自体に関するポリシー、つまりデータクラスに応じた変換にまで及ぶ

なぜデータガバナンスはより重要になっていくのか

データサイズの成長

  • 2018年11月に発行されたホワイトペーパーで、International Data Corporationは、グローバルデータスフィアが2025年までに175ZBに膨れ上がると予測している

データを利用する人々の指数関数的増加

  • IDCはまた、現在世界で50億人を超える人々がデータを操作していると報告しており、この数は2025年には60億人(世界の人口の約75%)に増加すると予測している
  • 企業は「データ主導の意思決定」が可能であることに夢中になっており、膨大な数の人員を必要としている

データ収集方法の高度化

  • データをバッチ処理して分析のためにロードするだけではなく、企業は、リアルタイムのストリーミングデータと分析を活用して、顧客により良い、よりパーソナライズされたエンゲージメントを提供することが必要になっている

多様な種類のデータ(機密性の高いデータを含む)の収集

  • データのやり取りの多くには、社会保障番号、クレジットカード番号、名前、住所、健康状態など、無数の機密データの生成とその結果の収集が含まれる
  • これらの非常に機密性の高いタイプのデータの収集が急増しているため、そのデータがどのように使用および処理され、誰がそれを表示できるかについて、顧客は大きな懸念を抱いている

データ利用ケースの拡張

  • 企業は、データを使用してより良いビジネス上の意思決定を行うよう努めている
  • さらに、顧客がより良い意思決定を行うのを助けるためにデータを使用している

データを扱う新しい規制や法律

  • データとデータの可用性の向上により、データ、データ収集、データアクセス、およびデータ使用に関する規制が望まれ、必要になっている
  • EUの一般データ保護規則(GDPR)や米国のカリフォルニア州消費者プライバシー法(CCPA)などの新しい規制は、無数の企業に適用される使用および収集管理についての規制の例にすぎない
  • 従来のデータアーキテクチャ戦略に組み込まれていなかったため、これらの新しい規制へのコンプライアンスを維持するためにテクノロジーとビジネスプロセスを変更するのに苦労している

データ利用についての倫理的関心

データガバナンスのビジネス的価値

  • データガバナンスは、知識労働者が必要とする簡単にインサイトを得るという戦略的ニーズに対応する
  • データガバナンスが戦略的プロセスである組織では、知識労働者は、ミッションを遂行するために必要なすべてのデータを簡単に見つけ、安全にアクセスを申請し、明確なタイムラインと透過的な承認プロセスを備えたシンプルなプロセスでデータへのアクセスを許可されることを期待できる
  • データの承認者とガバナンス担当者は、どのデータに誰がアクセスできるか、どのデータがガバナンスの管理ゾーンの「外側」にあるかを簡単に把握できる
  • CIOは、組織内のデータの高レベルの分析をレビューして、「データの総量」や「準拠していないデータ」などの定量化可能なメトリックを総合的にレビューし、リスクを理解(および軽減)することができる

イノベーションの促進

  • データガバナンス戦略は、うまく機能している場合、プロセス(ガバナンスの下でデータを利用できるようにする)、人(ポリシーを管理し、組織全体のデータアクセスを導き、必要に応じてサイロを解消する)、および上記を容易にするツールの組み合わせで構成される
  • データガバナンスは、理想的には、組織のリスク態勢を維持しながら、組織内のすべての従業員が一連のガバナンスルールの下ですべてのデータにアクセスできるようにする
  • すべての知識労働者に管理された方法でデータへのアクセスを提供することで、個人が組織内に存在するデータに基づいて質問への回答を迅速にプロトタイプ化できるようにすることで、イノベーションを促進できる

データガバナンスとデータ分析の大衆化の緊張関係

  • 多くの場合、完全なデータの民主化は、データガバナンスと矛盾すると考えられているが、そうではない
  • 覚えておくべき重要な概念は、データには2つのレイヤーがあるということ。データ自体とメタデータ
  • データガバナンスを使用すると、次の3つのことを実行できる
    • 管理されているすべてのデータのインデックスを含むメタデータカタログにアクセスし、特定のデータの存在を検索できるようにする。優れたデータカタログには、検索の範囲を制限する特定のアクセス制御ルールも含まれている
    • データへのアクセスを管理する。これには、取得プロセスおよび最小アクセスの原則を順守する方法が含まれる
    • 他の手順とは別に、データアクセス要求、データアクセス承認サイクル、承認者(データスチュワード)、および後続のすべてのアクセス操作で「監査ログ」を利用できるようにする
  • データガバナンスは、データの民主化を可能にし、より多くの知識を持つ従業員がより多くのデータにアクセスできるようにする機能になる。したがって、データの使用をより簡単かつ迅速にするビジネスの加速器になる

リスク管理 (盗難、誤用、破損)

  • CIOと責任あるデータスチュワードが長い間抱えていた主な懸念事項は、リスク要因は何か、それを軽減する計画は何か、そして潜在的な損害は何か、ということ
  • CIOはこれらの質問への回答に基づいてリソースを割り当てる
  • データガバナンスは、そこに提示されている他のトピックの中でも、データに対するリスクを管理するための一連のツール、プロセス、およびポジションを担当者に提供する
  • リスクには、盗難、誤用、データの破損がある

コンプライアンス

  • データガバナンスは、一連の規制がビジネス、特にビジネスプロセスのデータに適用される場合に活用される
  • 規制は、本質的に、組織が運営するビジネス環境内で機能するために遵守しなければならないポリシーである
  • ポリシーを実現するために以下のようなことを行う
きめ細かいアクセス制御
  • アクセス制御は、何よりもセキュリティに関連する確立されたトピックである。きめ細かいアクセス制御は、アクセス制御に次の考慮事項を追加する
  • アクセスを提供するとき、適切なサイズのコンテナへのアクセスを提供しているか
  • アクセスを提供するとき、適切なレベルのアクセスを提供しているか
  • アクセスを提供する場合、アクセスはどのくらい開いたままにする必要があるか
データの保持と削除
  • 重要な規制は、データの削除と保存を扱っている。設定された期間、およびその期間以上のデータを保存するという要件は一般的
  • 逆に、組織は特定の情報を保持する時間を制限して、責任を制限しながら迅速な結論を導き出すことができる
監査ログ
  • 規制当局に監査ログを表示できることは、ポリシーが遵守されていることの証拠として役立つ。削除されたデータを提示することはできないが、データが作成、操作、共有(および誰と)、アクセス(および誰によって)され、後で期限切れまたは削除された手段の監査証跡を表示できる
  • データガバナンスの目的で役立つためには、監査ログは不変で、書き込み専用であり、最も要求の厳しいデータ保存ポリシーである限り、それ自体で長期間保存される必要がある
  • 監査ログには、データとデータ操作自体に関する情報だけでなく、データ管理機能の周辺で発生する操作に関する情報も含める必要がある
機密データクラス
  • 多くの場合、規制当局は、あるクラスのデータを他のデータとは異なる方法で処理する必要があると判断する。これは、保護された人々のグループ、または一種の活動に最も一般的に関係する規制の中心である
  • データのどの部分を実際に処理するか、およびこのデータを構造化ストレージまたは非構造化ストレージに保存されているデータと比較する方法を正しく特定するのは、組織の責任

データガバナンスについて考える組織の考慮事項

  • 組織がデータガバナンスプログラムとその目標を定義し始めるとき、それらが運営される環境を考慮に入れるべき
規制とコンプライアンスのニーズの変化
  • 規制環境の変化により、組織はガバナンスに関して警戒を怠らない必要がある
  • 企業は既存の規制について知っている必要があるだけでなく、変化する規制や規定、およびビジネスのやり方に影響を与える可能性のある新しい規制についても把握する必要がある
データの蓄積と組織の成長
  • インフラストラクチャのコストが急速に減少し、組織が有機的に成長し、追加のビジネスユニットを取得することで成長する中で、データ蓄積のトピックと、大量のデータを迅速に蓄積するための適切な対応方法が重要になる
  • 組織は、データレイクを構築することですべての問題を解決できると考えていたが、現在、これらのデータレイクは、理解および管理することが不可能な大量のデータを含むデータの沼地になりつつある
データをクラウドに移動する
  • 従来、すべてのデータは、組織によって提供および維持されるインフラストラクチャに存在していた。これは、組織がアクセスを完全に制御できることを意味し、リソースの動的な共有はなかった
  • クラウドコンピューティングの出現により、組織はオンプレミスとクラウドインフラストラクチャの対応と投資について考える必要がある
データインフラストラクチャの専門知識
  • ハイブリッドコンピューティングにより、組織はオンプレミスとクラウドの両方のインフラストラクチャを利用でき、マルチクラウドにより、組織は複数のクラウドプロバイダーを利用できる
  • これによりガバナンスが複雑になり、ガバナンスの実装に使用されるツールの機能を超えてしまう

learning.oreilly.com

DataHub vs OpenMetadata ~OSSデータカタログツール比較~ 【概要編】

データカタログ、皆さんはどう運用してますでしょうか。必要だとは思うけどプライオリティーが低く特に導入していない、スプレッドシート(エクセル)管理でお茶を濁している、各クラウドベンダー標準のものをとりあえず使っている、という所も多いのかなと勝手に想像しています。

とはいえデータカタログは、データを利用したい、特に初めて触るようなデータを利用しようとしているユーザーにとっては、一番最初に触れるデータプロダクトになる可能性が高く、ここがおざなりになっていて本当に良いのだろうか、というのがこの2つのツールを比較しようと思ったきっかけです。まずは、データカタログのニーズを一身に受けて開発が進めらているこれらのツールを比較してみることで、データカタログに求められているものを把握してみようという魂胆です。

※ 以降の比較は、DataHub v0.8.34、OpenMetadata ver. 0.10.0 をもとに行っています。

目次

各ツールの基本情報

DataHub

DataHubは、LinkedIn内部のメタデータ管理に利用されていたツールがオープンソース化されたもので2019年から公開されています。OpenMetadataはじめほかのツールと比較した際の一番の特徴はこのブログポストにもあるように、スケーラブルなメタデータ管理ができるようなアーキテクチャ構造や内部のデータモデル設計がされているという点です。Kafkaを利用したイベントソーシングベースの実装でメタデータの登録と検索エンジン(ElasticSearch)内のインデックスを作成するようになっていたり、データモデルが柔軟に必要なエンティティを追加できるような実装になっています。

同社より2016年に公開されていたWhereHowsというツールがベースになっており、それを再構成したものです。

OpenMetadata

OpenMetadataは、2021年から開発が開始された新興のデータカタログツールで、データに関わる様々なメタデータを集中管理することを目指して開発されています(参考)。RedashやAirByteのように、SaaS版も用意される(もうすでに用意してある?)予定もあるようです(参考)。

このブログポストからも読み取れるように、アーキテクチャ面では反DataHubの姿勢をとっており、なるべくミニマムな構成、ミニマムなデータ構造で機能を実現できるように配慮されています。一方で、以降で比較するように機能面でもDataHubは意識されているように見え、とりあえずまずはDataHubが一般的にした機能を提供することを目標に置かれているように感じます。

ちなみに管理主体は open-metadata.org となっていますが、元々Hadoop系のコミュニティーで活動していて2021年4月までUberで働いていた Suresh Srinivas 氏が創業者である Collate, inc. 内で主に開発されているようです。

機能概要

主にサポートされている機能は以下になります。

機能 DataHub OpenMetadata
メタデータの検索 キーワード検索(詳細検索可)
タグによる絞り込み
階層構造からの絞り込み
キーワード検索(詳細検索可)
タグによる絞り込み
一覧からの絞り込み
メタデータの詳細 スキーマ
テーブル・カラムの説明
タグの設定
スキーマ
テーブル・カラムの説明
タグの設定
プロファイリング レコードの数
カラムのNULLの数および比率
カラム中のdistinctな値の数および比率
統計量
値のサンプル
参考
レコードの数
カラム中ののNULLの比率
カラム中のdistinctな値の比率
ユニークな値の比率
統計量
値のサンプル
取り込み処理 configファイル&CLI AirFlowのタスク
configファイル&CLI
データの使用履歴 テーブル毎のクエリ一覧
カラム毎のアクセス数
テーブル毎のクエリ一覧
参考
データ品質 Great Expectations連携でテストを追加 GUIや設定ファイルで独自仕様のテストを追加
データリネージ SQLログから自動作成
dbtなどの変換ツールクエリから自動生成
BIツールのSQLから自動生成
SQLログから自動作成
dbtなどの変換ツールクエリから自動生成
BIツールのSQLから自動生成
扱えるエンティティ データセット
チャート
ダッシュボード
パイプライン
MLモデル
ML特徴量
など
データセット
チャート
ダッシュボード
パイプライン
MLモデル
など
メタデータのバージョニング APIで書き込む 自動付与
アクセスコントロール アクションの許可設定
リソース単位でのアクションの許可設定
参考
アクションのAllowおよびDeny設定
SSO OIDC OIDC

構成の概要

DataHub

DataHubは上記でも説明したように、スケーラブルな処理を目指しており、ツールもコンポーネント毎に独立した構造になっています。メインのサーバーとフロントエンドが分かれているだけでなく、MySQLにKafkaやElasticSearch、neo4jも動いており、気軽に立ち上げるにはかなり大規模です(参考)。また、これに加えデータ取り込み処理を実行するためのAirFlowなどは別に設置する必要があります。

OpenMetadata

サーバー以外には、MySQLとElasticSearch、そしてデータ取り込み用のAirFlowがあればよく、比較的シンプルに構築することができます(参考)。AirFlowはGUIで構築した取り込み処理によってタスクが自動作成されるなど大きく依存していますが、既存のインスタンスがあればそれを再利用可能になっているなど、AirFlowユーザーにとっては負担の少ない構成になっています。

次回以降は、各機能の詳細、デプロイや連携のしやすさなどを細かく見ていこうと思います。

Cloud Storage Transfer ServiceでAssumeRoleを使ってS3からデータを移行する

※ この記事は2021年10月の情報に基づいて記載しています。

※ 最新情報はGCPのドキュメントを参照ください。

Cloud Storage Transfer Serviceは、GCP内から直接S3等のクラウドストレージ(もしくはオンプレミス)のデータ移行を行うことにより、高速で高並列にGCSへのデータ移行を実現できるサービスです。

一方で、データソースをS3とする場合は、AWSのアクセスキーとシークレットをGCPに設定しておく必要があり、アクセス情報の漏洩に対して、あまりセキュアではありませんでした。

それに対応できる機能が、2021年7月にパブリックプレビューとなりました。AWSのAssumeRoleWithWebIdentityを利用したフェデレーテッドアクセスにより、データの転送が行えるようになったのです。以下では、その使い方について、簡単に説明したいと思います。

GCPが管理する専用のサービスアカウント(SA)を確認する

Storage Transfer Serviceでは、GCPが管理する専用のサービスアカウントを用いてデータの転送処理を行っています。

通常のIAM一覧上では確認できず、必要な情報はサービス専用の googleServiceAccounts.get を呼び出すことで取得できます。このAPIの戻り値は以下のようになります。

{
  "accountEmail": "project-xxxxxxx@storage-transfer-service.iam.gserviceaccount.com",
  "subjectId": "xxxxxxx"
}

accountEmailは見ての通りSAを識別するための固有のメールアドレスで、xxxxの部分にはプロジェクト固有のプロジェクト識別番号が入ります。subjectIdもアカウントを識別するための固有のIDで、後でAWS上でフェデレーテッドアクセスを設定するために使用する情報になります。

AWSに専用のロールを作成する

続いて、AWS上に上記のSAがフェデレーテッドアクセスするためのロールを作成します。ロールには、GCPのSAがAssumeRoleを行うための権限と、該当のS3バケットからデータを読み出すための権限設定が必要です。

まず、フェデレーテッドアクセス用の権限ですが、対象のロールのAssumeRolePolicyに下記を設定すれば良いです。

{
  "Version": "2012-10-17",
  "Statement": [
    {
      "Effect": "Allow",
      "Principal": {
        "Federated": "accounts.google.com"
      },
      "Action": "sts:AssumeRoleWithWebIdentity",
      "Condition": {
        "StringEquals": {
          "accounts.google.com:sub": "SAのsubjectId"
        }
      }
    }
  ]
}

設定内容の詳しい説明はAWSのドキュメントを参照していただきたいのですが、GoogleからOIDCでAssumeRoleを呼び出された場合に、アカウントのsubjectIdが指定したものについてのみ許可するという設定になります。

次にS3からデータを読み出すための権限をロールのポリシーに指定します。適宜必要に応じて追加したり削ったりしてもらえれば良いですが、簡単には下記のようになります。

{
  "Version": "2012-10-17",
  "Statement": [
    {
      "Effect": "Allow",
      "Action": [
          "s3:Get*",
          "s3:List*",
       ],
      "Resource": "送信元バケット"
    }
  ]
}

GCSの書き込み権限を設定する

Storage Transfer Serviceが使用するSAには、対象のGCSバケットに書き込みを行う権限も必要になります。GCSの権限設定を用いて、バケットにSAが roles/storage.legacyBucketWriter および roles/storage.objectViewer を使用してアクセスできるように設定を行います。

Storage Transfer ServiceのJobを作成する

最後に、上記で作成したロールを用いてS3にアクセスするように設定したジョブを作成します。記事を書いている時点では、Webコンソールにて設定する方法はありませんでしたので、API呼び出しで作成する方法について記載します。

ジョブの作成は、transferJobs.createを使用して行います。設定内容は以下が設定されていれば最低限問題ありません。

{
  "name": "ジョブの名称",
  "projectId": "GCPのプロジェクトのID",
  "transferSpec": {
    "awsS3DataSource": {
      "bucketName": "送信元バケット名",
      "roleArn": "作成したAWSのロールのARN"
    },
    "gcsDataSink": {
      "bucketName": "送信先バケット名"
    }
  },
  "status": "ENABLED"
}

このAPIを実行したタイミングで、AssumeRoleでフェデレーテッドアクセスができるか、ロールでS3に指定のバケットの読み取りが行えるかの確認も走ります。アクセス情報に問題なければジョブが作成されます。

あとは、通常のジョブと同様手動で実行したりスケジュールで実行したりすることができます。

データテストライブラリー「Deequ」を触ってみた

DeequはAWSがリリースしているデータテストを行うためのライブラリです(Deequの説明ではUnit Testと表現されています)。

ここで言うデータテストは、ETL処理やデータマート作成処理などの意図通り動いているどうか、取り込んだデータが昔と変化していないかを確認するための検証処理のことを指しています。

ETL処理などを最初に作成したタイミングでは、その処理が意図したものになっているか確認すると思います。一方で、日次のバッチ処理や、動き続けているストリーム処理について、本当に意図したようにデータが加工されているかどうかは、通常の方法では処理自体が成功したかどうかくらいしか確認するすべがありません。

しかし、日々のデータ処理は簡単に意図しないデータを生み出してしまう可能性があります。気づいたらデータの中身が変わっていて、変換処理が意図しない動作をしてしまっていたり、そもそもソースデータがおかしくなっていて重要な指標がずれてしまう、というようなことも考えられるでしょう。

そのような時に役に立つのがデータテストです。データテストでは、Nullを許容しないはずのカラムに何故かNullか入ってくるようになっていないか、過去のデータと比較して極端にデータの数が変化していないか、などを調べることを含む概念です。一言でいうと、データが意図しないものになっていないかを確認する処理、と言えます。

目次

Deequの何がいいのか

AWSのDeequは、そんなデータテストを簡単に実施できるようにするためのScala製ライブラリです。PythonラッパーであるPyDeequもあります。 Deequは、例えばデータ変換ツールであるdbtでもサポートしていますが、それと比べると、下記のような点が特徴としてあげられます。

  • Sparkベースでできている
    • SQLクエリで直接アクセスできない、ファイルだけがあるようなデータにも適用できる
    • プラグラムベースでしか実現しにくいような処理でも比較的組み込みやすい
  • プリセットのテスト関数が豊富に組み込まれている
    • 手元のデータ単体にフォーカスしたテスト関数だけで40個ほどプリセットである
    • AnomalyDetectionという、過去のデータの状態も参照してテストするための処理も組み込まれている
  • 必要なテスト処理をカラムごとにレコメンドしてくれる機能も組み込まれている

個人的に特にいいのは、AnomalyDetectionの機能が最初から組み込まれている点です。言わずもがなデータは日々変わりますので、実データについて何が正解かを決めにくい、という問題がテストを実装するにあたって存在します。

そのため、今までのデータと比べてどうかというのが、データがおかしくなっているかどうかを判断するための1つの重要な指標になります。特定の指標について経年変化をもとに「違和感」を自動で検出できる仕組みは非常に重要です。

Deequは、このようなAnomalyDetectionと、一般的なデータ単体のテストを、1つのライブラリで完結して行うことができるようになっています。

データテストの実行

基本的な使い方

データテストの基本的な実行は、PyDeequでは下記のように書くことで実現できます。

サンプルコード1

from pyspark.sql.types import StructType, StructField, StringType, IntegerType
from pydeequ.checks import *
from pydeequ.verification import *

df = spark.createDataFrame(data=[
        (1, "Test1", "foo"), 
        (2, "Test2", "foo"), 
        (3, "Test3", "bar"), 
        (4, "Test4", "baz"), 
        (5, "Test5", "baz"), 
        (6, "Test6", "bar"), 
        (7, "Test7", None), 
        (8, "Test8", "bar"), 
], schema=StructType([
        StructField("a", IntegerType(), True),
        StructField("b", StringType(), True),
        StructField("c", StringType(), True),
]))

check_warning = Check(spark, CheckLevel.Warning, "Warning Check")
check_error = Check(spark, CheckLevel.Error, "Error Check")

checkResult = VerificationSuite(spark) \
    .onData(df) \
    .addCheck(
        check_warning.isComplete("a") \
        .isComplete("b") \
        .isComplete("c")) \
    .addCheck(
        check_error.isPositive("a") \
        .isUnique("b") \
        .isContainedIn("c", ["foo", "bar", "baz"])) \
    .run()

checkResult_df = VerificationResult.checkResultsAsDataFrame(spark, checkResult)
checkResult_df.show(truncate=False)

実行結果は下記のようになります。

結果1

+-------------+-----------+------------+-----------------------------------------------------------------------------------------------------------+-----------------+------------------------------------------------------+
|check        |check_level|check_status|constraint                                                                                                 |constraint_status|constraint_message                                    |
+-------------+-----------+------------+-----------------------------------------------------------------------------------------------------------+-----------------+------------------------------------------------------+
|Warning Check|Warning    |Warning     |CompletenessConstraint(Completeness(a,None))                                                               |Success          |                                                      |
|Warning Check|Warning    |Warning     |CompletenessConstraint(Completeness(b,None))                                                               |Success          |                                                      |
|Warning Check|Warning    |Warning     |CompletenessConstraint(Completeness(c,None))                                                               |Failure          |Value: 0.875 does not meet the constraint requirement!|
|Error Check  |Error      |Success     |ComplianceConstraint(Compliance(a is positive,COALESCE(CAST(a AS DECIMAL(20,10)), 1.0) > 0,None))          |Success          |                                                      |
|Error Check  |Error      |Success     |UniquenessConstraint(Uniqueness(List(b),None))                                                             |Success          |                                                      |
|Error Check  |Error      |Success     |ComplianceConstraint(Compliance(c contained in foo,bar,baz,`c` IS NULL OR `c` IN ('foo','bar','baz'),None))|Success          |                                                      |
+-------------+-----------+------------+-----------------------------------------------------------------------------------------------------------+-----------------+------------------------------------------------------+

VerificationSuiteとCheckの2種類のオブジェクトから構成され、Checkオブジェクトにメソッドチェーンの形でテストしたい項目を追加していき、addCheckでCheckオブジェクトをVerificationSuiteに登録、runでテストを実行できます。

Checkがテストする内容は、関数とカラム名によって決まります。 isComplete("a") であれば、カラムaの値がすべてNullでないことを検証してくれます。検証の結果は、結果が格納されたDataFrameのconstraint_statusカラムに格納され、成功であればSuccess、失敗であればFailureが入ります。

Checkオブジェクトには、WarningとErrorの2種類のCheckLevelと、そのCheckについての説明を指定することができます。同じCheckオブジェクトに紐付いたテストが1つでも失敗すると、結果のcheck_statusは、設定されているCheckLevelに応じた値が入るようになります。

Checkオブジェクト毎にどのようなテストを追加するかは後処理でどのようにしたいかによって分けるのが良いでしょう。ErrorLevelでオブジェクトを分ける、カラムごとにオブジェクトを分ける、テスト項目ごとにオブジェクトを分ける、などがパターンとして考えられます。

複数カラムの関係をテストする

Deequでは、特定カラムだけでなく、複数のカラムの組み合わせについてテストすることもできます。

下記のように、テスト関数の引数が異なるだけで、基本的な使い方は単一カラムのテストの場合と使い方は変わりません。

サンプルコード2

from pyspark.sql.types import StructType, StructField, StringType, IntegerType
from pydeequ.checks import *
from pydeequ.verification import *

df = spark.createDataFrame(data=[
        (1, 10, "foo"), 
        (2, 20, "foo"), 
        (3, 30, "bar"), 
        (4, 40, "baz"), 
        (5, 50, "baz"), 
        (6, 60, "bar"), 
        (7, 70, None), 
        (8, 80, "bar"), 
], schema=StructType([
        StructField("a", IntegerType(), True),
        StructField("b", IntegerType(), True),
        StructField("c", StringType(), True),
]))

check_warning = Check(spark, CheckLevel.Warning, "Warning Check")

checkResult = VerificationSuite(spark) \
    .onData(df) \
    .addCheck(
        check_warning.isLessThan("a", "b") \
        .hasCorrelation("a", "b", lambda x: x >= 1.0) \
        .hasUniqueness(["b", "c"], lambda x: x >= 1.0)) \
    .run()

checkResult_df = VerificationResult.checkResultsAsDataFrame(spark, checkResult)
checkResult_df.show(truncate=False)

isLessThanは、文字通りの意味でaの値が対応するbの値よりも小さいことを確認するテストです。

hasCorrelationは、指定した2つのカラムの相関係数が、指定したassertion関数を満たすことを確認するテストです。

hasUniquenessは、hasCorrelationと同様で、指定したカラムすべてを考慮してユニークかどうかを判定し、その結果がassertion関数を満たすことを確認するテストです。

制約条件を満たさないことをテストする

has系のテスト関数は、上記で書いたようにassertion関数を満たすかどうかでテスト結果が決まります。

このテスト関数の性質を利用して、条件を満たさないことをテストすることも可能です。

サンプルコード3

from pyspark.sql.types import StructType, StructField, StringType, IntegerType
from pydeequ.checks import *
from pydeequ.verification import *

df = spark.createDataFrame(data=[
        (1, 10, "foo"), 
        (2, 20, "foo"), 
        (3, 30, "bar"), 
        (4, 40, "baz"), 
        (5, 50, "baz"), 
        (6, 60, "bar"), 
        (7, 70, None), 
        (8, 80, "bar"), 
], schema=StructType([
        StructField("a", IntegerType(), True),
        StructField("b", IntegerType(), True),
        StructField("c", StringType(), True),
]))

check_warning = Check(spark, CheckLevel.Warning, "Warning Check")

checkResult = VerificationSuite(spark) \
    .onData(df) \
    .addCheck(
        check_warning.containsEmail("c", lambda x: x == 0.0) \
        .hasMin("a", lambda x: x > 0)) \
    .run()

checkResult_df = VerificationResult.checkResultsAsDataFrame(spark, checkResult)
checkResult_df.show(truncate=False)

hasMinは、指定したカラムの最小値が条件を満たすことを確認するテストです。0より小さくないことを確認する場合は、assertion関数で最小値が0より大きくなることを確認すればよいです。

containsEmailは、デフォルトでは指定したカラムのすべての値がEmailの形式を満たすかを確認するテスト関数です。assertion関数をしていした場合、Emailの値を含むカラムの割合を確認することができます。つまりこの値が0であることを確認すれば、指定したカラムにEmail形式の文字列を含まないことをテストすることができます。

カスタマイズした内容をテストする

satisfies関数を使う事により、任意のSQLを満たすかどうか、もしくは、レコードのうち何割が満たすかをテストすることができます。

サンプルコード4

from pyspark.sql.types import StructType, StructField, StringType, IntegerType
from pydeequ.checks import *
from pydeequ.verification import *

df = spark.createDataFrame(data=[
        (1, 10, "foo"), 
        (2, 20, "foo"), 
        (3, 30, "bar"), 
        (4, 40, "baz"), 
        (5, 50, "baz"), 
        (6, 60, "bar"), 
        (7, 70, None), 
        (8, 80, "bar"), 
], schema=StructType([
        StructField("a", IntegerType(), True),
        StructField("b", IntegerType(), True),
        StructField("c", StringType(), True),
]))

check_warning = Check(spark, CheckLevel.Warning, "Warning Check")

checkResult = VerificationSuite(spark) \
    .onData(df) \
    .addCheck(
        check_warning.satisfies("b % 10 = 0", "B is 10 dividable", lambda x: x == 1.0) \
        .satisfies("rlike(c, '[0-9]+?')", "C is not contained numeric", lambda x: x == 0.0) \
        .satisfies("b / a = 10", "b / a is 10", lambda x: x == 1.0)) \
    .run()

checkResult_df = VerificationResult.checkResultsAsDataFrame(spark, checkResult)
checkResult_df.show(truncate=False)

satisfies関数の引数は、条件のSQL文、条件の説明文、条件を満たすレコードの割合のassertion関数の3つが必要です。

SQL文は、Sparkのfilter関数で使用できるSQLの条件文であれば何でも使用することができます。rlikeで正規表現でマッチさせたり、複数のカラム間の関係を条件にすることもできます。

説明文は、下記の結果の出力内に使用される文字列ですので、後で識別できるものであれば何でも大丈夫です。

assertion関数は、pythonのsatisfies関数の定義ではOptionalですが、指定していないとエラーになってしまいます。すべてのレコードを満たすことを確認する場合は lambda x: x == 1.0 を、満たさないことを確認する場合は lambda x: x == 0.0 を指定しておく必要があります。

結果4

+-------------+-----------+------------+-------------------------------------------------------------------------------------+-----------------+------------------+
|check        |check_level|check_status|constraint                                                                           |constraint_status|constraint_message|
+-------------+-----------+------------+-------------------------------------------------------------------------------------+-----------------+------------------+
|Warning Check|Warning    |Success     |ComplianceConstraint(Compliance(B is 10 dividable,b % 10 = 0,None))                  |Success          |                  |
|Warning Check|Warning    |Success     |ComplianceConstraint(Compliance(C is not contained numeric,rlike(c, '[0-9]+?'),None))|Success          |                  |
|Warning Check|Warning    |Success     |ComplianceConstraint(Compliance(b / a is 10,b / a = 10,None))                        |Success          |                  |
+-------------+-----------+------------+-------------------------------------------------------------------------------------+-----------------+------------------+

AnomalyDetectionの実行

基本的な使い方

AnomalyDetectionの基本的な実行は、PyDeequでは下記のように書くことで実現できます。

サンプルコード5

from pyspark.sql.types import StructType, StructField, StringType, IntegerType
from pydeequ.analyzers import *
from pydeequ.anomaly_detection import *
from pydeequ.repository import *
from pydeequ.verification import *
import datetime

now = datetime.datetime.now().timestamp()
repo = InMemoryMetricsRepository(spark)
yesterdaysKey = ResultKey(spark, int(now * 1000) - 24 * 60 * 60 * 1000)

df_yesterday = spark.createDataFrame(data=[
        (1, 10, "foo"), 
        (2, 20, "foo"), 
        (3, 30, "bar"), 
        (4, 40, "baz"), 
        (5, 50, "baz"), 
        (6, 60, "bar"), 
        (7, 70, None), 
        (8, 80, "bar"), 
], schema=StructType([
        StructField("a", IntegerType(), True),
        StructField("b", IntegerType(), True),
        StructField("c", StringType(), True),
]))

checkResult = AnalysisRunner(spark) \
    .onData(df_yesterday) \
    .useRepository(repo) \
    .saveOrAppendResult(yesterdaysKey) \
    .addAnalyzer(Size()) \
    .addAnalyzer(Maximum("a")) \
    .run()

df = spark.createDataFrame(data=[
        (1, 10, "foo"), 
        (2, 20, "foo"), 
        (3, 30, "bar"), 
        (4, 40, "baz"), 
        (5, 50, "baz"), 
        (6, 60, "bar"), 
        (7, 70, None), 
        (8, 80, "bar"), 
        (9, 90, "bar"), 
        (10, 100, "bar"), 
], schema=StructType([
        StructField("a", IntegerType(), True),
        StructField("b", IntegerType(), True),
        StructField("c", StringType(), True),
]))

todaysKey = ResultKey(spark, int(now * 1000))
checkResult = VerificationSuite(spark) \
    .onData(df) \
    .useRepository(repo) \
    .saveOrAppendResult(todaysKey) \
    .addAnomalyCheck(AbsoluteChangeStrategy(-1.0, 1.0), Size()) \
    .addAnomalyCheck(AbsoluteChangeStrategy(-2.0, 2.0), Size()) \
    .addAnomalyCheck(RelativeRateOfChangeStrategy(0.9, 1.1), Maximum("a")) \
    .addAnomalyCheck(RelativeRateOfChangeStrategy(0.7, 1.3), Maximum("a")) \
    .run()

checkResult_df = VerificationResult.checkResultsAsDataFrame(spark, checkResult)
checkResult_df.show(truncate=False)

repo.load() \
  .getSuccessMetricsAsDataFrame() \
  .show()

AnomalyDetectionでは、通常のテストと違い、Checkの代わりに、AnomalyDetectionの方法を指定するオブジェクト、Analyzerというテスト項目を指定するオブジェクト、Repositoryという今までのテスト項目の結果を保存してくれるオブジェクト、ResultKeyといういつの時間のデータを保存するかを指定するオブジェクトが必要になります。

基本的にはRepositoryのオブジェクトに保存されている直前(オブジェクトに保存されている順番で直前)のデータを各テスト項目について比較して結果を出力します。Repositoryオブジェクトへの結果の保存は、AnalyzerRunnerクラスを用いて、Analyzerによるテスト項目の計算のみを行いその結果を保存する方法と、VerificationSuiteでAnomalyDetectionも行いながら、テスト項目の計算結果を保存する2つのアプローチがあります。Repositoryオブジェクトが管理しているファイルやメモリ上のオブジェクトの結果は、保存するたびに結果が追記されていくため、VerificationSuiteでの保存は用途に使い分けるのが良さそうです。

AbsoluteChangeStrategyは、保存されている直前の値と比較して、変化量がmaxRateDecreaseとmaxRateIncreaseの範囲に収まっているかどうかを確認します。

RelativeRateOfChangeStrategyは、保存されている直前の値と比較して、変化率がmaxRateDecreaseとmaxRateIncreaseの範囲に収まっているかどうかを確認します。

結果5

+---------------------------------+-----------+------------+----------------------------------+-----------------+-----------------------------------------------------+
|check                            |check_level|check_status|constraint                        |constraint_status|constraint_message                                   |
+---------------------------------+-----------+------------+----------------------------------+-----------------+-----------------------------------------------------+
|Anomaly check for Size(None)     |Warning    |Warning     |AnomalyConstraint(Size(None))     |Failure          |Value: 10.0 does not meet the constraint requirement!|
|Anomaly check for Size(None)     |Warning    |Success     |AnomalyConstraint(Size(None))     |Success          |                                                     |
|Anomaly check for Maximum(a,None)|Warning    |Warning     |AnomalyConstraint(Maximum(a,None))|Failure          |Value: 10.0 does not meet the constraint requirement!|
|Anomaly check for Maximum(a,None)|Warning    |Success     |AnomalyConstraint(Maximum(a,None))|Success          |                                                     |
+---------------------------------+-----------+------------+----------------------------------+-----------------+-----------------------------------------------------+

テスト項目は、Analyzerクラスの種類と指定されているカラムの組み合わせで決まります。 サンプルコードでは、todaysKey に関連して実行されている addAnomalyCheck には Size()Maximum("a") が指定されているものが2つありますが、下記のようにレポジトリに保存されている内容は、1つのみとなります。

保存内容5

+-------+--------+-------+-----+-------------+
| entity|instance|   name|value| dataset_date|
+-------+--------+-------+-----+-------------+
|Dataset|       *|   Size|  8.0|1628216988556|
| Column|       a|Maximum|  8.0|1628216988556|
|Dataset|       *|   Size| 10.0|1628303388556|
| Column|       a|Maximum| 10.0|1628303388556|
+-------+--------+-------+-----+-------------+

傾向の変化を検出する

AnomalyDetectionでは、直前だけでなく過去の複数のデータを使って、テストを行うことができます。

サンプルコード6

from pyspark.sql.types import StructType, StructField, StringType, IntegerType
from pyspark.sql.types import StructType, StructField, StringType, IntegerType
from pydeequ.analyzers import *
from pydeequ.anomaly_detection import *
from pydeequ.repository import *
from pydeequ.verification import *
import datetime

now = datetime.datetime.now().timestamp()
repo = InMemoryMetricsRepository(spark)

for i in range(24):
  yesterdaysKey = ResultKey(spark, int(now * 1000) - (24 - i) * 60 * 60 * 1000)

  df_yesterday = spark.createDataFrame(data=[
          (1, 10, "foo"), 
          (2, 20, "foo"), 
          (3, 30, "bar"), 
          (4, 40, "baz"), 
          (5, 50, "baz"), 
          (6, 60, "bar"), 
          (7, 70, None), 
          (8, 80, "bar"), 
  ], schema=StructType([
          StructField("a", IntegerType(), True),
          StructField("b", IntegerType(), True),
          StructField("c", StringType(), True),
  ]))

  checkResult = AnalysisRunner(spark) \
      .onData(df_yesterday) \
      .useRepository(repo) \
      .saveOrAppendResult(yesterdaysKey) \
      .addAnalyzer(Mean("a")) \
      .addAnalyzer(Completeness("c")) \
      .run()

df = spark.createDataFrame(data=[
        (1, 10, "foo"), 
        (2, 20, "foo"), 
        (3, 30, "bar"), 
        (4, 40, "baz"), 
        (5, 50, None), 
        (6, 60, None), 
        (7, 70, None), 
        (8, 80, "bar"), 
        (9, 90, None), 
        (1000, 100, None), 
], schema=StructType([
        StructField("a", IntegerType(), True),
        StructField("b", IntegerType(), True),
        StructField("c", StringType(), True),
]))

todaysKey = ResultKey(spark, int(now * 1000))
checkResult = VerificationSuite(spark) \
    .onData(df) \
    .useRepository(repo) \
    .saveOrAppendResult(todaysKey) \
    .addAnomalyCheck(OnlineNormalStrategy(), Mean("a")) \
    .addAnomalyCheck(OnlineNormalStrategy(), Completeness("c")) \
    .run()

checkResult_df = VerificationResult.checkResultsAsDataFrame(spark, checkResult)
checkResult_df.show(truncate=False)

repo.load() \
  .getSuccessMetricsAsDataFrame() \
  .show()

OnlineNormalStrategyは、Analyzerで指定されたテスト項目の過去の値の平均と標準偏差を計算し、今回のテスト項目の値が mean - lowerDeviationFactor *stdDev と mean + upperDeviationFactor * stDev の間に収まっているかどうかを確認します。

ignoreAnomalies で履歴データ内の外れ値を無視して平均と標準偏差を計算してくれることを期待しますが、現状のDeequ側の実装では残念ながらそれらは無視されず、平均と標準偏差の計算に考慮されてしまいます。また、ウィンドウサイズのようなものを指定することができないため、Repositoryに保存されているデータを渡す前に特定日以前は切っておくというような操作が必要になります。

結果6

+--------------------------------------+-----------+------------+---------------------------------------+-----------------+------------------------------------------------------+
|check                                 |check_level|check_status|constraint                             |constraint_status|constraint_message                                    |
+--------------------------------------+-----------+------------+---------------------------------------+-----------------+------------------------------------------------------+
|Anomaly check for Mean(a,None)        |Warning    |Warning     |AnomalyConstraint(Mean(a,None))        |Failure          |Value: 104.5 does not meet the constraint requirement!|
|Anomaly check for Completeness(c,None)|Warning    |Warning     |AnomalyConstraint(Completeness(c,None))|Failure          |Value: 0.5 does not meet the constraint requirement!  |
+--------------------------------------+-----------+------------+---------------------------------------+-----------------+------------------------------------------------------+

季節性を考慮した変化の検出する

履歴をもとに AnomalyDetection を行う場合は、周期性を考慮して行いたいケースが多くあります。Deequでは、週単位、年単位の周期を考慮してテストすることができます。

サンプルコード7

from pyspark.sql.types import StructType, StructField, StringType, IntegerType
from pydeequ.analyzers import *
from pydeequ.anomaly_detection import *
from pydeequ.repository import *
from pydeequ.verification import *
import datetime

now = datetime.datetime.now().timestamp()
repo = InMemoryMetricsRepository(spark)

for k in range(2):
  for i in range(1, 7):
    yesterdaysKey = ResultKey(spark, int(now * 1000) - (k*7+i) * 24 * 60 * 60 * 1000)

    df_yesterday = spark.createDataFrame(data=[
            (1, 10, "foo"), 
            (2, 20, "foo"), 
            (3, 30, "bar"), 
            (4, 40, "baz"), 
            (5, 50, "baz"), 
            (6, 60, "bar"), 
            (7, 70, None), 
            (8, 80, "bar"), 
    ], schema=StructType([
            StructField("a", IntegerType(), True),
            StructField("b", IntegerType(), True),
            StructField("c", StringType(), True),
    ]))

    checkResult = AnalysisRunner(spark) \
        .onData(df_yesterday) \
        .useRepository(repo) \
        .saveOrAppendResult(yesterdaysKey) \
        .addAnalyzer(Mean("a")) \
        .addAnalyzer(Completeness("c")) \
        .run()
  yesterdaysKey = ResultKey(spark, int(now * 1000) - (k*7+7) * 24 * 60 * 60 * 1000)

  df_yesterday = spark.createDataFrame(data=[
          (1, 10, "foo"), 
          (2, 20, "foo"), 
          (3, 30, "bar"), 
          (4, 40, "baz"), 
          (5, 50, None), 
          (6, 60, None), 
          (7, 70, None), 
          (8, 80, None), 
  ], schema=StructType([
          StructField("a", IntegerType(), True),
          StructField("b", IntegerType(), True),
          StructField("c", StringType(), True),
  ]))

  checkResult = AnalysisRunner(spark) \
      .onData(df_yesterday) \
      .useRepository(repo) \
      .saveOrAppendResult(yesterdaysKey) \
      .addAnalyzer(Mean("a")) \
      .addAnalyzer(Completeness("c")) \
      .run()

df = spark.createDataFrame(data=[
        (1, 10, "foo"), 
        (2, 20, "foo"), 
        (3, 30, "bar"), 
        (4, 40, "baz"), 
        (5, 50, None), 
        (6, 60, None), 
        (7, 70, None), 
        (8, 80, "bar"), 
        (9, 90, None), 
        (1000, 100, None), 
], schema=StructType([
        StructField("a", IntegerType(), True),
        StructField("b", IntegerType(), True),
        StructField("c", StringType(), True),
]))

todaysKey = ResultKey(spark, int(now * 1000))
checkResult = VerificationSuite(spark) \
    .onData(df) \
    .useRepository(repo) \
    .saveOrAppendResult(todaysKey) \
    .addAnomalyCheck(HoltWinters(MetricInterval.Daily, SeriesSeasonality.Weekly), Mean("a")) \
    .addAnomalyCheck(HoltWinters(MetricInterval.Daily, SeriesSeasonality.Weekly), Completeness("c")) \
    .run()

checkResult_df = VerificationResult.checkResultsAsDataFrame(spark, checkResult)
checkResult_df.show(truncate=False)

repo.load() \
  .getSuccessMetricsAsDataFrame() \
  .show()

HoltWintersは、データの頻度と周期の2つを指定することで、該当のテスト項目が履歴データから外れたものでないかを確認することができます。

結果7

+--------------------------------------+-----------+------------+---------------------------------------+-----------------+------------------------------------------------------+
|check                                 |check_level|check_status|constraint                             |constraint_status|constraint_message                                    |
+--------------------------------------+-----------+------------+---------------------------------------+-----------------+------------------------------------------------------+
|Anomaly check for Mean(a,None)        |Warning    |Warning     |AnomalyConstraint(Mean(a,None))        |Failure          |Value: 104.5 does not meet the constraint requirement!|
|Anomaly check for Completeness(c,None)|Warning    |Success     |AnomalyConstraint(Completeness(c,None))|Success          |                                                      |
+--------------------------------------+-----------+------------+---------------------------------------+-----------------+------------------------------------------------------+

Repositoryの格納内容

Repositoryには下記のように、すべての時間のすべてのAnalyzerのテスト結果が1つのファイル内(メモリだと1つのオブジェクト)に格納されます。

保存内容

[
  {
    "resultKey": {
      "dataSetDate": 1628402595296,
      "tags": {}
    },
    "analyzerContext": {
      "metricMap": [
        {
          "analyzer": {
            "analyzerName": "Size"
          },
          "metric": {
            "metricName": "DoubleMetric",
            "entity": "Dataset",
            "instance": "*",
            "name": "Size",
            "value": 8.0
          }
        },
        {
          "analyzer": {
            "analyzerName": "Maximum",
            "column": "a"
          },
          "metric": {
            "metricName": "DoubleMetric",
            "entity": "Column",
            "instance": "a",
            "name": "Maximum",
            "value": 8.0
          }
        }
      ]
    }
  },
  {
    "resultKey": {
      "dataSetDate": 1628316195296,
      "tags": {}
    },
    "analyzerContext": {
      "metricMap": [
        {
          "analyzer": {
            "analyzerName": "Size"
          },
          "metric": {
            "metricName": "DoubleMetric",
            "entity": "Dataset",
            "instance": "*",
            "name": "Size",
            "value": 10.0
          }
        },
        {
          "analyzer": {
            "analyzerName": "Maximum",
            "column": "a"
          },
          "metric": {
            "metricName": "DoubleMetric",
            "entity": "Column",
            "instance": "a",
            "name": "Maximum",
            "value": 10.0
          }
        }
      ]
    }
  },
]

その他の機能

データのプロファイリング

Profilerを用いることで、どのようなデータが格納されるカラムかを簡単に確認することができます。出力された値をもとにどのようなテスト項目があると良さそうかを、自動もしくは手動で判定するのに使うことが主な用途として想定されます。

サンプルコード8

from pydeequ.profiles import *

df = spark.createDataFrame(data=[
        (1, 10, "foo"), 
        (2, 20, "foo"), 
        (3, 30, "bar"), 
        (4, 40, "baz"), 
        (5, 50, "baz"), 
        (6, 60, "bar"), 
        (7, 70, None), 
        (8, 80, "bar"), 
        (9, 90, "bar"), 
        (10, 100, "bar"), 
], schema=StructType([
        StructField("a", IntegerType(), True),
        StructField("b", IntegerType(), True),
        StructField("c", StringType(), True),
]))

result = ColumnProfilerRunner(spark) \
    .onData(df) \
    .run()

for col, profile in result.profiles.items():
    print(profile)

取得できる主な項目は、completeness、approximateNumDistinctValues、dataType、histogram、meanなどの統計量です。

completenessはNullでない値の割合、approximateNumDistinctValuesは推定した値の種類、dataTypeはデータの型、histogramはどの値がどれだけ(数と割合)含まれているかの配列、統計量は数字カラムのみでそのカラム全体の平均や最大値などを返してくれます。

結果8

NumericProfiles for column: a: {
    "completeness": 1.0,
    "approximateNumDistinctValues": 10,
    "dataType": "Integral",
    "isDataTypeInferred": false,
    "typeCounts": {},
    "histogram": [
        [
            "8",
            1,
            0.1
        ],
        [
            "4",
            1,
            0.1
        ],
        [
            "9",
            1,
            0.1
        ],
        [
            "5",
            1,
            0.1
        ],
        [
            "10",
            1,
            0.1
        ],
        [
            "6",
            1,
            0.1
        ],
        [
            "1",
            1,
            0.1
        ],
        [
            "2",
            1,
            0.1
        ],
        [
            "7",
            1,
            0.1
        ],
        [
            "3",
            1,
            0.1
        ]
    ],
    "kll": "None",
    "mean": 5.5,
    "maximum": 10.0,
    "minimum": 1.0,
    "sum": 55.0,
    "stdDev": 2.8722813232690143,
    "approxPercentiles": []
}
NumericProfiles for column: b: {
    "completeness": 1.0,
    "approximateNumDistinctValues": 10,
    "dataType": "Integral",
    "isDataTypeInferred": false,
    "typeCounts": {},
    "histogram": [
        [
            "100",
            1,
            0.1
        ],
        [
            "40",
            1,
            0.1
        ],
        [
            "90",
            1,
            0.1
        ],
        [
            "50",
            1,
            0.1
        ],
        [
            "10",
            1,
            0.1
        ],
        [
            "80",
            1,
            0.1
        ],
        [
            "60",
            1,
            0.1
        ],
        [
            "20",
            1,
            0.1
        ],
        [
            "70",
            1,
            0.1
        ],
        [
            "30",
            1,
            0.1
        ]
    ],
    "kll": "None",
    "mean": 55.0,
    "maximum": 100.0,
    "minimum": 10.0,
    "sum": 550.0,
    "stdDev": 28.722813232690143,
    "approxPercentiles": []
}
StandardProfiles for column: c: {
    "completeness": 0.9,
    "approximateNumDistinctValues": 3,
    "dataType": "String",
    "isDataTypeInferred": false,
    "typeCounts": {
        "Boolean": 0,
        "Fractional": 0,
        "Integral": 0,
        "Unknown": 1,
        "String": 9
    },
    "histogram": [
        [
            "bar",
            5,
            0.5
        ],
        [
            "baz",
            2,
            0.2
        ],
        [
            "foo",
            2,
            0.2
        ],
        [
            "NullValue",
            1,
            0.1
        ]
    ]
}

テスト項目のレコメンデーション

ConstraintSuggestionは、Profilerからさらに進んで、Checkオブジェクトにどのようなテスト項目を追加したほうが良さそうかを、レコメンドしてくれます。

サンプルコード9

from pydeequ.suggestions import *

df = spark.createDataFrame(data=[
        (1, 10, "foo"), 
        (2, 20, "foo"), 
        (3, 30, "bar"), 
        (4, 40, "baz"), 
        (5, 50, "baz"), 
        (6, 60, "bar"), 
        (7, 70, None), 
        (8, 80, "bar"), 
        (9, 90, "bar"), 
        (10, 100, "bar"), 
], schema=StructType([
        StructField("a", IntegerType(), True),
        StructField("b", IntegerType(), True),
        StructField("c", StringType(), True),
]))

suggestionResult = ConstraintSuggestionRunner(spark) \
             .onData(df) \
             .addConstraintRule(DEFAULT()) \
             .run()

for item in suggestionResult['constraint_suggestions']:
  print(item)
  print()

レコメンドされる対象のルールは、addConstraintRuleで指定する事ができ、DEFAULTではsuggestionsモジュール配下のすべてのルールが含まれます。含まれるルールには、Nullでないかを確認するテストを追加すべきか判定する CompleteIfCompleteRule 、負の値が含まれていないかを確認するテストを追加すべきかを判定する NonNegativeNumbersRule などがあります。

constraint_suggestions の各結果には、constraint_name、column_name、current_value、description、suggesting_rule、rule_description、code_for_constraintの各値が含まれます。利用上もっとも重要なのが code_for_constraint で、Checkオブジェクトに該当のテスト項目を追加するための実装がそのまま記載されています。

結果9

{'constraint_name': 'CompletenessConstraint(Completeness(b,None))', 'column_name': 'b', 'current_value': 'Completeness: 1.0', 'description': "'b' is not null", 'suggesting_rule': 'CompleteIfCompleteRule()', 'rule_description': 'If a column is complete in the sample, we suggest a NOT NULL constraint', 'code_for_constraint': '.isComplete("b")'}

{'constraint_name': "ComplianceConstraint(Compliance('b' has no negative values,b >= 0,None))", 'column_name': 'b', 'current_value': 'Minimum: 10.0', 'description': "'b' has no negative values", 'suggesting_rule': 'NonNegativeNumbersRule()', 'rule_description': 'If we see only non-negative numbers in a column, we suggest a corresponding constraint', 'code_for_constraint': '.isNonNegative("b")'}

{'constraint_name': 'UniquenessConstraint(Uniqueness(List(b),None))', 'column_name': 'b', 'current_value': 'ApproxDistinctness: 1.0', 'description': "'b' is unique", 'suggesting_rule': 'UniqueIfApproximatelyUniqueRule()', 'rule_description': 'If the ratio of approximate num distinct values in a column is close to the number of records (within the error of the HLL sketch), we suggest a UNIQUE constraint', 'code_for_constraint': '.isUnique("b")'}

{'constraint_name': 'CompletenessConstraint(Completeness(a,None))', 'column_name': 'a', 'current_value': 'Completeness: 1.0', 'description': "'a' is not null", 'suggesting_rule': 'CompleteIfCompleteRule()', 'rule_description': 'If a column is complete in the sample, we suggest a NOT NULL constraint', 'code_for_constraint': '.isComplete("a")'}

{'constraint_name': "ComplianceConstraint(Compliance('a' has no negative values,a >= 0,None))", 'column_name': 'a', 'current_value': 'Minimum: 1.0', 'description': "'a' has no negative values", 'suggesting_rule': 'NonNegativeNumbersRule()', 'rule_description': 'If we see only non-negative numbers in a column, we suggest a corresponding constraint', 'code_for_constraint': '.isNonNegative("a")'}

{'constraint_name': 'UniquenessConstraint(Uniqueness(List(a),None))', 'column_name': 'a', 'current_value': 'ApproxDistinctness: 1.0', 'description': "'a' is unique", 'suggesting_rule': 'UniqueIfApproximatelyUniqueRule()', 'rule_description': 'If the ratio of approximate num distinct values in a column is close to the number of records (within the error of the HLL sketch), we suggest a UNIQUE constraint', 'code_for_constraint': '.isUnique("a")'}

{'constraint_name': "ComplianceConstraint(Compliance('c' has value range 'bar', 'baz', 'foo' for at least 99.0% of values,`c` IN ('bar', 'baz', 'foo'),None))", 'column_name': 'c', 'current_value': 'Compliance: 0.9999999999999999', 'description': "'c' has value range 'bar', 'baz', 'foo' for at least 99.0% of values", 'suggesting_rule': 'FractionalCategoricalRangeRule(0.9)', 'rule_description': 'If we see a categorical range for most values in a column, we suggest an IS IN (...) constraint that should hold for most values', 'code_for_constraint': '.isContainedIn("c", ["bar", "baz", "foo"], lambda x: x >= 0.99, "It should be above 0.99!")'}

{'constraint_name': 'CompletenessConstraint(Completeness(c,None))', 'column_name': 'c', 'current_value': 'Completeness: 0.9', 'description': "'c' has less than 29% missing values", 'suggesting_rule': 'RetainCompletenessRule()', 'rule_description': 'If a column is incomplete in the sample, we model its completeness as a binomial variable, estimate a confidence interval and use this to define a lower bound for the completeness', 'code_for_constraint': '.hasCompleteness("c", lambda x: x >= 0.71, "It should be above 0.71!")'}

TensorFlow Array Indexing Correspond to numpy

TensorFlowで配列処理を効率的に行うのはなかなか難しいことがあります。

例えば、下記のようなIndexing処理はnumpyでは簡単に実現することができますが、TnesorFlowではそうはいきません。

a[:, [2, 3]]

スライス以外の方法でインデックスを指定して値を取得する際には、下記のように tf.gather もしくは tf.gather_nd 関数を利用する必要があります。

tf.gather(a, [2, 3], axis=1)

また、配列の値の更新には tf.tensor_scatter_nd_update を使用する必要があります。

first = tf.tile(tf.expand_dims(tf.range(4), axis=1), (1, 2))
indexes = tf.tile([[2, 3]], (4, 1))
indices = tf.stack([first, indexes], axis=-1)
a = tf.tensor_scatter_nd_update(a, indices, tf.ones((4, 2, 4)) * 2)

この記事では、頻出するインデクシングのシチュエーションにおいて、TensorFlowでの値の取得方法、更新方法を記載していきます。

コードは、記事で紹介している以外の実装も含めて下記に置いています。

TensorFlow Indexing.ipynb · GitHub

目次

Slicingのみの場合

Slicingのみの場合、TensorFlowでも簡単にIndexingを実現できます。

numpyでの以下のような値の取得と、以下のような値の更新を行うケースを考えます。

a = np.ones((4, 4, 4))
a[:, 2:3]
a[:, 2:3] = np.ones((4, 1, 4)) * 2

このケースでは、TensorFlowでもほとんど同じように記述することができます。

a = tf.ones((4, 4, 4))
a[:, 2:3]
a = a[:, 2:3].assign(tf.ones((4, 1, 4)) * 2)

Boolean Array によるIndexingの場合

Boolean Array によるIndexingは、少し特殊な書き方が必要になりますが、パターンが分かれば簡単に実現できます。

下記のようなnumpyでの値の取得と値の更新を行うケースを考えます。

a = np.ones((4, 4, 4))
d = np.array([[[True, True, True, True], [False, False, False, False], [False, False, False, False], [False, False, False, False]]] * 4)
a[d]
a[d] = 2

値の取得は問題なくnumpyと同様に扱うことができます。

一方、値の更新はややトリッキーな書き方をする必要があります。Bool値の配列を1,0の配列に変換することにより、Indexingの配列がTrueの場合に代入する配列の値を、Falseの場合には元の配列の値を使用するような配列を作成する必要があります。

この方法では、元の配列と同じサイズの配列を用意する必要がありますが、実用上の大体のケースでは、特定の値に更新するか、元々同じサイズの配列の値に一部置き換えるといったようなケースであるため、そこまで問題にはならないでしょう。

a = tf.ones((4, 4, 4))
d = tf.constant([[[True, True, True, True], [False, False, False, False], [False, False, False, False], [False, False, False, False]]] * 4)
a[d]
d = tf.cast(d, dtype=a.dtype)
a = (1 - d) * a + d * np.ones((4, 4, 4)) * 2

Integer Array によるIndexingの場合

Integer Arrayを用いたIndexingは、TensorFlowの機能をフルに使用する必要があります。

下記のようなnumpyでの値の取得と値の更新を行うケースを考えます。

a = np.ones((4, 4, 4))
b = np.array([[2, 3]])
c = np.array([1, 2])
a[:, [2, 3]]
a[b, c]
a[:, [2, 3]] = np.ones((4, 2, 4)) * 2
a[b, c] = np.ones((1, 2, 4)) * 2

Integer Array による値の取得

値の取得は、下記のように tf.gather および tf.gather_nd を使用する必要があります。

a = tf.ones((4, 4, 4))
b = tf.constant([[2, 3]])
c = tf.constant([1, 2])
tf.gather(a, [2, 3], axis=1)
tf.gather_nd(a, tf.stack([b, [c]], axis=-1))

tf.gatherは、最初に対象の配列、次にIndexingをする対象を示す1次元の配列、axisにどの次元のIndexingを行うかを指定します。

つまり、tf.gather(a, [2, 3], axis=1)を実行すると、(4, 2, 4)の配列が取得できることになります。

tf.gather_ndは、tf.gatherを多次元の配列でIndexingするように拡張したものです。ただし、その配列の値の並びの解釈はtf.gatherとは異なります。

tf.gatherでは[2, 3]が与えられた場合、この配列は同じ次元の2番目と3番目の値を取得することを示しているのに対し、tf.gather_ndでは1次元目の2番目、2次元目の3番目の値を取得することを意味します。

つまり、tf.gatherはnumpyでいう a[[2, 3]] の挙動であり、tf.gather_ndは a[2, 3] の挙動と似たような挙動を示します。

numpyの配列と違い、tf.gather_ndは複数の[2, 3]のペアを与えることができ、またその配列のshapeに合わせて出力のshapeが変化します。

例えば、上記の行列aに対して、tf.gather_nd(a, [2,3]) を呼び出すと、numpyで言う a[2, 3] と同じ出力を得ることができますが、tf.gather_nd(a, [[2,3], [1, 2]]) を呼び出すと、 a[2, 3] の結果と a[1, 2] の結果を縦方向にstackした値を取得できます。

また、Indexingを示す配列のndimsは制限されておらず、 [[[[2,3]]]] のような配列を指定することができ、その場合の結果のshapeは(1, 1, 1, 4)になります。

注意点としては、最終次元の配列の次元数は元の配列のndimsより大きくなることはできず、また、個々にIndexingした結果はstackされることになるため、stackできるような配列になっている必要があります。

Integer Array による値の更新

値の更新は tf.tensor_scatter_nd_update を使用する事により実現できますが、tf.gatherに相当する関数がないため、Slicingを必要とする配列の更新の場合には一工夫必要になります。

具体的には、下記のような処理になります。

a = tf.Variable(tf.ones((4, 4, 4)))
b = tf.constant([[2, 3]])
c = tf.constant([1, 2])

# likely a[:, [2, 3]] = np.ones((4, 2, 4)) * 2
first = tf.tile(tf.expand_dims(tf.range(4), axis=1), (1, 2))
indexes = tf.tile([[2, 3]], (4, 1))
indices = tf.stack([first, indexes], axis=-1)
a = tf.tensor_scatter_nd_update(a, indices, tf.ones((4, 2, 4)) * 2)

# likely a[b, c] = np.ones((1, 2, 4)) * 2
indices = tf.stack([b, [c]], axis=-1)
a = tf.tensor_scatter_nd_update(a, indices, tf.ones((1, 2, 4)) * 2)

Indexingを行う対象を示す配列の仕様はtf.gather_ndと全く同じです。つまり、配列の1次元目から順にその何番目の値を更新するかを指定する必要があるため、スライス相当のIndexの指定を、tf.rangeやtf.tileなどを駆使して実装者が行う必要があります。

一方で、スライス相当の処理が必要ない場合は、tf.gather_ndと同じような考えで実現できます。代入する配列だけ、代入した領域と同じshapeの配列が必要になる点だけは注意が必要です。

Keras Loss Behavior with Language Model

KerasのModelクラスを使用した際のロスの計算は、Paddingで追加した余計な値を勾配の計算から除外する処理は自動でやってくれるのですが、
historyに記録されるlossの平均値を求める際に、maskを部分的にしか考慮しておらず、padding数が多くなればなるほど、実際のロスより小さくなってしまうという現象が発生します。

この記事は、KerasのModelクラスでLossを利用する、
特にEmbedding層で、mask_zeroをTrueにした場合に、Paddingで追加した余計な値を、勾配の計算に使用しない、ロスの計算から完全に除外する方法についてのメモです。

検証用のコードはこちらです。

マスクを使用したロスの計算は、TensorFlowのチュートリアルを参考にしています。

目次

KerasのLossの種類と種類毎の処理の違い

Kerasのmodel.compileで指定できるLossの種類は大きく分けると以下の3つがあります。
※ 参考:compile内で呼ばれているtf.keras.Model.prepare_loss_functions

  • tf.keras.losses.Lossの派生クラス

  • 独自定義したloss関数

  • 独自定義したcallableなLossオブジェクト

このうち、「独自定義したloss関数」は、prepare_loss_functions内でtf.keras.losses.Lossの派生クラスであるLossFunctionWrapperクラスに変換されるため、実質は2種類のLossオブジェクトが使い分けられることになります。

Lossを含んだ計算グラフの構築は、tf.keras.Model._prepare_total_loss内で行われます。
この関数の中で、上記2つのオブジェクトは別々の処理が行われことになります。

tf.keras.losses.Lossの派生クラスの場合の処理

tf.keras.losses.Lossの場合は、Kerasの内部関数であるtf.keras.utils.losses_util.compute_weighted_loss内で、以下の処理が行われます。

  1. ロス関数の呼び出しによるロスの計算

  2. tf.keras.utils.losses_util.scale_losses_by_sample_weightでロス×sample_weightsを計算

  3. Lossオブジェクトのreductionに応じた集計をtf.keras.utils.losses_util.reduce_weighted_lossで計算。大体の場合は、ロスの平均値を計算する処理

これから、sample_weightsを使用しない場合は、独自定義のロス関数でどんなshapeの値を返しても良いことがわかります。

ただし、sample_weightsとしてEmbedding層で計算したmaskが暗黙的に使用される場合があるので後述のように注意が必要です。

独自定義したcallableなLossオブジェクトの場合の処理

正解ラベルとニューラルネットワークの出力だけでなく、sample_weightsも引数に渡されて、Lossオブジェクトのcall関数が呼ばれます。
ベクトル値を返した場合、ベクトルの平均値が最終のロスとして使用されます。

Embedding層でmask_zeroをTrueにした場合のLossの挙動

Embedding層でmask_zeroをTrueに指定した場合、sample_weightsとしてEmbedding層で計算したmaskが暗黙的に使用されます。
maskはpadding処理で0埋めした部分はFalse、それ以外はTrueとなり、paddingで埋めた部分の計算に勾配が伝わらないようにしてくれます。

ただし、あくまで勾配が伝わらないようにしてくれるだけで、ロスの平均を取る際の分母の数をmask分減らしてはくれません。
このため、Perplexity等の計算をこのLossが出力した値を元に行うと、paddingで埋める長さが長いほど小さい値になってしまいおかしくなります。

例えば、正解ラベルとして  [1, 1, 2, 0, 0 ] の系列データがあるケースを考えます。

この時、ネットワークの出力として、各時点でのラベルの予測確率が  [0.2, 0.2, 0.2, 0.2, 0.2 ]と得られたとします。

0は計算処理の都合上いれているだけのデータで、実際の処理では無視するため、このデータに対するクロスエントロピーロスは下記のようになります。

 -(log(0.2) + log(0.2) + log(0.2)) / 3. = 1.60943...

しかしKerasのModelクラスを使用して計算すると

 -(log(0.2) + log(0.2) + log(0.2)) / 5. = 0.96566...

という値になってしまい、実際の値に比べてかなり小さくなってしまいます。

これを避けるためには、検証コード内LOSS_MODEが2の場合のように、独自定義したLossオブジェクトで、padiding部分の値を元にした処理に勾配が行かないように、かつ平均計算時の分母からの除去するように必要があります。

ちなみにMetoricsはレートに直す際に、sample_weightsの合計で割るような実装になっているため、maskの場合も意図した動作になります。