1. ホーム
  2. python

[解決済み] torch.argmaxでdim=1が行インデックスを返すのはなぜですか?

2022-02-16 17:46:59

質問

をやっています。 argmax という PyTorch の関数が定義されている。

torch.argmax(input, dim=None, keepdim=False)

例を考えてみましょう

a = torch.randn(4, 4)
print(a)
print(torch.argmax(a, dim=1))

ここで、dim=1 を使うと、列ベクトルを検索する代わりに、以下のように行ベクトルを検索するようになります。

print(a) :   
tensor([[-1.7739,  0.8073,  0.0472, -0.4084],  
        [ 0.6378,  0.6575, -1.2970, -0.0625],  
        [ 1.7970, -1.3463,  0.9011, -0.8704],  
        [ 1.5639,  0.7123,  0.0385,  1.8410]])  

print(torch.argmax(a, dim=1))  
tensor([1, 1, 0, 3])

私の仮定では、dim = 0 は行、dim = 1 は列を表します。

どのように解決するのですか?

そろそろ 正しく理解する どのように axis または dim の引数は PyTorch で動作します。


上の図を理解すれば、次の例は理解できるはずです。

    |
    v
  dim-0  ---> -----> dim-1 ------> -----> --------> dim-1
    |   [[-1.7739,  0.8073,  0.0472, -0.4084],
    v    [ 0.6378,  0.6575, -1.2970, -0.0625],
    |    [ 1.7970, -1.3463,  0.9011, -0.8704],
    v    [ 1.5639,  0.7123,  0.0385,  1.8410]]
    |
    v

# argmax (indices where max values are present) along dimension-1
In [215]: torch.argmax(a, dim=1)
Out[215]: tensor([1, 1, 0, 3])


備考 : dim (略 '次元' に相当するものです。 '軸' を NumPy で使用する。