ニューラルネット・ライブラリ Chainer と手書き数字データ集合 MNIST を使ったお遊びデモ。
描画枠に数字を手書きして送信すると,判別結果(0〜9のいずれか)が表示される。
mnist.py
- 推論プログラムdata/mnist.model
- 学習済みパラメータ
MNISTを使って訓練したニューラルネットを使って,入力画像から推論結果 (0〜9) を計算する(mnist.py
を単独で実行すると,data/three.png
に対する推論結果を出力する)。
ほぼチュートリアルのプログラム例の通り(PNG画像をNumPy配列に変換する部分は後述)。
学習は,chainer.git の examples/mnist/train_mnist.py
を,無改造・オプション指定なしで実行しただけ(GPUがなくても6〜7分で終わる)。
- 中間層2層, それぞれ1000ノード, 活性化関数はReLU
- 更新アルゴリズムはAdam (勾配降下法の改良の一つ. 参考記事)
- 損失関数はsoftmax cross entropy
- エポック数 20, バッチサイズ 100
- (テストデータの正解率 98.2%)
train_mnist.py
による学習結果は result/snapshot_iter_12000
というファイルに保存される。このファイル (snapshotファイル) はネットワークパラメータ以外の情報も含んでいるので,以下のようなコードを実行してネットワークパラメータだけ取り出す(train_mnist_custom_loop.py
を使った場合はsnapshotファイルではなくmodelファイルが出力される)。
from train_mnist import MLP
from chainer import serializers
n_units = 1000
net = MLP(n_units, 10)
serializers.load_npz(
'result/snapshot_iter_12000',
net,
path='updater/model:main/predictor/')
serializers.save_npz('mnist.model', net)
MNISTに合わせて,入力画像を 28×28 = 784 画素 (0.0〜1.0, 背景が0.0) の配列に変換しなければならない。 画像の変換に Pillow を使っている。
MNIST 配布元の説明に従って,以下の前処理を実行(mnist.py
のinferFromImage
を参照)。重心の計算はNumPyで行っている。
- 余白を除く
- 20×20の矩形にぴったり合うようアスペクト比を変えずに大きさを調節
- 28×28の矩形の中に重心 (center of mass) を中心にして配置
mnist.py
を単独で実行すると,前処理後の画像を PIL.Image.Image.show() を使って表示する。
Webアプリ版では,ページ下部のデバッグ出力欄に,前処理後の画像と,ラベル0〜9に対するニューラルネットの出力値(出力値の降順に整列)を表示する。
web.py
- Webアプリのメインプログラムtemplates/index.html
- メインページのHTMLstatic/*
- JavaScriptやCSSなどの静的ファイル
PythonベースWebフレームワーク Flask を使ってWebアプリ化し, Heroku に載せる。
web.py
の機能は以下の2つだけ。数十行でできている。
/
がアクセスされるとメインページを返す。/send
にてPNGまたはJPEGファイルのアップロードを受け付け,mnist.py
のinferFromImage
を呼び出して,実行結果をJSONで返す。
localhostで実行する場合は,requirements.txt
に書かれているライブラリをインストールした後,web.py
を実行すればよい(下記)。http://localhost:5000/
でアクセスできるサーバが起動する。
$ pip install -r requirements.txt # 依存ライブラリをインストール
$ python3 web.py # サーバを起動
index.html
に
drawingboard.js による手書きユーザインタフェースを取り付けた。
drawingboard.js を使うのに jQuery が必要。
drawingboard.js で描いた画像は data URI 形式で取り出されるが,これをファイルアップロード形式で送信するのが若干面倒だった(参考記事)。web.py
を変更して data URI を受け取るようにする,という選択肢もある。
Procfile
- プロセス定義ファイルweb: gunicorn web:app --log-file=-
runtime.txt
- 実行環境指定ファイル (Pythonのバージョンを指定)python-3.6.4
requirements.txt
- 依存ライブラリのリストchainer==3.2.0 Flask==0.12.2 gunicorn==19.7.1 numpy==1.14.1 Pillow==5.0.0
上記のファイルを置いておけば,Heroku が勝手に必要なライブラリをインストールしてWebサーバ (Gunicorn) を起動してサービスを開始する。
- Chainer - ニューラルネットワーク・ライブラリ
- Pillow - PIL (Python Imaging Library) 互換ライブラリ
- Flask - Pythonベースの軽量Webフレームワーク
- jQuery - JavaScript拡張ライブラリ.drawingboard.jsが依存
- drawingboard.js - HTML5 canvasベースのドローソフト
- Chainer v3 ビギナー向けチュートリアル
- 実践! GPUサーバでディープラーニング, 長谷川猛, Software Design 2018年3月号
- Tensorflow, MNIST and your own handwritten digits