[解決済み] PyTorchのクロスエントロピー
2022-03-07 19:32:14
質問
PyTorchのクロスエントロピーの損失について、少し混乱しています。
この例で考えると
import torch
import torch.nn as nn
from torch.autograd import Variable
output = Variable(torch.FloatTensor([0,0,0,1])).view(1, -1)
target = Variable(torch.LongTensor([3]))
criterion = nn.CrossEntropyLoss()
loss = criterion(output, target)
print(loss)
損失が0になることを期待しますが、得します。
Variable containing:
0.7437
[torch.FloatTensor of size 1]
私の知る限り、クロスエントロピーはこのように計算できます。
しかし、1*log(1) = 0 という結果になってはいけないのでしょうか?
一発符号化など別の入力も試してみましたが、これは全くうまくいかないので、損失関数の入力形状は問題ないようです。
どなたか、私の間違いがどこにあるのか、教えていただけると本当にありがたいです。
ありがとうございました。
どのように解決するのですか?
あなたの例では、出力
[0, 0, 0, 1]
クロス・エントロピーの数学的定義で要求されているように、確率として扱います。 しかし、PyTorchはそれらを出力として扱い、和をとる必要はありません。
1
そのため、まず確率に変換し、ソフトマックス関数を使用する必要があります。
そこで
H(p, q)
になります。
H(p, softmax(output))
出力を翻訳する
[0, 0, 0, 1]
を確率に変換する。
softmax([0, 0, 0, 1]) = [0.1749, 0.1749, 0.1749, 0.4754]
というわけで。
-log(0.4754) = 0.7437
関連
-
PythonによるLeNetネットワークモデルの学習と予測
-
Python 人工知能 人間学習 描画 機械学習モデル作成
-
Pythonコードの可読性を向上させるツール「pycodestyle」の使い方を詳しく解説します
-
Python 入出力と高次代入の基礎知識
-
[解決済み】TypeErrorの修正方法。Unicodeオブジェクトは、ハッシュ化する前にエンコードする必要がある?
-
[解決済み】OSError: [WinError 193] %1 は有効な Win32 アプリケーションではありません。
-
[解決済み】Pythonでgoogle APIのJSONコードを読み込むとエラーになる件
-
[解決済み】NameError: 名前 'self' が定義されていません。
-
[解決済み】「OverflowError: Python int too large to convert to C long" on windows but not mac
-
[解決済み】 'numpy.float64' オブジェクトは反復可能ではない
最新
-
nginxです。[emerg] 0.0.0.0:80 への bind() に失敗しました (98: アドレスは既に使用中です)
-
htmlページでギリシャ文字を使うには
-
ピュアhtml+cssでの要素読み込み効果
-
純粋なhtml + cssで五輪を実現するサンプルコード
-
ナビゲーションバー・ドロップダウンメニューのHTML+CSSサンプルコード
-
タイピング効果を実現するピュアhtml+css
-
htmlの選択ボックスのプレースホルダー作成に関する質問
-
html css3 伸縮しない 画像表示効果
-
トップナビゲーションバーメニュー作成用HTML+CSS
-
html+css 実装 サイバーパンク風ボタン
おすすめ
-
Pythonの非常に便利な2つのデコレーターを解説
-
Python機械学習Githubが8.9Kstarsに達したモデルインタープリタLIME
-
python implement mysql add delete check change サンプルコード
-
Pythonを使って簡単なzipファイルの解凍パスワードを手作業で解く
-
Evidentlyを用いたPythonデータマイニングによる機械学習モデルダッシュボードの作成
-
風力制御におけるKS原理を深く理解するためのpythonアルゴリズム
-
[解決済み] データ型が理解できない
-
[解決済み】Python elifの構文が無効です【終了しました
-
[解決済み】「OverflowError: Python int too large to convert to C long" on windows but not mac
-
[解決済み】django インポートエラー - core.managementという名前のモジュールがない