1. ホーム
  2. python

[解決済み】PyTorchで学習したモデルを保存する最適な方法とは?

2022-02-21 10:42:03

質問

PyTorchで学習したモデルを保存するための代替方法を探していました。今のところ、2つの代替方法を見つけました。

  1. torch.save() でモデルを保存し torch.load() を使ってモデルをロードします。
  2. model.state_dict() を使って学習済みモデルを保存し model.load_state_dict() で保存されたモデルをロードします。

私は、このようなことに遭遇しました。 ディスカッション アプローチ2がアプローチ1より推奨されているところ。

質問ですが、なぜ2番目のアプローチが好まれるのでしょうか?それは トーチ.nn モジュールにはこの2つの機能があり、それを使うことが推奨されているのでしょうか?

どのように解決するのですか?

以下のものが見つかりました。 このページ のgithub repoに掲載されているので、その内容をここに貼り付けておきます。


モデルを保存するための推奨アプローチ

モデルのシリアライズとリストアには、主に2つのアプローチがあります。

1つ目(推奨)は、モデルのパラメータのみを保存し、ロードする方法です。

torch.save(the_model.state_dict(), PATH)

じゃあ、後で。

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

2つ目は、モデル全体の保存と読み込みです。

torch.save(the_model, PATH)

じゃあ、後で。

the_model = torch.load(PATH)

しかし、この場合、シリアライズされたデータは、特定のクラス と正確なディレクトリ構造を使用するため、様々な方法で壊れる可能性があります。 を他のプロジェクトで使用したり、深刻なリファクタリングを行った後に使用します。