1. ホーム
  2. python

[解決済み] TensorFlowによるグラフのファイルへの保存/ファイルからの読み込み

2022-12-01 08:53:56

質問

TensorFlowのグラフをファイルに出力し、他のプログラムで読み込む方法はいくつかあるようですが、どのように動作するのか、明確な例や情報を見つけることができません。すでに分かっているのは、このようなことです。

  1. モデルの変数をチェックポイントファイル(.ckpt)に保存するのに tf.train.Saver() で保存し、後で復元できるように ( ソース )
  2. モデルを.pbファイルに保存し、再びロードするには tf.train.write_graph()tf.import_graph_def() ( ソース )
  3. .pbファイルからモデルを読み込み、再トレーニングし、Bazelを使用して新しい.pbファイルにダンプする ( ソース )
  4. グラフを凍結して、グラフと重みを一緒に保存する ( ソース )
  5. 使用方法 as_graph_def() を使用してモデルを保存し、重み/変数については定数 ( ソース )

しかし、これらの異なる方法に関するいくつかの疑問が解消されていません。

  1. チェックポイント ファイルについて、それはモデルの学習済み重みを保存するだけなのでしょうか。チェックポイント ファイルを新しいプログラムにロードして、モデルを実行するために使用することはできますか、それとも単に、ある時間/段階でモデルの重みを保存する方法として機能するのでしょうか。
  2. について tf.train.write_graph() についてですが、重み/変数も保存されていますか?
  3. Bazelについて、再トレーニングのための.pbファイルへの保存/ファイルからの読み込みしかできないのでしょうか?グラフを.pbにダンプするだけの簡単なBazelコマンドはないのでしょうか?
  4. 凍結についてですが、凍結したグラフを tf.import_graph_def() ?
  5. TensorFlowのAndroidデモでは、GoogleのInceptionモデルを.pbファイルから読み込んでいます。もし、私自身の.pbファイルを代用したい場合、どのようにすれば良いでしょうか?ネイティブのコードやメソッドを変更する必要があるのでしょうか?
  6. 一般的に、これらすべてのメソッドの違いは何ですか。あるいは、より広範に、以下の違いは何ですか? as_graph_def() /.ckpt/.pb の違いは何ですか?

要するに、私が探しているのは、グラフ (さまざまな操作など) とその重み/変数の両方をファイルに保存し、その後、グラフと重みを別のプログラムにロードして使用する (必ずしも継続/再トレーニングではない) ために使用できる方法です。

このトピックに関するドキュメントはあまり単純ではないので、どのような回答/情報でも大いに感謝されます。

どのように解決するのですか?

TensorFlowでモデルを保存する問題には多くのアプローチ方法があり、少し混乱することがあります。それぞれのサブクエスチョンを順番に見ていきましょう。

  1. チェックポイントファイル (たとえば saver.save() を呼び出すことで生成されます。 tf.train.Saver オブジェクト) には、重みと、同じプログラムで定義された他の変数だけが含まれています。それらを別のプログラムで使用するには、関連するグラフ構造を再作成する必要があります(例えば、それを再び構築するコードを実行するか、または tf.import_graph_def() を呼び出すなどして)、TensorFlowにその重みをどうするか指示する。なお saver.save() を呼び出すと MetaGraphDef このファイルにはグラフと、チェックポイントからの重みをそのグラフに関連付ける方法の詳細が含まれています。参照 チュートリアル を参照してください。

  2. tf.train.write_graph() はグラフの構造を書くだけで、重みは書きません。

  3. BazelはTensorFlowのグラフの読み書きに関係ありません。(おそらく私はあなたの質問を誤解しています。遠慮なくコメントで明確にしてください)。

  4. 凍結されたグラフは、以下の方法で読み込むことができます。 tf.import_graph_def() . この場合、重みは(通常)グラフに埋め込まれているため、別のチェックポイントを読み込む必要はありません。

  5. 主な変更点は、モデルに投入されるテンソルの名前と、モデルから取得されるテンソルの名前を更新することでしょう。TensorFlow Android デモでは、これは inputNameoutputName に渡される文字列は TensorFlowClassifier.initializeTensorFlow() .

  6. GraphDef はプログラムの構造で、通常、学習プロセスを通じて変化することはありません。チェックポイントは学習プロセスの状態のスナップショットであり、通常、学習プロセスの各ステップで変化する。そのため、TensorFlowはこれらのタイプのデータに対して異なるストレージフォーマットを使用し、低レベルAPIはそれらを保存およびロードする異なる方法を提供します。より高レベルのライブラリ、例えば MetaGraphDef というライブラリがあります。 Keras といった skflow は、モデル全体を保存・復元するためのより便利な方法を提供するために、これらのメカニズム上に構築されています。