1. ホーム
  2. python

[解決済み] Pytorchはワンホットベクターをサポートしないのですか?

2022-02-11 23:01:55

質問

Pytorchがワンホットベクトルをどのように扱うのか、非常に混乱しています。これには チュートリアル ニューラルネットワークは、出力としてワンホットベクトルを生成します。私が理解する限り、チュートリアルのニューラルネットワークの概略構造は次のようになるはずです。

しかし labels は一発ベクトル形式ではありません。以下のようになります。 size

print(labels.size())
print(outputs.size())

output>>> torch.Size([4]) 
output>>> torch.Size([4, 10])

奇跡的に outputslabels から criterion=CrossEntropyLoss() であれば、全くエラーになりません。

loss = criterion(outputs, labels) # How come it has no error?

私の仮説

多分、pytorchは自動的に labels をワンホットベクトル形式に変換します。そこで、損失関数に渡す前に、ラベルを1-hot vectorに変換するようにしています。

def to_one_hot_vector(num_class, label):
    b = np.zeros((label.shape[0], num_class))
    b[np.arange(label.shape[0]), label] = 1

    return b

labels_one_hot = to_one_hot_vector(10,labels)
labels_one_hot = torch.Tensor(labels_one_hot)
labels_one_hot = labels_one_hot.type(torch.LongTensor)

loss = criterion(outputs, labels_one_hot) # Now it gives me error

しかし、次のようなエラーが発生しました。

RuntimeError: multi-target not supported at /opt/pytorch/pytorch/aten/src/THCUNN/generic/ClassNLLCriterion.cu:15

で、ワンホットベクターはサポートされていません。 Pytorch ? どのように Pytorch が計算されます。 cross entropy に対して、2つのテンソル outputs = [1,0,0],[0,0,1]labels = [0,2] ? 今のところ全く意味不明です。

どのように解決するのですか?

PyTorchは、そのドキュメントで CrossEntropyLoss その

この基準は、ミニバッチサイズの1次元テンソルの各値のターゲットとして、クラスインデックス(0〜C-1)を期待するものである。

言い換えれば、それはあなたの to_one_hot_vector 関数が概念的に組み込まれている CEL であり、ワンショットAPIを公開しない。ワンホットベクトルはクラスラベルを格納するのに比べてメモリ効率が悪いことに注意してください。

ワンホットベクターが与えられ、クラスラベルの形式にする必要がある場合(たとえば CEL を使用することができます。 argmax のようにします。

import torch
 
labels = torch.tensor([1, 2, 3, 5])
one_hot = torch.zeros(4, 6)
one_hot[torch.arange(4), labels] = 1
 
reverted = torch.argmax(one_hot, dim=1)
assert (labels == reverted).all().item()