1. ホーム
  2. machine-learning

[解決済み] tf.reset_default_graph() の使用方法

2022-02-28 22:12:38

質問

を使おうとすると、必ず tf.reset_default_graph() このようなエラーが発生します。 IndexError: list index out of range または `` です。コードのどの部分でこれを使うべきですか?いつこれを使うべきですか?

編集してください。

コードを更新しましたが、まだエラーは発生します。

def evaluate():
    with tf.name_scope("loss"):
        global x # x is a tf.placeholder()
        xentropy = tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=neural_network(x))
        loss = tf.reduce_mean(xentropy, name="loss")

    with tf.name_scope("train"):
        optimizer = tf.train.AdamOptimizer()
        training_op = optimizer.minimize(loss)

    with tf.name_scope("exec"):
        with tf.Session() as sess:
            for i in range(1, 2):
                sess.run(tf.global_variables_initializer())
                sess.run(training_op, feed_dict={x: np.array(train_data).reshape([-1, 1]), y: label})
                print "Training " + str(i)
                saver = tf.train.Saver()
                saver.save(sess, "saved_models/testing")
                print "Model Saved."


def predict():
    with tf.name_scope("predict"):
        tf.reset_default_graph()
        with tf.Session() as sess:
            saver = tf.train.import_meta_graph("saved_models/testing.meta")
            saver.restore(sess, "saved_models/testing")
            output_ = tf.get_default_graph().get_tensor_by_name('output_layer:0')
            print sess.run(output_, feed_dict={x: np.array([12003]).reshape([-1, 1])})


def main():
    print "Starting Program..."
    evaluate()
    writer = tf.summary.FileWriter("mygraph/logs", tf.get_default_graph())
    predict()

更新したコードからtf.reset_default_graph()を削除すると、以下のエラーが発生します。 ValueError: cannot add op with name hidden_layer1/kernel/Adam as that name is already used

私の現在の理解では、tf.reset_default_graph()はすべてのグラフを削除するので、上記のようなエラーは避けられました( ValueError: cannot add op with name hidden_layer1/kernel/Adam as that name is already used )

解決方法は?

このような使い方があるのでしょう。

import tensorflow as tf
a = tf.constant(1)
with tf.Session() as sess:
    tf.reset_default_graph()

セッションで使うからエラーになるんだろ。から tf.reset_default_graph() のドキュメントをご覧ください。

<ブロッククオート

tf.Sessionまたはtf.InteractiveSessionが存在する時にこの関数を呼び出すと がアクティブな場合、未定義の動作になります。以前に作成した この関数を呼び出した後にtf.Operationまたはtf.Tensorオブジェクトを呼び出すと は、未定義の動作となります。


tf.reset_default_graph() は、jupyter notebookで実験している間、テスト段階で(少なくとも私にとっては)役に立ちます。しかし、私は本番で使ったことがなく、そこでどのように役立つかはわかりません。

ノートブックにありそうな例です。

import tensorflow as tf
# create some graph
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print sess.run(...)

しかし、別のグラフを作成してtensorboardで可視化すると、古いノードと新しいノードが表示されます。これを解決するには、カーネルを再起動し、次のセルだけを実行すればよい。しかし、私はただそうすることができる。

tf.reset_default_graph()
# create a new graph
with tf.Session() as sess:
    print sess.run(...)

OPがコードを追加した後に編集 :

with tf.name_scope("predict"):
    tf.reset_default_graph()

おおよそこんな感じです。あなたのコードが失敗するのは tf.name_scope はすでにグラフに何かを追加しています。この "adding something to the graph" の中にいるとき、TFにグラフを完全に削除するように言いますが、何かを追加するのに忙しいので、削除することができません。