PythonとMXNetで遊んでみる


更新日から1年以上経過しています。情報が古い可能性がございます。

社内勉強会ネタその3

DeepLearningフレームワークのMXNetを使ってみようの回。
MXNetはAWSがサポートすることになったフレームワークとして有名なやつで、
言語的にはPythonだけでなく、RやScalaなんかでも利用できます。
最近Apache Incubatorになったりして一部で話題になったりもしました。

フレームワークの説明は各自調べてもらうとして、
早速ネタに入っていこうと思います。

今回は単純パーセプトロンで表現できない、XORを学習してみます。
使用バージョンは以下の通り。
・Python 3.6.3
・numpy 1.13.3
・mxnet 0.12.0

下準備

Jupyter上で実行しているので、上のセルから順に載せていきます。
まずはインポート。

import numpy as np
import mxnet as mx

次に、途中でログを出力して欲しいので、コールバック関数を定義していきます。

def log_output(params):
    epoch = params[0]
    batch = params[1]
    name, value = params[2].get()
    print('Epoch[{0}], Batch[{1}], Validation-{2} : {3}'.format(epoch, batch, name, value))

続いてネットワークの作成。

net = mx.sym.Variable('data')
net = mx.sym.FullyConnected(data=net, name='fc1', num_hidden=2)
net = mx.sym.Activation(data=net, act_type='sigmoid')
net = mx.sym.FullyConnected(data=net, name='fc2', num_hidden=2)
net = mx.sym.SoftmaxOutput(data=net, name='softmax')

このネットワークを基にモデルを作ります。

mod = mx.mod.Module(symbol=net)

これで必要な道具の用意ができたので、データを定義していきましょう。

学習用データの作成

XORを学ばせるので、用意するデータはラベルと合わせてこんな感じですね。

x_train = np.array([[0, 0],
                    [1, 0],
                    [0, 1],
                    [1, 1]], np.float32)

y_train = np.array([0,
                    1,
                    1,
                    0], np.int32)

このデータからイテレータを作成します。

data1 = mx.io.NDArrayIter(data=x_train, label=y_train, batch_size=4, shuffle=True)
data2 = data1

これでデータの準備も完了しました。

学習

1000エポックほど回してみましょう。
OptimizerはAdamを用いてみました。
eval_end_callbackに初めに定義した関数を与えているので、
これが1エポック学習後にeval_dataに対して評価した時に呼ばれます。

mod.fit(train_data=data1,
        eval_data=data2,
        num_epoch=1000,
        optimizer='Adam',
        eval_metric='acc',
        eval_end_callback=log_output
       )

以下の様な出力が得られます。

Epoch[0], Batch[1], Validation-accuracy : 0.5
Epoch[1], Batch[1], Validation-accuracy : 0.5
Epoch[2], Batch[1], Validation-accuracy : 0.5
Epoch[3], Batch[1], Validation-accuracy : 0.5
Epoch[4], Batch[1], Validation-accuracy : 0.5
Epoch[5], Batch[1], Validation-accuracy : 0.5
   ...中略...
Epoch[996], Batch[1], Validation-accuracy : 1.0
Epoch[997], Batch[1], Validation-accuracy : 1.0
Epoch[998], Batch[1], Validation-accuracy : 1.0
Epoch[999], Batch[1], Validation-accuracy : 1.0

評価データと学習データに同じものを与えていますが、
学習の結果accuracyが1.0になっているのが出力から分かります。

確認

では本当にXORが学習できているか確認してみましょう。

test = mx.io.NDArrayIter(data=x_train, label=y_train, batch_size=4, shuffle=False)
result = mod.predict(test).asnumpy()
print(list(map(lambda x: x.argmax(), result)))

実行すると以下の出力が得られます。

[0, 1, 1, 0]

入力に対し、正しくXORができていることが確認できました。

ネットワークの可視化

MXNetにはネットワークを可視化する機能がついているので、
今回作成したネットワークを見てみましょう。

mx.viz.plot_network(net)

以下の様な画像が表示されます。

以上、MXNetで遊んで見る回でした。


コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です