1. ホーム
  2. pytorch

torch.stack()の公式解説、詳細、例題について

2022-02-17 04:52:49
<パス

一番下の【3.例】を直接見て、前の説明に戻ることができます

での pytorch には、主に2つの共通スプライシング機能があります。

  1. stack()
    cat()
    

実際には、この2つの機能は互いに補完し合い、異なるシナリオで使用されます。 cat() 参考 torch.cat() を使用しますが、この記事で取り上げるのは stack() .

機能の意味 : 使用 stack は、2つの情報を保持することができる。[1.配列]と[2.テンソル行列]の情報を保持することができ、[ ]に属します。 展開 resplicing]機能。

比喩的に言えば、データがすべて2次元の行列(平面)である場合、この1つ1つの平面を3次元(例えば時系列)の立方体に圧縮し、その長さを時系列の長さとすることができるのである。

この機能は、自然言語処理でよく見られる( NLP ) や画像畳み込みニューラルネットワーク ( CV )に含まれる。

1. stack()

公式説明:一連の入力テンソルを新しい次元に沿って連結する。一連のテンソルはすべて同じ形状であるべきである。

平たく言えば、2次元のテンソルを複数まとめて3次元のテンソルに、3次元のものを複数まとめて4次元のテンソルに...といった具合に、つまりは 積み重ねのための新しい次元の追加 .

outputs = torch.stack(inputs, dim=?) → Tensor

パラメータ

  • inputs : 接続されるテンソル列。
    注意事項 python の唯一の配列データです。 list tuple .

  • dim : 新たに追加する必要があるディメンジョン。 0 から len(outputs)
    注意事項 len(outputs) は生成されたデータの次元サイズ,すなわち outputs 次元値の

2.フォーカス

  1. 機能での入力 inputs はシーケンスのみ使用可能であり、シーケンス内のテンソル要素は必ず shape 同じである

の例 ----. [tensor_1, tensor_2,...] または (tensor_1, tensor_2,...) で、必ず

tensor_1.shape == tensor_2.shape
dim

  1. は、生成のために選択されたディメンジョンで 0<=dim<len(outputs) len(outputs) の後の出力です。 tensor の寸法は

理解できない場合は例題を見て、それから戻って読んでください。

3. 事例紹介

1. 準備 2 tensor のデータで、それぞれ shape

[3,3]
# Assume it's the output of time step T1
T1 = torch.tensor([[1, 2, 3],
        		[4, 5, 6],
        		[7, 8, 9]])
# Assume it is the output of time step T2
T2 = torch.tensor([[10, 20, 30],
        		[40, 50, 60],
        		[70, 80, 90]])

2. スタック関数のテスト

print(torch.stack((T1,T2),dim=0).shape)
print(torch.stack((T1,T2),dim=1).shape)
print(torch.stack((T1,T2),dim=2).shape)
print(torch.stack((T1,T2),dim=3).shape)
# outputs:
torch.Size([2, 3, 3])
torch.Size([3, 2, 3])
torch.Size([3, 3, 2])
'Selected dim>len(outputs), so error reported'
IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)


コードをコピーして実行することで試すことができます:スプライシングされた tensor の形状によって変化します。 dim の変更が発生します。

<テーブル 薄暗い 形状 0 [ <マーク 2 1 [3, <マーク 2 2 [3, 3, <マーク 2 3 オーバーフローエラー報告

4. まとめ

  1. 機能の役割
    この関数は stack() について 配列データ の内部テンソルです。 <マーク 次元展開

  2. プレゼンス意味。
    自然言語処理とボリューム、ニューラルネットワークでは 通常、保存するために - [シーケンシャル(順次)情報] と [テンソルの行列情報] を保存する。 このときだけ stack .

意味のために機能がある?

RNNを手書きで書いたことがある人は知っていると思うが、リカレントニューラルネットワークの出力データは、次のようになる。 list に挿入されます。 seq_len の形状は [batch_size, output_size] tensor を使用する必要があり、計算には不向きである。 stack スプライシングのために、-[1.seq_len 時間ステップ] と -[2.テンソル属性 [batch_size, output_size] ].