1. ホーム
  2. python

[解決済み] NumPyのtranspose()メソッドは、どのように配列の軸を並べ替えるのですか?

2022-12-13 08:41:32

質問

In [28]: arr = np.arange(16).reshape((2, 2, 4))

In [29]: arr
Out[29]: 
array([[[ 0,  1,  2,  3],
        [ 4,  5,  6,  7]],

       [[ 8,  9, 10, 11],
        [12, 13, 14, 15]]])


In [32]: arr.transpose((1, 0, 2))
Out[32]: 
array([[[ 0,  1,  2,  3],
        [ 8,  9, 10, 11]],

       [[ 4,  5,  6,  7],
        [12, 13, 14, 15]]])

整数のタプルを transpose() 関数に渡すと、どうなるでしょうか?

具体的には、これは3D配列です。軸のタプルを渡すと、NumPyはどのように配列を変換するのでしょうか? (1, 0 ,2) ? これらの整数がどの行または列を参照しているか説明できますか?また、NumPyの文脈では、軸番号とは何ですか?

どのように解決するのですか?

配列を転置するために、NumPyは各軸の形状とストライドの情報を入れ替えるだけです。以下はそのストライドです。

>>> arr.strides
(64, 32, 8)

>>> arr.transpose(1, 0, 2).strides
(32, 64, 8)

転置操作は軸0と軸1の歩幅を入れ替えたことに注意してください。 これらの軸の長さも入れ替わりました(両方の長さが 2 である)。

NumPyは、新しい配列を構築するために、基礎となるメモリをどのように見るかを変更するだけでよいのです。


ストライドの可視化

ストライドの値は、配列の軸の次の値に到達するためにメモリ上を移動しなければならないバイト数を表します。

さて、3次元配列の arr はこのようになります(軸をラベル付けしてあります)。

この配列は 連続したメモリブロックに格納されます。 に格納されており、本質的には一次元です。これを3Dオブジェクトとして解釈するために、NumPyは3つの軸のいずれかに沿って移動するために、ある一定のバイト数を飛び越える必要があります。

各整数は8バイトのメモリを使用するため(int64 d型を使用)、各次元のストライド値はジャンプする必要のある値の数の8倍となります。例えば、軸1に沿って移動するには、4つの値(32バイト)がジャンプされ、軸0に沿って移動するには、8つの値(64バイト)がジャンプされる必要があります。

と書くと arr.transpose(1, 0, 2) と書くと、軸0と軸1を入れ替えています。 入れ替えた配列は次のようになります。

NumPyがすべきことは、軸0と軸1のストライド情報を交換することです(軸2は変更されません)。今、私たちは軸0よりも軸1に沿って移動するためにさらにジャンプしなければなりません。

この基本的な概念は、配列の軸のどんな並べ換えに対しても働きます。転置を処理する実際のコードはCで書かれており、次の場所で見つけることができます。 ここで .