1. ホーム
  2. pytorch

pytorchのSpeat()関数

2022-02-17 19:57:52
<パス

pytorchのrepeat()関数はテンソルのコピーを作成します。

引数が2つだけの場合、第1引数は複製後の列数、第2引数は複製後の行数を示す。
パラメータが3つの場合、第1パラメータはコピー後のチャンネル数、第2パラメータはコピー後のカラム数、第3パラメータはコピー後の行数を示す。

次に、例を挙げてイメージしてみましょう。

>>> x = torch.tensor([6,7,8])
>>> x.repeat(4,2)
tensor([[6, 7, 8, 6, 7, 8],
        [6, 7, 8, 6, 7, 8],
        [6, 7, 8, 6, 7, 8],
        [6, 7, 8, 6, 7, 8]])
>>> x.repeat(4,2,1)
tensor([[6, 7, 8],
         [6, 7, 8]],

        [[6, 7, 8]],
         [6, 7, 8]],

        [[6, 7, 8],
         [6, 7, 8]],

        [[6, 7, 8],
         [6, 7, 8]]])
>>> x.repeat(4,2,1).size()
torch.Size([4, 2, 3])