こんにちは。レッジでデータサイエンティストをしている今村と申します。
レッジでは機械学習モデルの構築といったエンジニア色強めの業務や、クライアントへの提案といった上流工程など、AIプロジェクトにおける上流から下流まで広く関わらせていただいています。
今回は、先日弊ブログで紹介したStreamlitを用いて、GANを活用したWebアプリを作ってみました。Streamlitについては、前回記事、もしくは公式ドキュメントを是非チェックしてみてください。
GANの題材として、世界中の人に愛されているポケモンの画像を用いて、新しいポケモンを生成することを試みてみました。加えて、Webアプリ化することでGANの学習結果をインタラクティブに表現でき、見て・触って楽しいWebアプリを作ってみました。
GANとは
このような技術ブログを閲覧してくださる方には、GANの仕組みをご存知の方が多いと思うので、ここでは簡潔に説明します。
GAN(Generative Adversarial Network)とはディープラーニング における生成モデルの一種で、2014年にIan Goodfellow氏により提唱されました。画像生成の主流モデルの1つとして挙げられています。
GANは2種類のモデル(生成器;Generator・識別器;Discriminator)を使用することが大きな特徴です。生成器は乱数を元に画像を出力するモデルで、識別器は画像が実際の画像、もしくは生成された画像かということを識別するモデルです。その2つのモデルを「敵対的」に学習を行うことで、あたかも本物に近い画像を生成してしまうというものです。
ここで述べている「敵対的」学習というものは下記のような構造です。
- 生成器はなるべく本物に近い画像を生成するように学習する
- 識別器は本物の画像と、生成器によって出力された画像を精度良く識別できるように学習する
つまり、生成器からするといかに識別器をダマせる画像を作れるかということを学習します。一方で識別器からすると、生成器が作った偽物にダマされず、きちんと識別できるように学習するということになります。あたかも2つのモデルを競わせるという形で学習を行うため、敵対的と呼ばれる所以になっています。
GANと一言で言っても実は様々なGANが存在しています。今回のWebアプリでは下記のネットワークをベースに学習を行いました。
- DCGAN(Deep Convolutional GAN)
- SAGAN(Self-Attention GAN)
これらのネットワークの具体的な内容はこちらを参考にしてください。
Webアプリの説明
このWebアプリはポケモンの画像を元にGANによる実行結果のデモとなっています。実際に操作している様子は下の動画をご覧ください。
Webアプリでは以下の条件を設定することで、その条件に応じた出力結果を表示できます。
- GANの切り替え(DCGAN, SAGAN)
- エポック数(学習回数)
これらを指定することで、GANのネットワークと特定のエポック数に応じた生成画像をインタラクティブに確認できます。
また、入力値は乱数を設定しているため画面の更新のたびに異なる画像が出力されるようになっています。
なお、実装コードについてはGitHubにて公開しています。
環境
以下の環境で実装しています。
- python == 3.7.4
- streamlit == 0.61.0
- torch == 1.5.0
データセット
使用したデータセットはkvpratama氏が公開している「Pokemon Images Dataset」用いました。
全891種類のポケモンのpngデータが格納されているデータセットです。
コードの説明
GitHubに記載の手順を参照してください。
なお、リポジトリにはモデル構造や学習中のログを記録したtensorboardも合わせて格納しています。
以下ではWebアプリの実装したapp.py
について説明していきます。
まずは必要なライブラリをインポートします。ここでfrom src import models
は別ファイルでまとめているモデル構造をインポートしています。
import numpy as np import matplotlib.pyplot as plt import torch import streamlit as st import torchvision.utils as vutils from src import models
続いてグローバル変数の設定を行います。
Z_DIM
が生成モデルの入力における乱数の次元を設定します。この値は任意の数を設定します。この数値が大きいほど生成画像の多様性を高めることができます。(本アプリでは400と設定しています)
OUTPUT_IMAGE_NUM
はアプリ上に表示させる画像の枚数を指定します。
# グローバル変数 Z_DIM = 400 OUTPUT_IMAGE_NUM = 20 device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
続いてWebアプリのタイトルやテキスト、インタラクション周りの設定をおこまないます。テキストについてはそれっぽい言葉を書いてみました。
下のセットアップの定義の項目で、「GANの切り替え」と「エポック数」を選択できるようにします。GANの切り替えはセレクトボックス、エポック数はスライダーを用いました。スライダーは最大値、最小値、値の間隔をそれぞれ引数で指定できます。
Webアプリ上で選択したGANをexp
、エポック数をepoch
という変数に格納します。
# タイトルとディスクリプションの設定 st.title('Pokémon GAN') st.markdown('---') st.markdown('This app is a demo of GAN using Pokémon images.') st.markdown('In this demo, the DCGAN and SAGAN models can be used to generate images.') st.markdown('By selecting the type of model and the number of training epochs, ' 'you can see the results of the Generative Images from GAN.') st.markdown('Maybe, you can find a new Pokémon... Enjoy!!') st.markdown('---') # セットアップの定義 st.sidebar.subheader('Setup') exp = st.sidebar.selectbox('Select GAN', ('DCGAN', 'SAGAN')) epoch = st.sidebar.slider('Select Epoch', min_value=0, max_value=1000, step=40)
上で指定したexp
とepoch
を元に学習済みモデルを読み込みます。
# モデルの初期化関数 @st.cache def model_init(weight_path, exp): """ モデルの初期化関数 Parameters ---------- weight_path : str 学習済みモデルの重みのパス exp : str モデルの種類(DCGAN, SAGAN) Returns ------- G : torch.nn.Module 学習済みモデル """ if 'DCGAN' in exp: G = models.Generator_dcgan(z_dim=Z_DIM, image_size=64, out_channel=3) elif 'SAGAN' in exp: G = models.Generator_sagan(z_dim=Z_DIM, image_size=48, out_channel=3) G.load_state_dict(torch.load(weight_path, map_location=device)) G.eval() return G # モデルの準備 weight_path = f'./weights/{exp}_netG_epoch_{epoch}.pth' G = model_init(weight_path, exp) G = G.to(device)
最後に画像を生成し、出力する処理を行います。
G
は生成モデルをあらかじめ学習させていた生成モデルを表しています。
# 画像生成 z = torch.randn(OUTPUT_IMAGE_NUM, Z_DIM, 1, 1) with torch.no_grad(): out = G(z.to(device)) # 複数画像を一つに結合 img = vutils.make_grid(out.detach().cpu(), normalize=True, padding=2, nrow=5, pad_value=1) img = np.transpose(img.numpy(), [1, 2, 0]) # pyplotで図の構成を作成 plt.imshow(img) plt.axis('off') plt.tight_layout() # アプリ上で画像を表示 st.subheader('Generative Pokémon Images') st.pyplot()
以上がWebアプリの実装をまとめたコードになります。一部、説明を省略した箇所があるため、全体のコードを確認したい場合はGitHubを参照してください。
GANの出力結果
さて、ここから実際の生成画像を見ていきましょう。
まずはDCGANモデルです。
- 0エポック
最初は乱数をそのまま返しているだけなので、こういった砂嵐の画像になってしまいます。
- 80エポック
生き物みたいなものが現れました。とは言えきちんとした形とはなっておらず、なんとも形容しにくいですね。。
- 600エポック
思い切ってエポック数を600に上げてみました。かなり詳細な画像ができてきました!
羽根や足のようなものが再現されているような様子も確認できます。遠目で見ると、それっぽく見える、、かも?
形はいい感じなのかなと思ってますが、色合いの表現はまだまだ弱いですね。
- 880エポック
もっと学習させたらどうなるか、ということでエポック数を880まで上げてみました。すると、出力画像がどれも同じような画像になってしまいました。
これは俗に言う 「モード崩壊」 と呼ばれる現象で、生成モデルが似たような画像しか出力しなくなる状態のことを指します。
続いて、SAGANモデルでも同じように見てみましょう。
- 0エポック
DCGANモデルと同じ砂嵐画像です。さあここからどんな画像が出力されるのか。
- 80エポック
生き物のような形が出力されました。DCGANの80エポックと比較するとノイズが少ない鮮明な画像が出力できていることが分かります。とは言え、これだとよくわからないので学習を進めていきます。
- 280エポック
より複雑な形をした画像を出力できました。こちらも遠目で見ると、それっぽく見えますかね、、?
- 520エポック
さらに学習させていくと、SAGANでもモード崩壊を引き起こしてしまいました。
モード崩壊について
先述したモード崩壊とは、GANの生成モデルが偏った画像しか出力しない現象を指します。このモード崩壊はGANの学習の不安定性の1つとして挙げられています。詳細について知りたい方はこちらを確認してみてください。
本稿ではモード崩壊の詳細を割愛させていただき、モード崩壊が起きたタイミングを学習時のLossの分布を見て簡単な考察をしてみます。
DCGANとSAGANの生成モデルを学習した際のLossの変動をそれぞれ可視化してみました。下のグラフは横軸をエポック数、縦軸をLossで表現したグラフです。今回のWebアプリの元になっているモデルは最大1000エポックの学習を行なっています。
まずはDCGANについて見てみます。
グラフを見ると、700エポック付近からLossが大きく発散していることが分かります。
この700エポック前後の出力をWebアプリ上で確認すると、Lossが発散したタイミングとモード崩壊が発生するタイミングはほぼ一致することが分かります。この辺りについては是非ともWebアプリで見てみてください。
続いてSAGANについても見てみましょう。
SAGANについては、やや分かりにくいですが400エポック付近からLossが発散し始めており、こちらも同じように400エポック以降からモード崩壊を起こす様子がWebアプリ上で確認できます。
最後に
Streamlitのメリットはエポック数やモデルの設定をインタラクティブに変更し、その結果をアプリ上に出力できることだと思います。
このWebアプリを使ってみると、エポック数を上げていくことでより詳細な画像が生成されていくことが分かるかと思います。
特にGANなどの生成モデルや、Openpose、セマンティックセグメンテーションなど画像がメインになるアプリケーションだと真価を発揮するのではないかと思います。
加えて、このStreamlitを用いることで、このようなWebアプリがpythonのみで数十行書くだけで構築できることが非常に魅力的でした。簡易なWebアプリを作るのであればこれで十分ではないでしょうか。機械学習エンジニアやデータサイエンティストが気軽にアプリを作ることができるのは画期的だと感じました。
反省
今回作成したWebアプリでは微妙な画像しか出力することができませんでした。アウトプットとしてはだいぶお粗末な結果です。。(Streamlitの良さを知ってもらうことがメインなので温かい目で見ていただければ。。。)
GANの学習は不安定で難しい、と言われていましたがまさにその通りでした。即席で作った単純なネットワークだとやはり難しく、不安定性の解消、特にモード崩壊対策を考慮したモデルを構築、学習する必要があるなと改めて感じました。