1. ホーム
  2. python

[解決済み] PyTorchのクロスエントロピー

2022-03-07 19:32:14

質問

PyTorchのクロスエントロピーの損失について、少し混乱しています。

この例で考えると

import torch
import torch.nn as nn
from torch.autograd import Variable

output = Variable(torch.FloatTensor([0,0,0,1])).view(1, -1)
target = Variable(torch.LongTensor([3]))

criterion = nn.CrossEntropyLoss()
loss = criterion(output, target)
print(loss)

損失が0になることを期待しますが、得します。

Variable containing:
 0.7437
[torch.FloatTensor of size 1]

私の知る限り、クロスエントロピーはこのように計算できます。

しかし、1*log(1) = 0 という結果になってはいけないのでしょうか?

一発符号化など別の入力も試してみましたが、これは全くうまくいかないので、損失関数の入力形状は問題ないようです。

どなたか、私の間違いがどこにあるのか、教えていただけると本当にありがたいです。

ありがとうございました。

どのように解決するのですか?

あなたの例では、出力 [0, 0, 0, 1] クロス・エントロピーの数学的定義で要求されているように、確率として扱います。 しかし、PyTorchはそれらを出力として扱い、和をとる必要はありません。 1 そのため、まず確率に変換し、ソフトマックス関数を使用する必要があります。

そこで H(p, q) になります。

H(p, softmax(output))

出力を翻訳する [0, 0, 0, 1] を確率に変換する。

softmax([0, 0, 0, 1]) = [0.1749, 0.1749, 0.1749, 0.4754]

というわけで。

-log(0.4754) = 0.7437