1. ホーム
  2. pytorch

torch.stack()の使用

2022-02-17 09:09:14


さっそくですが、写真をご覧ください。

この図は、a,b,cの3つの3x3 Tensorを持っています。

これら3つのテンソルの最後の次元の要素を重ね合わせて、新しいテンソルを形成したい場合

入力 d=torch.stack( (a,b,c) ,dim = 2)

のように2次元で重なり合っていることがわかります。 具体的には以下のように変化します。

d[0][0]の位置は、a[0][0]の[1]、b[0][0]の[10]、c[0][0]の[100]からなるサイズ3 [1,10,100] の新しい要素で重畳しています

つまり、dの次元は3×3×3である

ここで、dimパラメータに注目してください!!!!!!!!!!!!!!!!!!!!!(笑

2次元目でスタックしていますが、(pytorchでは0から数えるので、1次元目がdim=0、2次元目がdim=1ということです。)

しかし、スタック関数に書いたdimはdim=2であり、3次元であるため、最終結果は3次元であることになる と言うより

元の要素は3次元です。

そこで別の書き方があります。dim=-1と書いて、元々持っていたものを何でもいいので、最後の次元を指定します。

結果は、冒頭にdim=2と書いたときと同じです。


このdimパラメータの意味を理解するために別の例を追加すると、スタック作成後に要素の結果が位置する次元を指定します。

dim をそれぞれ 0, 1, 2 に設定し、その結果に注目してください。

ちょっとだけ分析?

c, dim = 0 c = [ a, b ] の時

d, dim = 1 d = [ [a[0] , b[0] ] , [a[1], b[1] ] の時

e, dim = 2, e = [ <スパン [ a[0][0]、b[0][0]です。  <スパン ]   , [ a[0][1]、b[0][1]。  <スパン ]   ,  a[0][2],b[0][2]です。 <スパン ] <スパン ] ,

[ <スパン [ a[1][0]、b[1][0]です。 ]   , [ a[1][1]、b[0][1]です。 <スパン ]   ,  [ a[1][2],b[1][2]です。 ] ] ]