1. ホーム
  2. python

[解決済み] PyTorchテンソルのリサイズ

2022-02-24 15:12:03

質問

現在、tensor.resize()関数を使ってテンソルを新しい形状にリサイズしています。 t = t.resize(1, 2, 3) .

これでは非推奨の警告が出ます。

非インプレイスリサイズは非推奨です

それ故に、私は tensor.resize_() 関数は、適切なインプレース置換であると思われます。しかし、これでは

grad を必要とする変数のリサイズができない

のエラーが発生しました。 に陥ることがあります。

from torch.autograd._functions import Resize
Resize.apply(t, (1, 2, 3))

これは非推奨の警告を避けるために、tensor.resize()が行うものです。 これは適切な解決策ではなく、むしろハックのように思える。 どうすれば正しく tensor.resize_() この場合

どのように解決するのですか?

代わりに tensor.reshape(new_shape) または torch.reshape(tensor, new_shape) というように

# a `Variable` tensor
In [15]: ten = torch.randn(6, requires_grad=True)

# this would throw RuntimeError error
In [16]: ten.resize_(2, 3)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-16-094491c46baa> in <module>()
----> 1 ten.resize_(2, 3)

RuntimeError: cannot resize variables that require grad


上記 RuntimeError を使用することで解決または回避することができます。 tensor.reshape(new_shape)

In [17]: ten.reshape(2, 3)
Out[17]: 
tensor([[-0.2185, -0.6335, -0.0041],
        [-1.0147, -1.6359,  0.6965]])

# yet another way of changing tensor shape
In [18]: torch.reshape(ten, (2, 3))
Out[18]: 
tensor([[-0.2185, -0.6335, -0.0041],
        [-1.0147, -1.6359,  0.6965]])