1. ホーム
  2. python

[解決済み] numpyで行列の乗算を一括処理

2022-03-03 02:18:09

質問

2つのnumpy配列があります ab 形状の [5, 5, 5][5, 5] それぞれ 両者とも ab の場合、shapeの最初のエントリはバッチサイズです。行列の乗算オプションを実行すると、形状の配列が得られます。 [5, 5, 5] . MWEは以下の通りです。

import numpy as np

a = np.ones((5, 5, 5))
b = np.random.randint(0, 10, (5, 5))
c = a @ b
# c.shape is (5, 5, 5)

仮に、バッチサイズのループを実行するとしたら、つまり a[0] @ b[0].T の配列が生成されます。 [5, 1] . 最後に、軸 1 に沿ってすべての結果を連結すると、結果としての配列は、形状 [5, 5] . 以下のコードでは、これらの行をよりよく説明しています。

a = np.ones((5, 5, 5))
b = np.random.randint(0, 10, (5, 5))
c = []
for i in range(5):
    c.append(a[i] @ b[i].T)
c = np.concatenate([d[:, None] for d in c], axis=1).T
# c.shape evaluates to be (5, 5)

loopを使わずに上記の機能を得ることはできますか?例えば、PyTorchには torch.bmm を計算することができます。ありがとうございます。

どのように解決するのですか?

numpyのeinsumを使って計算することができます。

c = np.einsum('BNi,Bi ->BN', a, b)

Pytorchもこのeinsum関数を、構文を少し変えて提供しています。そのため、簡単に動作させることができます。他の図形も簡単に扱える。

そうすれば、転置やスクイーズ演算を気にする必要はありません。また、内部で既存の行列のコピーを作成しないので、メモリの節約にもなります。