MobileNetV2でCNN画像分類
機械学習をやっていると、
「データが少なすぎて、モデルの汎化性能がちっとも向上しない!」
という悩みはつきものです。特に画像分類なんて、あるカテゴリーの画像を何千枚、何万枚も集めるなんて、どれだけの費用と時間がかかるのか、考えるだけで絶望的になります。
仮に何らかの方法で必要な資料を集めることができたとしても、数万枚のデータを読み込ませて CNN を訓練するには、Google Colab の T4 GPU を使ったとしても、めちゃくちゃ時間がかかります。
転移学習とは?
そこで、今回は少ないリソースで、効率よくモデルを訓練できる、夢のような手法、転移学習 について紹介します。転移学習をざっくり説明すると、
「他の人が一生懸命訓練したモデルをちょっとばかり拝借して、自分のモデルに組み込んで楽をしよう」
という手法です。もう少しきちんと説明すると、あるタスクで学習した知識を、関連するタスクに適用する手法です。今回は画像分類で転移学習を活用しますが、転移学習は以下のような分野でも活用されています。
・自然言語処理:大規模テキストコーパスで学習した言語モデルを、特定分野(医療、法律など)のテキスト分析に適用する
・音声認識:一般的な音声認識モデルを、特定のアクセントやノイズ環境での音声認識に特化させる
ImageNetで訓練されたMobileNetV2を画像分類に活用します
今回、紹介する MobileNetV2 は、ImageNet(動物、植物、建物、乗り物、日用品など、日常生活のあらゆる物体が階層的に分類されたデータセット)を用いて、基本的な画像特徴(エッジ、テクスチャ、形状など)を事前学習しています。このため、基礎的な特徴抽出部分を再学習する必要がなく、少ないデータセット(今回のような1000枚程度の画像)でも、モデルの性能を迅速に向上させることが可能です。
それでは、さっそく、TensorFlow の CNN 画像分類で MobileNetV2 の威力を実感してもらいます。まずは必要なライブラリをまとめて読み込んでおきます。
# TensorFlow CNN画像分類 MobileNetV2 組み込みモデル # In[1] import tensorflow as tf from tensorflow.keras import layers, models import tensorflow_datasets as tfds import matplotlib.pyplot as plt
次に tensorflow_datasetsから、犬猫の画像と、正解ラベルがセットで格納されたデータを読み込みます。
# In[2]
# データセットの読み込み
# as_supervised=Trueで教師あり学習用の(input,label)形式で取得する
dataset, info = tfds.load('cats_vs_dogs', split=['train'], with_info=True, as_supervised=True)
full_dataset = dataset[0]
‘cats_vs_dogs’ は機械学習用のデータとしては比較的軽量ですが、それでも総計23262枚もの画像データを含むので、Google Colab の T4 GPU を使ったとしても、訓練に数分間かかります。練習段階で訓練の度に待ち時間があると鬱陶しいですし、今回の記事の目的は「学習済みモデルを援用すると、どれぐらい早くトレーニングできるか」を示すのが目的なので、敢えて少ないデータを扱います。というわけで、1000 個だけ抜き出します。
# In[3] # サブセットの作成 subset_size = 1000 subset_dataset = full_dataset.take(subset_size)
一部のデータを抜き出しているので、犬と猫、それぞれ同程度の画像が含まれているか心配ですよね。もちろん、天下の TensorFlow が用意した学習データなので、そのような点でぬかりはないと思いますが、心配性の私としては、こういう細かいところがどうしても気になってしまうので、念のために確認しておきます。
# In[4]
# 訓練用データのラベル分布を確認
labels = [label.numpy() for _, label in subset_dataset]
print("Cats: ", labels.count(0))
print("Dogs: ", labels.count(1))
# Cats: 515
# Dogs: 485
大丈夫そうですね。以下のコードでデータの概要が確認できます。
# In[5] # cats_vs_dogsの情報を表示 print(info) # 実行結果は長いので省略
実行結果には
'train': <SplitInfo num_examples=23262, num_shards=16>
とありますが、なぜ画像枚数が 23262 という中途半端な数字なのか気になりませんか?
info に含まれる “There are 1738 corrupted images that are dropped.” という一文に注目してください。 本来、’cats_vs_dogs’ には 25,000 枚のデータが用意されていましたが、画像形式が正しくなかったり、画像ファイルのヘッダーやメタデータが破損していたりしているので、データセットの読み込み時に 1738 枚が自動的に破棄されたということです(23,262 + 1,738 = 25,000)。まあ、知っていても知らなくても、どうでもいいような情報なので、本題に戻りましょう。訓練用とテスト用のデータ数を確認しておきます。
# In[6] # データセットのシャッフル # シャッフルバッファサイズを指定(全データ数以上でデータが完全にシャッフルされる) subset_dataset = subset_dataset.shuffle(buffer_size=subset_size, seed=42) # 訓練データセットとテストデータセットの分割(8:2の比率) train_size = int(subset_size * 0.8) # 訓練データは800枚 test_size = subset_size - train_size # 残りのテストデータは200枚 # データセットの分割 train_dataset = subset_dataset.take(train_size) # 訓練データ test_dataset = subset_dataset.skip(train_size) # テストデータ print(train_size, test_size) # 800 200
cats_vs_dogs の画像はサイズが揃っていません(”画像分類あるある” ですね!)。面倒ですがきちんと前処理をしておきましょう。画像を適当なサイズに揃えて、正規化しておきます。
# In[7]
# 画像のサイズ
IMG_SIZE = 128
# 前処理関数(リサイズと正規化)
def preprocess(image, label):
image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
image = image / 255.0 # 正規化
return image, label
1000枚でも、フルバッチで学習させるのはさすがに大変なので、トレーニング用データとテスト用データをバッチ化しておきます。
# In[8] # バッチサイズを32に指定 # トレーニングデータが800サンプル、バッチサイズが32なので、 # トレーニングデータのバッチ数は800/32=25 BATCH_SIZE = 32 # トレーニング用データとテスト用データをバッチ化 # 効率的なデータロードのために、prefetch(tf.data.AUTOTUNE)を追加しておく train_dataset = train_dataset.map(preprocess).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE) test_dataset = test_dataset.map(preprocess).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
ベースモデルとして、MobileNetV2 を読み込んで、モデルを構築します。ベースモデルの下には、プーリング層と全結合層だけ加えます。畳み込み層は必要ありません。
# In[9]
base_model = tf.keras.applications.MobileNetV2(input_shape=(IMG_SIZE, IMG_SIZE, 3),
include_top=False,
weights='imagenet')
# 事前学習済み部分を固定する(再学習させない)
base_model.trainable = False
# モデルを構築
model = models.Sequential([
base_model,
layers.GlobalAveragePooling2D(),
layers.Dense(128, activation='relu'),
layers.Dense(1, activation='sigmoid')
])
# モデルの概要を表示
model.summary()
| Layer (type) | Output Shape | Param # |
|---|---|---|
| mobilenetv2_1.00_128 (Functional) | (None, 4, 4, 1280) | 2,257,984 |
| global_average_pooling2d(GlobalAveragePooling2D) | (None, 1280) | 0 |
| dense (Dense) | (None, 128) | 163,968 |
| dense_1 (Dense) | (None, 1) | 129 |
モデルをコンパイルします。
# In[10]
# モデルのコンパイル
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
loss='binary_crossentropy',
metrics=['accuracy'])
# モデルのトレーニング
EPOCHS = 10
history = model.fit(train_dataset,
validation_data=test_dataset,
epochs=EPOCHS)
# 学習の進行状況をプロット
plt.plot(history.history['accuracy'], label='Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()
# モデルの評価
loss, accuracy = model.evaluate(test_dataset)
print(f"Test Loss: {loss}")
print(f"Test Accuracy: {accuracy}")
# Test Loss: 0.00033416482619941235
# Test Accuracy: 1.0

僅か 1000 枚の画像しか使用していないにも関わらず、5 Epoch × 25 = 125 step で Validation Accuracy が 1.000 に達しています。画像分類としては驚異的な速さです。これがどれだけ速いか知るには、普段やるみたいに、畳み込み層とプーリング層をごりごり積んで実行してみたください。たとえば、以下のコードでは 10 Epoch 学習させても、Validation Accuracy は 0.75 までしか到達しません。実行結果は省略しますが、興味のある人は試してみてください。
# In[11]
# MobileNetV2を使用しない場合のコード
# モデルの構築
model = models.Sequential([
# 畳み込み層とプーリング層
layers.Conv2D(32, (3, 3), activation='relu', input_shape=(128, 128, 3)),
layers.MaxPooling2D((2, 2)),
layers.Conv2D(64, (3, 3), activation='relu'),
layers.MaxPooling2D((2, 2)),
layers.Conv2D(128, (3, 3), activation='relu'),
layers.MaxPooling2D((2, 2)),
layers.Conv2D(128, (3, 3), activation='relu'),
layers.MaxPooling2D((2, 2)),
# 全結合層
layers.Flatten(),
layers.Dense(128, activation='relu'),
# 2クラス分類のためシグモイド関数を使用
layers.Dense(1, activation='sigmoid')
])
# モデルのコンパイル
model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy'])
# 学習の実行
EPOCHS = 10
history = model.fit(train_dataset,
validation_data=test_dataset,
epochs=EPOCHS)
# 学習の進行状況をプロット
plt.plot(history.history['accuracy'], label='Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()
コメント