[解決済み] 学習済みKerasモデルをロードして学習を継続する
2022-08-01 07:40:11
質問
部分的に学習したKerasモデルを保存し、再度モデルをロードした後に学習を継続することは可能でしょうか。
この理由は、将来的にもっと多くの訓練データがあり、モデル全体を再び訓練したくないからです。
私が使用している関数は以下の通りです。
#Partly train model
model.fit(first_training, first_classes, batch_size=32, nb_epoch=20)
#Save partly trained model
model.save('partly_trained.h5')
#Load partly trained model
from keras.models import load_model
model = load_model('partly_trained.h5')
#Continue training
model.fit(second_training, second_classes, batch_size=32, nb_epoch=20)
編集1:完全に動作する例を追加
最初のデータセットで10エポック後の最後のエポックの損失は0.0748で、精度は0.9863になります。
モデルを保存、削除、再読み込みした後、2番目のデータセットで学習したモデルの損失と精度は、それぞれ0.1711と0.9504になります。
これは新しい学習データによるものでしょうか、それとも完全に再学習されたモデルによるものでしょうか?
"""
Model by: http://machinelearningmastery.com/
"""
# load (downloaded if needed) the MNIST dataset
import numpy
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense
from keras.utils import np_utils
from keras.models import load_model
numpy.random.seed(7)
def baseline_model():
model = Sequential()
model.add(Dense(num_pixels, input_dim=num_pixels, init='normal', activation='relu'))
model.add(Dense(num_classes, init='normal', activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
return model
if __name__ == '__main__':
# load data
(X_train, y_train), (X_test, y_test) = mnist.load_data()
# flatten 28*28 images to a 784 vector for each image
num_pixels = X_train.shape[1] * X_train.shape[2]
X_train = X_train.reshape(X_train.shape[0], num_pixels).astype('float32')
X_test = X_test.reshape(X_test.shape[0], num_pixels).astype('float32')
# normalize inputs from 0-255 to 0-1
X_train = X_train / 255
X_test = X_test / 255
# one hot encode outputs
y_train = np_utils.to_categorical(y_train)
y_test = np_utils.to_categorical(y_test)
num_classes = y_test.shape[1]
# build the model
model = baseline_model()
#Partly train model
dataset1_x = X_train[:3000]
dataset1_y = y_train[:3000]
model.fit(dataset1_x, dataset1_y, nb_epoch=10, batch_size=200, verbose=2)
# Final evaluation of the model
scores = model.evaluate(X_test, y_test, verbose=0)
print("Baseline Error: %.2f%%" % (100-scores[1]*100))
#Save partly trained model
model.save('partly_trained.h5')
del model
#Reload model
model = load_model('partly_trained.h5')
#Continue training
dataset2_x = X_train[3000:]
dataset2_y = y_train[3000:]
model.fit(dataset2_x, dataset2_y, nb_epoch=10, batch_size=200, verbose=2)
scores = model.evaluate(X_test, y_test, verbose=0)
print("Baseline Error: %.2f%%" % (100-scores[1]*100))
編集2:tensorflow.kerasの備考
tensorflow.kerasでは、モデル適合のパラメータnb_epochsをepochsに変更します。importsとbasemodel関数は。
import numpy
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import load_model
numpy.random.seed(7)
def baseline_model():
model = Sequential()
model.add(Dense(num_pixels, input_dim=num_pixels, activation='relu'))
model.add(Dense(num_classes, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
return model
どのように解決するのですか?
実は
model.save
は、あなたのケースで学習を再開するために必要なすべての情報を保存します。モデルをリロードすることによって損なわれる可能性があるのは、オプティマイザの状態だけです。これを確認するには、次のようにしてください。
save
でモデルをリロードし、学習データで学習させてみてください。
関連
-
[解決済み] 学習後のモデルを保存・復元する方法は?
-
[解決済み] SQLAlchemy: セッションの作成と再利用
-
[解決済み] Pythonのキャッシュライブラリはありますか?
-
[解決済み] なぜ(0-6)は-6=偽なのか?重複
-
[解決済み] tensorflowのCPUのみのインストールでダイナミックライブラリ 'cudart64_101.dll' を読み込めなかった
-
[解決済み] あるメソッドが複数の引数のうち1つの引数で呼び出されたことを保証する
-
[解決済み] djangoのQueryDictをPythonのDictに変更するには?
-
[解決済み] Pythonでリストが空かどうかをチェックする方法は?重複
-
[解決済み] Alembicアップグレードスクリプトでインサートやアップデートを実行するにはどうすればよいですか?
-
[解決済み] ファブリックタスクにパラメータを渡す
最新
-
nginxです。[emerg] 0.0.0.0:80 への bind() に失敗しました (98: アドレスは既に使用中です)
-
htmlページでギリシャ文字を使うには
-
ピュアhtml+cssでの要素読み込み効果
-
純粋なhtml + cssで五輪を実現するサンプルコード
-
ナビゲーションバー・ドロップダウンメニューのHTML+CSSサンプルコード
-
タイピング効果を実現するピュアhtml+css
-
htmlの選択ボックスのプレースホルダー作成に関する質問
-
html css3 伸縮しない 画像表示効果
-
トップナビゲーションバーメニュー作成用HTML+CSS
-
html+css 実装 サイバーパンク風ボタン
おすすめ
-
[解決済み] PythonでファイルのMD5チェックサムを計算するには?重複
-
[解決済み] Django のテストデータベースをメモリ上だけで動作させるには?
-
[解決済み] django.db.migrations.exceptions.InconsistentMigrationHistory
-
[解決済み] データフレームをソートした後にインデックスを更新する
-
[解決済み] サブフォルダからのインポートモジュール
-
[解決済み] Pythonでマルチプロセッシングキューを使うには?
-
[解決済み] Pythonで、ウェブサイトが404か200かを確認するためにurllibをどのように使用しますか?
-
[解決済み] pycharmがタブをスペースに自動変換する
-
[解決済み] Pythonの文字列書式をリストで使う
-
[解決済み] Pythonでランダムなファイル名を生成する最良の方法