敵対的生成ネットワークGAN

敵対的生成ネットワーク

敵対的生成ネットワークは、英語だとGenerative Adversarial Networkで、GANと略されています。

以下、敵対的生成モデルをGANと呼びます。

元データがどのような分布になっているかを推測し、その分布に基づいて、元データと同じようなデータを生成することを目的としたモデルを生成モデルと言います。

とりわけ、ディープラーニングを取り入れた生成モデルのことを深層生成モデルと呼びます。

GANは、この深層生成モデルの一種です。

敵対的生成ネットワークのネットワーク

GANは、2種類のネットワークで構成されています。

それぞれ、ジェネレータとディスクリミネータと呼ばれます。

  1. ジェネレータ
  2. 入力として潜在空間のランダムベクトルを受け取り、画像を生成して出力するネットワーク

  3. ディスクリミネータ
  4. 入力として画像を受け取る。
    その画像が本物か偽物かを予測して出力するネットワーク
    (ジェネレータで生成されたか、されてないか)

この2つのネットワークは、敵対関係にあります。

この関係から、敵対的生成ネットワークとなっています。

ジェネレータで、画像(False)を生成し、ディスクリミネータでその画像が本物(True)かFalseかを予測し、ジェネレータが勝ったら、学習終了となります。

これらの構造をGANと呼びます。

ジェネレータ、ディスクリミネータのそれぞれのネットワークにCNNを用いたモデルをDCGANと呼びます。

$$
\min _ { G } \max _ { D } V ( D , G ) = \mathbb { E } _ { \boldsymbol { x } \sim p _ { \mathrm { data } } ( \boldsymbol { x } ) } [ \log D ( \boldsymbol { x } ) ] + \mathbb { E } _ { \boldsymbol { z } \sim p _ { \boldsymbol { z } } ( \boldsymbol { z } ) } [ \log ( 1 – D ( G ( \boldsymbol { z } ) ) ) ]
$$

ジェネレータ

$$
\mathcal { L } _ { G } = \frac { 1 } { m } \sum _ { i = 1 } ^ { m } \log \left( 1 – D \left( G \left( \boldsymbol { z } _ { i } \right) \right) \right)
$$

ディスクリミネータ

$$
\mathcal { L } _ { D } = \frac { 1 } { m } \sum _ { i = 1 } ^ { m } \left[ \log D \left( \boldsymbol { x } _ { i } \right) + \log \left( 1 – D \left( G \left( \boldsymbol { z } _ { i } \right) \right) \right) \right]
$$

学習アルゴリズム

ジェネレータとディスクリミネータは、それぞれ独立に学習する。

初めに、ジェネレータが学習する。
次に、ディスクリミネータが学習する。

一方が、ある程度学習し終わるまで、もう一方は、学習はしない(パラメータ固定)。
これを繰り返す。

敵対的生成ネットワークの論文

  1. Generative Adversarial Networks
  2. https://arxiv.org/abs/1406.2661

参考