終末 A.I.

データいじりや機械学習するエンジニアのブログ

TensorFlow Array Indexing Correspond to numpy

TensorFlowで配列処理を効率的に行うのはなかなか難しいことがあります。

例えば、下記のようなIndexing処理はnumpyでは簡単に実現することができますが、TnesorFlowではそうはいきません。

a[:, [2, 3]]

スライス以外の方法でインデックスを指定して値を取得する際には、下記のように tf.gather もしくは tf.gather_nd 関数を利用する必要があります。

tf.gather(a, [2, 3], axis=1)

また、配列の値の更新には tf.tensor_scatter_nd_update を使用する必要があります。

first = tf.tile(tf.expand_dims(tf.range(4), axis=1), (1, 2))
indexes = tf.tile([[2, 3]], (4, 1))
indices = tf.stack([first, indexes], axis=-1)
a = tf.tensor_scatter_nd_update(a, indices, tf.ones((4, 2, 4)) * 2)

この記事では、頻出するインデクシングのシチュエーションにおいて、TensorFlowでの値の取得方法、更新方法を記載していきます。

コードは、記事で紹介している以外の実装も含めて下記に置いています。

TensorFlow Indexing.ipynb · GitHub

目次

Slicingのみの場合

Slicingのみの場合、TensorFlowでも簡単にIndexingを実現できます。

numpyでの以下のような値の取得と、以下のような値の更新を行うケースを考えます。

a = np.ones((4, 4, 4))
a[:, 2:3]
a[:, 2:3] = np.ones((4, 1, 4)) * 2

このケースでは、TensorFlowでもほとんど同じように記述することができます。

a = tf.ones((4, 4, 4))
a[:, 2:3]
a = a[:, 2:3].assign(tf.ones((4, 1, 4)) * 2)

Boolean Array によるIndexingの場合

Boolean Array によるIndexingは、少し特殊な書き方が必要になりますが、パターンが分かれば簡単に実現できます。

下記のようなnumpyでの値の取得と値の更新を行うケースを考えます。

a = np.ones((4, 4, 4))
d = np.array([[[True, True, True, True], [False, False, False, False], [False, False, False, False], [False, False, False, False]]] * 4)
a[d]
a[d] = 2

値の取得は問題なくnumpyと同様に扱うことができます。

一方、値の更新はややトリッキーな書き方をする必要があります。Bool値の配列を1,0の配列に変換することにより、Indexingの配列がTrueの場合に代入する配列の値を、Falseの場合には元の配列の値を使用するような配列を作成する必要があります。

この方法では、元の配列と同じサイズの配列を用意する必要がありますが、実用上の大体のケースでは、特定の値に更新するか、元々同じサイズの配列の値に一部置き換えるといったようなケースであるため、そこまで問題にはならないでしょう。

a = tf.ones((4, 4, 4))
d = tf.constant([[[True, True, True, True], [False, False, False, False], [False, False, False, False], [False, False, False, False]]] * 4)
a[d]
d = tf.cast(d, dtype=a.dtype)
a = (1 - d) * a + d * np.ones((4, 4, 4)) * 2

Integer Array によるIndexingの場合

Integer Arrayを用いたIndexingは、TensorFlowの機能をフルに使用する必要があります。

下記のようなnumpyでの値の取得と値の更新を行うケースを考えます。

a = np.ones((4, 4, 4))
b = np.array([[2, 3]])
c = np.array([1, 2])
a[:, [2, 3]]
a[b, c]
a[:, [2, 3]] = np.ones((4, 2, 4)) * 2
a[b, c] = np.ones((1, 2, 4)) * 2

Integer Array による値の取得

値の取得は、下記のように tf.gather および tf.gather_nd を使用する必要があります。

a = tf.ones((4, 4, 4))
b = tf.constant([[2, 3]])
c = tf.constant([1, 2])
tf.gather(a, [2, 3], axis=1)
tf.gather_nd(a, tf.stack([b, [c]], axis=-1))

tf.gatherは、最初に対象の配列、次にIndexingをする対象を示す1次元の配列、axisにどの次元のIndexingを行うかを指定します。

つまり、tf.gather(a, [2, 3], axis=1)を実行すると、(4, 2, 4)の配列が取得できることになります。

tf.gather_ndは、tf.gatherを多次元の配列でIndexingするように拡張したものです。ただし、その配列の値の並びの解釈はtf.gatherとは異なります。

tf.gatherでは[2, 3]が与えられた場合、この配列は同じ次元の2番目と3番目の値を取得することを示しているのに対し、tf.gather_ndでは1次元目の2番目、2次元目の3番目の値を取得することを意味します。

つまり、tf.gatherはnumpyでいう a[[2, 3]] の挙動であり、tf.gather_ndは a[2, 3] の挙動と似たような挙動を示します。

numpyの配列と違い、tf.gather_ndは複数の[2, 3]のペアを与えることができ、またその配列のshapeに合わせて出力のshapeが変化します。

例えば、上記の行列aに対して、tf.gather_nd(a, [2,3]) を呼び出すと、numpyで言う a[2, 3] と同じ出力を得ることができますが、tf.gather_nd(a, [[2,3], [1, 2]]) を呼び出すと、 a[2, 3] の結果と a[1, 2] の結果を縦方向にstackした値を取得できます。

また、Indexingを示す配列のndimsは制限されておらず、 [[[[2,3]]]] のような配列を指定することができ、その場合の結果のshapeは(1, 1, 1, 4)になります。

注意点としては、最終次元の配列の次元数は元の配列のndimsより大きくなることはできず、また、個々にIndexingした結果はstackされることになるため、stackできるような配列になっている必要があります。

Integer Array による値の更新

値の更新は tf.tensor_scatter_nd_update を使用する事により実現できますが、tf.gatherに相当する関数がないため、Slicingを必要とする配列の更新の場合には一工夫必要になります。

具体的には、下記のような処理になります。

a = tf.Variable(tf.ones((4, 4, 4)))
b = tf.constant([[2, 3]])
c = tf.constant([1, 2])

# likely a[:, [2, 3]] = np.ones((4, 2, 4)) * 2
first = tf.tile(tf.expand_dims(tf.range(4), axis=1), (1, 2))
indexes = tf.tile([[2, 3]], (4, 1))
indices = tf.stack([first, indexes], axis=-1)
a = tf.tensor_scatter_nd_update(a, indices, tf.ones((4, 2, 4)) * 2)

# likely a[b, c] = np.ones((1, 2, 4)) * 2
indices = tf.stack([b, [c]], axis=-1)
a = tf.tensor_scatter_nd_update(a, indices, tf.ones((1, 2, 4)) * 2)

Indexingを行う対象を示す配列の仕様はtf.gather_ndと全く同じです。つまり、配列の1次元目から順にその何番目の値を更新するかを指定する必要があるため、スライス相当のIndexの指定を、tf.rangeやtf.tileなどを駆使して実装者が行う必要があります。

一方で、スライス相当の処理が必要ない場合は、tf.gather_ndと同じような考えで実現できます。代入する配列だけ、代入した領域と同じshapeの配列が必要になる点だけは注意が必要です。