1. ホーム
  2. Pytorch

トーチ.スタック(リスト)

2022-02-17 04:26:35

トーチ.スタック(リスト,0)

listの各要素はテンソルの0次元の各要素である

import torch

a = torch.Tensor([[1, 3, 2], [1, 3, 2]])
b = torch.Tensor([[2, 1, 1], [2, 1, 1]])
c = torch.Tensor([[3, 2, 3], [2, 1, 1]])

my_list = [a, b, c]
print(torch.stack(my_list, 0))


The
tensor([[ 1., 3., 2.]],
         [ 1., 3., 2.]],

        [[ 2., 1., 1.]],
         [ 2., 1., 1.]],

        [[ 3., 2., 3.]],
         [ 2., 1., 1.]]])

troch.stack(list,1) リストの各要素の0次元目〜n次元目までの各要素をグループ化したもの

print(troch.stack(my_list, 1))

The
tensor([[ 1., 3., 2,
         [ 2., 1., 1.],
         [ 3., 2., 3.]],

        [[ 1., 3., 2.],
         [ 2., 1., 1.],
         [ 2., 1., 1.]]])