1. ホーム
  2. machine-learning

[解決済み] PyTorchのバックワード関数

2022-01-29 03:56:48

質問

pytorchのbackward関数についていくつか質問があるのですが、正しい出力が得られていないと思います。

import numpy as np
import torch
from torch.autograd import Variable
a = Variable(torch.FloatTensor([[1,2,3],[4,5,6]]), requires_grad=True) 
out = a * a
out.backward(a)
print(a.grad)

の場合、出力は

tensor([[ 2.,  8., 18.],
        [32., 50., 72.]])

もしかしたら、それは 2*a*a

しかし、私は出力は次のようになると思います。

tensor([[ 2.,  4., 6.],
        [8., 10., 12.]])

2*a. 原因 d(x^2)/dx=2x

解決方法は?

のドキュメントをよくお読みください。 backward() を使用することで、より深く理解することができます。

デフォルトでは、pytorchは backward() のために呼び出される 最後の の出力、つまり損失関数です。損失関数は常にスカラーを出力します。 スカラー の損失は、他のすべての変数/パラメータとの関係でうまく定義されます(連鎖法則を使用)。

したがって、デフォルトでは backward() はスカラーテンソルに対して呼び出され、引数を必要としない。

例えば

a = torch.tensor([[1,2,3],[4,5,6]], dtype=torch.float, requires_grad=True)
for i in range(2):
  for j in range(3):
    out = a[i,j] * a[i,j]
    out.backward()
print(a.grad)

イールド

tensor([[ 2.,  4.,  6.],
        [ 8., 10., 12.]])

予想通りです。 d(a^2)/da = 2a .

しかし backward を2×3の out テンソル(もはやスカラー関数ではない) - 何を期待しているのか? a.grad になるのでしょうか?実際には、2×3×2×3の出力が必要です。 d out[i,j] / d a[k,l] (!)

Pytorchはこの非スカラー関数の導関数をサポートしていません。 その代わりに、pytorchは以下のように仮定します。 out は中間テンソルに過ぎず、どこかにスカラー損失関数が存在し、連鎖法則により、その損失関数は d loss/ d out[i,j] . この "upstream" の勾配は 2×3 のサイズであり、これは実際にあなたが提供する引数である。 backward この場合 out.backward(g) ここで g_ij = d loss/ d out_ij .

そして、グラデーションは鎖の法則によって計算されます。 d loss / d a[i,j] = (d loss/d out[i,j]) * (d out[i,j] / d a[i,j])

あなたが提供した a を上流のグラデーションとして取得しました。

a.grad[i,j] = 2 * a[i,j] * a[i,j]

もし、"upstream"のグラデーションをすべて1つずつ用意するのであれば

out.backward(torch.ones(2,3))
print(a.grad)

イールド

tensor([[ 2.,  4.,  6.],
        [ 8., 10., 12.]])

予想通りです。

チェーンルールにある