eyecatch
Sun, Mar 25, 2018

機械学習のHello World: MNISTの分類モデルをKerasで作ってみた

機械学習のHello WorldとしてよくやられるMNISTの分類モデルをKeras on TensorFlowで作ってみた話。 (adsbygoogle = window.adsbygoogle || []).push({}); MNISTとは 手書き数字画像のラベル付きデータセット。 6万個の訓練データと1万個のテストデータからなる。 CC BY-SA 3.0で配布されているっぽい。 一つの画像は28×28ピクセルの白黒画像で、0から9のアラビア数字が書かれている。 画像とラベルがそれぞれ独特な形式でアーカイブされていて、画像一つ、ラベル一つ取り出すのも一苦労する。 Kerasとは Pythonのニューラルネットワークライブラリ。 バックエンドとしてTensorFlowかCNTKかTheanoを使う。 今回はTensorFlowを使った。 やったこと KerasのMNISTのAPIとかコードサンプルとかがあけどこれらはスルー。 MNISTのサイトにあるデータセットをダウンロードしてきて、サイトに書いてあるデータ形式の説明を見ながらサンプルを取り出すコードを書いた。 で、KerasでVGGっぽいCNNを書いて、学習させてモデルをダンプして、ダンプしたモデルをロードしてテストデータで評価するコードを書いた。 コードはGitHubに。 ネットワークアーキテクチャ 入力画像のサイズに合わせてVGGを小さくした感じのCNNを作った。 VGGは2014年に発表されたアーキテクチャで、各層に同じフィルタを使い、フィルタ数を線形増加させるシンプルな構造でありながら、性能がよく、今でもよく使われるっぽい。 VGGを図にすると以下の構造。 実際はバッチ正規化とかDropoutもやるのかも。 プーリング層は数えないで16層なので、VGG-16とも呼ばれる。 パラメータ数は1億3800万個くらいで、結構深めなアーキテクチャ。 VGG-16は244×244×3の画像を入力して1000クラスに分類するのに対し、MNISTは28×28×1を入れて10クラスに分類するので、以下のような7層版を作った。 これでパラメータ数は27万個くらい。 訓練データのサンプル数が6万個なので、パラメータ数が大分多い感じではある。 コードは以下。 inputs: Tensor = Input(shape=(IMAGE_NUM_ROWS, IMAGE_NUM_COLS, 1)) x: Tensor = Conv2D(filters=8, kernel_size=(2, 2), padding='same', activation='relu')(inputs) x = Conv2D(filters=8, kernel_size=(2, 2), padding='same', activation='relu')(x) x = MaxPooling2D(pool_size=(2, 2))(x) x = Conv2D(filters=16, kernel_size=(2, 2), padding='same', activation='relu')(x) x = Conv2D(filters=16, kernel_size=(2, 2), padding='same', activation='relu')(x) x = Conv2D(filters=16, kernel_size=(2, 2), padding='same', activation='relu')(x) x = MaxPooling2D(pool_size=(2, 2))(x) x = Flatten()(x) x = Dense(units=256, activation='relu')(x) x = Dense(units=256, activation='relu')(x) predictions: Tensor = Dense(NUM_CLASSES, activation='softmax')(x) model: Model = Model(inputs=inputs, outputs=predictions) model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy']) 訓練・評価 上記モデルを6万個のサンプルでバッチサイズ512で一周(1エポック)学習させると、Intel Core i5-6300HQプロセッサー、メモリ16GBのノートPCで28秒前後かかる。 とりあえず3エポック学習させてみる。 (ml-mnist) C:\workspace\ml-mnist>python bin\ml_mnist.py train Using TensorFlow backend.