Irisデータセット
機械学習用ライブラリ scikit-learn には練習用データセットがいくつか用意されています。その中の 1 つ、Iris flower data set には、Iris(アヤメ属)に属する 3 品種、setosa(セトサ)、versicolor(バージカラー)、versinica(バージニカ)の特徴量測定値とクラスデータ(品種データ)が収められています。
[画像はアイキャッチも含めて Wikipedia から引用しています]
今後しばらくは、この Irisデータを使ってニューラルネットワークに品種分類を学習させるので、今回はデータを読み込んで、ネットワークに入力できるようにデータ構造を整えておきます。
Irisデータの読み込みと加工
最初に scikit-learn の datasets モジュールをインポートして、iris のデータを読み込みます。
# In[10]
from sklearn import datasets
# irisデータをロード
iris = datasets.load_iris()
# irisのがくの長さ、がくの幅、花びらの長さ、花びらの幅
data_in = iris.data
# Setosa,Versicolor,Versinicaのクラスデータ
data_c = iris.target
iris.data には、sepal length (がくの長さ)、sepal width (がくの幅)、petal length (花弁の長さ)、petal width (花弁の幅) の測定値が収められています。データの 5 行目までを表示してみます)。
# In[11]
# data_inを5行目までを表示
print(data_in[:5])
'''
[[5.1 3.5 1.4 0.2]
[4.9 3. 1.4 0.2]
[4.7 3.2 1.3 0.2]
[4.6 3.1 1.5 0.2]
[5. 3.6 1.4 0.2]]
'''
iris.target はクラスデータ (品種データ) です。
setosa = 0, versicolor = 1, virginica = 2 で区分されています。
# In[12]
# data_cを表示
print(data_c)
'''
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2]
'''
ただし、このままではネットワークに入力できないので、次のコードで 1 of K (one hot) 表記に変換しておきます。
# In[13]
# data_cをone of Kに変換
data_c = np.identity(3, dtype = "int8")[data_c]
data_c の中身を確認しておきましょう。
# In[14]
# 変更されたdata_cを5行目まで表示
print(data_c[:5])
'''
[[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]]
'''
学習を安定させるために、data_in を「標準化」しておきます。
# In[15]
# 入力データの平均値
data_in_av = np.average(data_in, axis = 0)
# 入力データの標準偏差
data_in_sd = np.std(data_in, axis = 0)
# 入力データの標準化
data_in = (data_in - data_in_av) / data_in_sd
最後に data_in と data_c を訓練データとテストデータに分割しておきます。
# In[16]
# インデックス配列を作成
np.random.seed(10)
idx = np.arange(len(data_c))
np.random.shuffle(idx)
idx_train = idx[idx % 2 == 0]
idx_test = idx[idx % 2 != 0]
# 入力データを訓練データとテストデータに分割
data_in_train = data_in[idx_train]
data_in_test = data_in[idx_test]
# クラスデータを訓練データとテストデータに分割
data_c_train = data_c[idx_train]
data_c_test = data_c[idx_test]
Irisのクラス分布
Iris のクラス分布の様子も見ておきましょう。今回は特徴量が 4 種類あるので、「がくの長さ」と「がくの幅」、「花弁の長さ」と「花弁の幅」の組合わせで 2 枚のマップを作ることにします。
「がくの長さ」と「がくの幅」のクラス分布描画コードです。
# In[17]
# setosaの「がくの長さ」と「がくの幅」を抽出
x1 = data_in[data_c[:,0] == 1][:, 0]
y1 = data_in[data_c[:,0] == 1][:, 1]
# versicolorの「がくの長さ」と「がくの幅」を抽出
x2 = data_in[data_c[:,1] == 1][:, 0]
y2 = data_in[data_c[:,1] == 1][:, 1]
# versinicaの「がくの長さ」と「がくの幅」を抽出
x3 = data_in[data_c[:,2] == 1][:, 0]
y3 = data_in[data_c[:,2] == 1][:, 1]
# 「がくの長さ」と「がくの幅」のクラス分布
fig = plt.figure(figsize = (5, 5))
ax = fig.add_subplot(111)
ax.set_xlabel("sepal length", size = 15, labelpad = 10)
ax.set_ylabel("sepal width", size = 15, labelpad = 10)
ax.scatter(x1, y1, marker = "D", color = "green", label = "setosa")
ax.scatter(x2, y2, marker = "+", color = "darkblue", label = "versicolor")
ax.scatter(x3, y3, marker = "o", color = "darkorange", label = "versinica")
ax.legend()
plt.show()
「花弁の長さ」と「花弁の幅」のクラス分布描画コードです。
# In[18]
# setosaの「花弁の長さ」と「花弁の幅」を抽出
x1 = data_in[data_c[:,0] == 1][:, 2]
y1 = data_in[data_c[:,0] == 1][:, 3]
# versicolorの「花弁の長さ」と「花弁の幅」を抽出
x2 = data_in[data_c[:,1] == 1][:, 2]
y2 = data_in[data_c[:,1] == 1][:, 3]
# versinicaの「花弁の長さ」と「花弁の幅」を抽出
x3 = data_in[data_c[:,2] == 1][:, 2]
y3 = data_in[data_c[:,2] == 1][:, 3]
# 「花弁の長さ」と「花弁の幅」のクラス分布
fig = plt.figure(figsize = (5, 5))
ax = fig.add_subplot(111)
ax.set_xlabel("petal length", size = 15, labelpad = 10)
ax.set_ylabel("petal width", size = 15, labelpad = 10)
ax.scatter(x1, y1, marker = "D", color = "green", label = "setosa")
ax.scatter(x2, y2, marker = "+", color = "darkblue", label = "versicolor")
ax.scatter(x3, y3, marker = "o", color = "darkorange", label = "versinica")
ax.legend()
plt.show()
がくも花弁も、セトサは他の品種とは大きく異なる特徴を示しています。バージカラーとバージニカを比較すると、がくの特徴量にはほとんど差がありませんが、花弁の特徴量の境界は比較的明瞭です。
コメント