[解決済み] 'tensorflow.python.framework.ops.EagerTensor'オブジェクトには属性 '_in_graph_mode' がありません。
質問
CNNフィルターを視覚化するために、ランダムな「画像」を最適化し、そのフィルターに高い平均活性化を生じさせようとしているが、これはニューロスタイル転送アルゴリズムと何らかの類似性がある。
そのために、TensorFlow==2.2.0-rc を使っています。しかし、最適化の過程で、以下のようなエラーが発生します。
'tensorflow.python.framework.ops.EagerTensor' object has no attribute '_in_graph_mode'
. デバッグしてみると、どうやら
opt.apply_gradients()
のように手動でグラデーションを適用します。
img = img - lr * grads
しかし、私は単純なSGDではなく、"Adam"オプティマイザを使いたいのです。
以下は、最適化部分のコードです。
opt = tf.optimizers.Adam(learning_rate=lr, decay = 1e-6)
for _ in range(epoch):
with tf.GradientTape() as tape:
tape.watch(img)
y = model(img)[:, :, :, filter]
loss = -tf.math.reduce_mean(y)
grads = tape.gradient(loss, img)
opt.apply_gradients(zip([grads], [img]))
解決方法は?
バグの原因は、tf.kerasのオプティマイザが変数オブジェクト(tf.Variable型)に勾配を適用し、あなたがテンソル(tf.Tensor型)に勾配を適用しようとすることにあります。TensorオブジェクトはTensorFlowではmutableではないので、オプティマイザはこれに勾配を適用することができない。
を初期化する必要があります。
img
をtf.Variableとして使用します。これがあなたのコードのあり方です。
# NOTE: The original image is lost here. If this is not desired, then you can
# rename the variable to something like img_var.
img = tf.Variable(img)
opt = tf.optimizers.Adam(learning_rate=lr, decay = 1e-6)
for _ in range(epoch):
with tf.GradientTape() as tape:
tape.watch(img)
y = model(img.value())[:, :, :, filter]
loss = -tf.math.reduce_mean(y)
grads = tape.gradient(loss, img)
opt.apply_gradients(zip([grads], [img]))
また、グラデーションの計算はテープのコンテキストの外側で行うことを推奨します。これは、中に入れておくと、テープがグラデーションの計算自体をトラッキングしてしまい、メモリ使用量が多くなってしまうからです。これは、高次のグラジェントを計算したい場合にのみ望ましい。それらは必要ないので、外に置いてある。
なお、私は以下の行を変更しました。
y = model(img)[:, :, :, filter]
から
y = model(img.value())[:, :, :, filter]
. これは、tf.kerasのモデルは変数ではなく、テンソルを入力として必要とするためです(バグ、あるいは特徴?)
関連
-
PythonによるLeNetネットワークモデルの学習と予測
-
Python 人工知能 人間学習 描画 機械学習モデル作成
-
風力制御におけるKS原理を深く理解するためのpythonアルゴリズム
-
Pythonの画像ファイル処理用ライブラリ「Pillow」(グラフィックの詳細)
-
[解決済み】 TypeError: += でサポートされていないオペランド型: 'int' および 'list' です。
-
[解決済み] Pythonでオブジェクトが属性を持つかどうかを知る方法
-
[解決済み] オブジェクトの種類を決定しますか?
-
[解決済み] Pythonのクラスはなぜオブジェクトを継承するのですか?
-
[解決済み] Pythonでnullオブジェクトを参照する
-
[解決済み] エラーです。" 'dict' オブジェクトには 'iteritems' という属性がありません "
最新
-
nginxです。[emerg] 0.0.0.0:80 への bind() に失敗しました (98: アドレスは既に使用中です)
-
htmlページでギリシャ文字を使うには
-
ピュアhtml+cssでの要素読み込み効果
-
純粋なhtml + cssで五輪を実現するサンプルコード
-
ナビゲーションバー・ドロップダウンメニューのHTML+CSSサンプルコード
-
タイピング効果を実現するピュアhtml+css
-
htmlの選択ボックスのプレースホルダー作成に関する質問
-
html css3 伸縮しない 画像表示効果
-
トップナビゲーションバーメニュー作成用HTML+CSS
-
html+css 実装 サイバーパンク風ボタン
おすすめ
-
Python 人工知能 人間学習 描画 機械学習モデル作成
-
python implement mysql add delete check change サンプルコード
-
Python LeNetネットワークの説明とpytorchでの実装
-
Python Pillow Image.save jpg画像圧縮問題
-
[解決済み] _tkinter.TclError: 表示名がなく、$DISPLAY環境変数もない。
-
[解決済み】TypeErrorの修正方法。Unicodeオブジェクトは、ハッシュ化する前にエンコードする必要がある?
-
[解決済み] データ型が理解できない
-
[解決済み] 'DataFrame' オブジェクトに 'sort' 属性がない
-
[解決済み] 'int'オブジェクトに'__getitem__'属性がない。
-
[解決済み】 AttributeError("'str' object has no attribute 'read'")