深層強化学習②
今日はこのページ↓
MNISTデータの取得
# 手書き数字の画像データMNISTをダウンロード
from sklearn.datasets import fetch_mldata
mnist = fetch_mldata('MNIST original', data_home=".") # data_homeは保存先を指定します
ここで、HTTPError が発生
mldata.org が service is temporarily down! しているらしい。
これで変数mnistにデータが格納されました。fetch_mldata()は手書き数字の画像データとラベルデータをダウンロードするのですが、ときおりダウンロード先のサーバーの都合でうまく動かない場合があります。その場合何回か実行しているとうまくいくので、繰り返してみてください。
に従って、3回ぐらいトライしてみたけどだめだった。
他のサイトから、mnist-original.mat をダウンロードすればいいらしいが、どのフォルダに入れればいいかすらわからない。
たどりついたのが↓
どのフォルダかは、
from sklearn.datasets import get_data_home print(get_data_home())
これでわかった。
~/scikit_learn_data/mldata/
に mnist-orginal.mat を置いて
from sklearn.datasets import fetch_mldata
mnist = fetch_mldata('MNIST original')
でエラーはなくなった。
# 1. データの前処理(画像データとラベルに分割し、正規化)
X = mnist.data / 255 # 0-255を0-1に正規化
y = mnist.target
# MNISTのデータの1つ目を可視化する
import matplotlib.pyplot as plt
% matplotlib inline
plt.imshow(X[0].reshape(28, 28), cmap='gray')
print("この画像データのラベルは{:.0f}です".format(y[0]))
この画像データのラベルは0です
ここから先は、コピペでうまくいきそう。
本筋でないところで時間を取られて、わかったような気になってしまうのはどうかと思うけど、暇つぶしにはちょうどいいかも。