[解決済み] PyTorch autograd -- gradはスカラー出力に対してのみ暗黙的に作成することができます。
2022-02-28 10:49:24
質問
を使用しています。
autograd
ツールで
PyTorch
そして、1次元のテンソルの値に整数のインデックスでアクセスする必要があることに気がついた。このようなものだ。
def basic_fun(x_cloned):
res = []
for i in range(len(x)):
res.append(x_cloned[i] * x_cloned[i])
print(res)
return Variable(torch.FloatTensor(res))
def get_grad(inp, grad_var):
A = basic_fun(inp)
A.backward()
return grad_var.grad
x = Variable(torch.FloatTensor([1, 2, 3, 4, 5]), requires_grad=True)
x_cloned = x.clone()
print(get_grad(x_cloned, x))
以下のエラーメッセージが表示されます。
[tensor(1., grad_fn=<ThMulBackward>), tensor(4., grad_fn=<ThMulBackward>), tensor(9., grad_fn=<ThMulBackward>), tensor(16., grad_fn=<ThMulBackward>), tensor(25., grad_fn=<ThMulBackward>)]
Traceback (most recent call last):
File "/home/mhy/projects/pytorch-optim/predict.py", line 74, in <module>
print(get_grad(x_cloned, x))
File "/home/mhy/projects/pytorch-optim/predict.py", line 68, in get_grad
A.backward()
File "/home/mhy/.local/lib/python3.5/site-packages/torch/tensor.py", line 93, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/home/mhy/.local/lib/python3.5/site-packages/torch/autograd/__init__.py", line 90, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
私は一般的に、ある変数のクローン版を使用することで、その変数がグラデーション計算に使用され続けることになるのか、少し疑問を持っています。その変数自体は、事実上
A
を呼び出すと
A.backward()
のように、その操作の一部であってはならない。
この方法について、あるいは、勾配の履歴を失わず、かつ、1次元テンソルに対して
requires_grad=True
!
**編集部(9月15日):**。
res
は1〜5の2乗値を含む0次元のテンソルのリストである。1.0, 4.0, ..., 25.0]を含む一つのテンソルを連結するために、私は以下のように変更した。
return Variable(torch.FloatTensor(res))
を
torch.stack(res, dim=0)
を生成します。
tensor([ 1., 4., 9., 16., 25.], grad_fn=<StackBackward>)
.
しかし、私はこの新しいエラーに遭遇しました。
A.backward()
という行があります。
Traceback (most recent call last):
File "<project_path>/playground.py", line 22, in <module>
print(get_grad(x_cloned, x))
File "<project_path>/playground.py", line 16, in get_grad
A.backward()
File "/home/mhy/.local/lib/python3.5/site-packages/torch/tensor.py", line 93, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/home/mhy/.local/lib/python3.5/site-packages/torch/autograd/__init__.py", line 84, in backward
grad_tensors = _make_grads(tensors, grad_tensors)
File "/home/mhy/.local/lib/python3.5/site-packages/torch/autograd/__init__.py", line 28, in _make_grads
raise RuntimeError("grad can be implicitly created only for scalar outputs")
RuntimeError: grad can be implicitly created only for scalar outputs
解決方法は?
を変更しました。
basic_fun
を以下のように変更したところ、問題が解決しました。
def basic_fun(x_cloned):
res = torch.FloatTensor([0])
for i in range(len(x)):
res += x_cloned[i] * x_cloned[i]
return res
このバージョンはスカラー値を返します。
関連
-
Pythonコンテナのための組み込み汎用関数操作
-
python string splicing.join()とsplitting.split()の説明
-
PythonはWordの読み書きの変更操作を実装している
-
パッケージングツールPyinstallerの使用と落とし穴の回避
-
[解決済み】なぜ「LinAlgError: Grangercausalitytestsから「Singular matrix」と表示されるのはなぜですか?
-
[解決済み】TypeErrorを取得しました。エントリを持つ子テーブルの後に親テーブルを追加しようとすると、 __init__() missing 1 required positional argument: 'on_delete'
-
[解決済み] 'DataFrame' オブジェクトに 'sort' 属性がない
-
[解決済み】Python Error: "ValueError: need more than 1 value to unpack" (バリューエラー:解凍に1つ以上の値が必要です
-
[解決済み】LogisticRegression: Pythonでsklearnを使用して、未知のラベルタイプ: '連続'を使用しています。
-
[解決済み】cアンダースコア式`c_`は、具体的に何をするのですか?
最新
-
nginxです。[emerg] 0.0.0.0:80 への bind() に失敗しました (98: アドレスは既に使用中です)
-
htmlページでギリシャ文字を使うには
-
ピュアhtml+cssでの要素読み込み効果
-
純粋なhtml + cssで五輪を実現するサンプルコード
-
ナビゲーションバー・ドロップダウンメニューのHTML+CSSサンプルコード
-
タイピング効果を実現するピュアhtml+css
-
htmlの選択ボックスのプレースホルダー作成に関する質問
-
html css3 伸縮しない 画像表示効果
-
トップナビゲーションバーメニュー作成用HTML+CSS
-
html+css 実装 サイバーパンク風ボタン
おすすめ
-
Pythonコードの可読性を向上させるツール「pycodestyle」の使い方を詳しく解説します
-
[解決済み】TypeError: unhashable type: 'numpy.ndarray'.
-
[解決済み】pygame.error: ビデオシステムが初期化されていない
-
[解決済み】Pythonスクリプトで「Expected 2D array, got 1D array instead: 」というエラーが発生?
-
[解決済み】numpy: true_divide で無効な値に遭遇
-
[解決済み】 NameError: グローバル名 'xrange' は Python 3 で定義されていません。
-
[解決済み】「SyntaxError.Syntax」は何ですか?Missing parentheses in call to 'print'」はPythonでどういう意味ですか?
-
[解決済み】SyntaxError: デフォルト以外の引数がデフォルトの引数に続く
-
[解決済み】Python: OverflowError: 数学の範囲エラー
-
[解決済み】ImportError: bs4という名前のモジュールがない(BeautifulSoup)