Pytorch torch.Tensor.detach()メソッドの使い方と、指定したモジュールの重みを変更する方法
2022-02-17 11:45:14
デタッチ
detachの正式な解釈は、現在の計算グラフから分離した新しいTensorを返すことである。
返されたTensorは元のTensorと同じ記憶領域を共有するが、返されたTensorがグラディエントを必要とすることはないことに注意すること。
import torch as t
a = t.ones(10,)
b = a.detach()
print(b)
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
では、この関数は何をするのでしょうか?
-ネットワークAがTensor型の変数aを出力し、aがネットワークBの入力として渡され、損失関数を通してネットワークBのパラメータをバックプロパゲートしたいが、ネットワークAのパラメータは変更したくない場合、 detcah() メソッドを使用することができます。
a = A(input)
a = detach()
b = B(a)
loss = criterion(b, target)
loss.backward()
実例を見るには
import torch as t
x = t.ones(1, requires_grad=True)
x.requires_grad #True
y = t.ones(1, requires_grad=True)
y.requires_grad #True
x = x.detach() #After separation
x.requires_grad #False
y = x+y #tensor([2.])
y.requires_grad #I am still True
y.retain_grad() #y is not a leaf tensor, add this line
z = t.pow(y, 2)
z.backward() #backpropagate
y.grad #tensor([4.])
x.grad #None
上記のコードでは、バックプロパゲーションはyで終了し、xには到達していないので、xのgradプロパティはNoneです。
さて、ここまでモデルの重みを変更することについて話してきましたが、もうひとつ別のケースがあります。
ネットワークAがTensor型変数aを出力し、aがネットワークBの入力として渡される場合、損失関数をバックプロパゲートすることによってネットワークAのパラメータを修正したいが、ネットワークBのパラメータは修正したくない場合はどうすればよいだろうか。
Tensor.requests_gradプロパティを使用し、requests_gradをFalseに変更すれば良いのです。
for param in B.parameters():
param.requires_grad = False
a = A(input)
b = B(a)
loss = criterion(b, target)
loss.backward()
関連
-
Pytorch-1-TX2にpytorchをインストール(自分でやったよ)
-
[Centernet recurrence] AttributeError:Can't pickle local object 'get_dataset.<locals>.Dataset
-
AttributeError NoneType オブジェクトに属性データがない。
-
PytorchがNotImplementedErrorを発生させるようです。
-
ピトーチリピートの使用方法
-
pytorchのSpeat()関数
-
torch.stack()の使用
-
torch.stack()の公式解説、詳細、例題について
-
ピトーチテンソルインデックス
-
AttributeError: 'Graph' オブジェクトには 'node' という属性がありません。
最新
-
nginxです。[emerg] 0.0.0.0:80 への bind() に失敗しました (98: アドレスは既に使用中です)
-
htmlページでギリシャ文字を使うには
-
ピュアhtml+cssでの要素読み込み効果
-
純粋なhtml + cssで五輪を実現するサンプルコード
-
ナビゲーションバー・ドロップダウンメニューのHTML+CSSサンプルコード
-
タイピング効果を実現するピュアhtml+css
-
htmlの選択ボックスのプレースホルダー作成に関する質問
-
html css3 伸縮しない 画像表示効果
-
トップナビゲーションバーメニュー作成用HTML+CSS
-
html+css 実装 サイバーパンク風ボタン