Skip to content

A tiny Web-based demo app of a handwritten digit discriminator using Chainer

Notifications You must be signed in to change notification settings

ytakata69/mnist-demo

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

22 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

mnist-demo

ニューラルネット・ライブラリ Chainer と手書き数字データ集合 MNIST を使ったお遊びデモ。

描画枠に数字を手書きして送信すると,判別結果(0〜9のいずれか)が表示される。

デモサイト on Heroku

学習済みニューラルネットによる推論

  • mnist.py - 推論プログラム
  • data/mnist.model - 学習済みパラメータ

MNISTを使って訓練したニューラルネットを使って,入力画像から推論結果 (0〜9) を計算する(mnist.py を単独で実行すると,data/three.png に対する推論結果を出力する)。

ほぼチュートリアルのプログラム例の通り(PNG画像をNumPy配列に変換する部分は後述)。

学習

学習は,chainer.gitexamples/mnist/train_mnist.py を,無改造・オプション指定なしで実行しただけ(GPUがなくても6〜7分で終わる)。

  • 中間層2層, それぞれ1000ノード, 活性化関数はReLU
  • 更新アルゴリズムはAdam (勾配降下法の改良の一つ. 参考記事
  • 損失関数はsoftmax cross entropy
  • エポック数 20, バッチサイズ 100
  • (テストデータの正解率 98.2%)

train_mnist.py による学習結果は result/snapshot_iter_12000 というファイルに保存される。このファイル (snapshotファイル) はネットワークパラメータ以外の情報も含んでいるので,以下のようなコードを実行してネットワークパラメータだけ取り出す(train_mnist_custom_loop.py を使った場合はsnapshotファイルではなくmodelファイルが出力される)。

from train_mnist import MLP
from chainer import serializers

n_units = 1000
net = MLP(n_units, 10)

serializers.load_npz(
  'result/snapshot_iter_12000',
  net,
  path='updater/model:main/predictor/')

serializers.save_npz('mnist.model', net)

画像の整形とNumPy配列への変換

MNISTに合わせて,入力画像を 28×28 = 784 画素 (0.0〜1.0, 背景が0.0) の配列に変換しなければならない。 画像の変換に Pillow を使っている。

MNIST 配布元の説明に従って,以下の前処理を実行(mnist.pyinferFromImageを参照)。重心の計算はNumPyで行っている。

  1. 余白を除く
  2. 20×20の矩形にぴったり合うようアスペクト比を変えずに大きさを調節
  3. 28×28の矩形の中に重心 (center of mass) を中心にして配置

mnist.py を単独で実行すると,前処理後の画像を PIL.Image.Image.show() を使って表示する。 Webアプリ版では,ページ下部のデバッグ出力欄に,前処理後の画像と,ラベル0〜9に対するニューラルネットの出力値(出力値の降順に整列)を表示する。

Webアプリ化

  • web.py - Webアプリのメインプログラム
  • templates/index.html - メインページのHTML
  • static/* - JavaScriptやCSSなどの静的ファイル

PythonベースWebフレームワーク Flask を使ってWebアプリ化し, Heroku に載せる。

web.py の機能は以下の2つだけ。数十行でできている。

  • / がアクセスされるとメインページを返す。
  • /send にてPNGまたはJPEGファイルのアップロードを受け付け,mnist.pyinferFromImage を呼び出して,実行結果をJSONで返す。

localhostで実行する場合は,requirements.txtに書かれているライブラリをインストールした後,web.py を実行すればよい(下記)。http://localhost:5000/ でアクセスできるサーバが起動する。

$ pip install -r requirements.txt  # 依存ライブラリをインストール
$ python3 web.py                   # サーバを起動

drawingboard.js による手書きUI

index.htmldrawingboard.js による手書きユーザインタフェースを取り付けた。

drawingboard.jsのデモページ

drawingboard.js を使うのに jQuery が必要。

drawingboard.js で描いた画像は data URI 形式で取り出されるが,これをファイルアップロード形式で送信するのが若干面倒だった(参考記事)。web.py を変更して data URI を受け取るようにする,という選択肢もある。

Heroku へのデプロイ

  • Procfile - プロセス定義ファイル
    web: gunicorn web:app --log-file=-
    
  • runtime.txt - 実行環境指定ファイル (Pythonのバージョンを指定)
    python-3.6.4
    
  • requirements.txt - 依存ライブラリのリスト
    chainer==3.2.0
    Flask==0.12.2
    gunicorn==19.7.1
    numpy==1.14.1
    Pillow==5.0.0
    

上記のファイルを置いておけば,Heroku が勝手に必要なライブラリをインストールしてWebサーバ (Gunicorn) を起動してサービスを開始する。

使用ライブラリ

  • Chainer - ニューラルネットワーク・ライブラリ
  • Pillow - PIL (Python Imaging Library) 互換ライブラリ
  • Flask - Pythonベースの軽量Webフレームワーク
  • jQuery - JavaScript拡張ライブラリ.drawingboard.jsが依存
  • drawingboard.js - HTML5 canvasベースのドローソフト

参考資料

About

A tiny Web-based demo app of a handwritten digit discriminator using Chainer

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published