深層強化学習②

今日はこのページ↓

MNISTデータの取得

# 手書き数字の画像データMNISTをダウンロード
 
from sklearn.datasets import fetch_mldata
 
mnist = fetch_mldata('MNIST original', data_home=".")  # data_homeは保存先を指定します

 ここで、HTTPError が発生

mldata.org が service is temporarily down! しているらしい。

これで変数mnistにデータが格納されました。fetch_mldata()は手書き数字の画像データとラベルデータをダウンロードするのですが、ときおりダウンロード先のサーバーの都合でうまく動かない場合があります。その場合何回か実行しているとうまくいくので、繰り返してみてください。 

 に従って、3回ぐらいトライしてみたけどだめだった。

他のサイトから、mnist-original.mat をダウンロードすればいいらしいが、どのフォルダに入れればいいかすらわからない。

たどりついたのが↓

どのフォルダかは、

from sklearn.datasets import get_data_home
print(get_data_home())

これでわかった。

~/scikit_learn_data/mldata/

に mnist-orginal.mat を置いて

from sklearn.datasets import fetch_mldata
mnist = fetch_mldata('MNIST original') 

 でエラーはなくなった。

 

# 1. データの前処理(画像データとラベルに分割し、正規化)

X = mnist.data / 255 # 0-255を0-1に正規化
y = mnist.target

 

# MNISTのデータの1つ目を可視化する

import matplotlib.pyplot as plt
% matplotlib inline

plt.imshow(X[0].reshape(28, 28), cmap='gray')
print("この画像データのラベルは{:.0f}です".format(y[0]))

この画像データのラベルは0です

f:id:kazuzo88:20180815113258p:plain

ここから先は、コピペでうまくいきそう。

本筋でないところで時間を取られて、わかったような気になってしまうのはどうかと思うけど、暇つぶしにはちょうどいいかも。