1. ホーム
  2. pytorch

PyTorchのF.cross_entropy()関数

2022-02-21 05:40:36

F.cross_entropy()をPyTorchで理解する

PyTorchはクロスエントロピーを求めるために2つの一般的な関数を提供しています。

1つはF.cross_entropy()で、もう1つは

もう一つは、F.nll_entropy()で、これは

は、F.cross_entropy(input, target)のパラメータtargetについて、以下のように説明されています。
I. クロスエントロピーの式と計算手順
1. クロスエントロピーの計算式です。

H(p,q)=-i∑P(i)logQ(i)

ここで、P P は真の値、Q Q は予測された値である。
2. クロスエントロピーの計算手順
1)ステップの説明

1) predict_scoreに対してsoftmax演算を行い、その結果をpred_scores_softとして記録する。
pred_score_softをロギングし、その結果をpred_score_soft_logとして記す。
(③) pred_scores_soft_log を実測値で計算し、処理する。
という考え方です。
                                                  score→softm→log→computeの順で計算します。
2) 例を挙げて計算を説明する。

P 1 = [ 1 0 0 0 0 0 ]

Q 1 = [ 0.4 0.3 0.05 0.05 0.2 ]

H ( p , q ) = - ∑ i P ( i ) log Q ( i ) = - ( 1∗ l o g 0.4 + 0∗ l o g 0.05 + 0∗ l o g 0.05 + 0∗ l o g 0.2 ) = - l o g 0.4 ≈ 0.916 )
もし
Q 2 = [ 0.98 0.01 0 0 0.01 ]
では
H ( p , q ) = - ∑ i P ( i ) log Q ( i ) = - ( 1∗ l o g 0.98 + 0∗ l o g 0.05 + 0∗ l o g 0 + 0∗ l o g 0.01 ) = - l o g 0.98 ≒ 0.02 )

H ( p , q ) の計算と、Q1 と Q2 の P1 に対する類似度の目視観察から、Q2 は Q1 よりも P1 に類似していることがわかる
II. 公式ドキュメントからの注意点

PyTorchのF.cross_entropy()の中国語の公式ドキュメントは以下のように記述されています。

torch.nn.functional.cross_entropy(input, target, weight=None, size_average=True)

この関数はlog_softmaxとnll_lossを使用します、詳しくはCrossEntropyLossを参照してください。

共通パラメータ

<テーブル パラメータ名 形状 ノート 入力 (N,C) Cはカテゴリ数 ターゲット N 0 <= ターゲット[i] <= C-1


III. 自分自身の理解

公式ドキュメントのノートでは、targetsパラメータの説明として、torch.shapeはtorch.Size([N])、0 <= targets[i] <= C-1とされています。
ネットワークは出力を計算し、関数に送りますが、入力のtorch.shapeはtorch.Size([N,C])で、ソフトマックスとログ演算のため変わりませんが、ターゲットのtorch.shapeはtorch.Size([N])でA行列ではなくスカラーですので、上の例のようにクロスエントロピー計算をどうするのでしょうか。

インポートトーチ
Fとしてtorch.nn.functionalをインポートする。

pred_score = torch.tensor([[13., 3., 2., 5., 1,]])
                           [1., 8., 20., 2., 3.],
                           [1., 14., 3., 5., 3.]])
print(pred_score)
pred_score_soft = F.softmax(pred_score, dim=1)
print(pred_score_soft)
pred_score_soft_log = pred_score_soft.log()
print(pred_score_soft_log)

という結果になる。

tensor([[13., 3., 2., 5., 1,]])
        [ 1., 8., 20., 2., 3.],
        [ 1., 14., 3., 5., 3.]])
tensor([[9.9960e-01, 4.5382e-05, 1.6695e-05, 3.3533e-04, 6.1417e-06]],

        [5.6028e-09, 6.1442e-06, 9.9999e-01, 1.5230e-08, 4.1399e-08]。
        [2.2600e-06, 9.9984e-01, 1.6699e-05, 1.2339e-04, 1.6699e-05]]) 。
tensor([[-4.0366e-04, -1.0000e+01, -1.1000e+01, -8.0004e+00, -1.2000e+01]]),
        [1.9000e+01, -1.2000e+01, -6.1989e-06, -1.8000e+01, -1.7000e+01],
        [-1.3000e+01, -1.5904e-04, -1.1000e+01, -9.0002e+00, -1.1000e+01]))

スカラーターゲットで計算するには?
IV. 分析

F.Cross_entropy(input, target)関数はsoftmaxとlogの操作を含んでおり、つまり、ネットワークはこの2つの操作なしで送られた入力パラメータを計算する。

例えば,分類問題では,入力は torch.Size([N, C]) の行列で表現される.ここで,N はサンプル数,C はカテゴリ数,そして input[i][j] は i 番目のサンプルがカテゴリ jj に属するときの Scores と解釈することができる.Scores の値が大きければ大きいほど,そのカテゴリが j である確率は高くなる(コードブロックに反映されている).

また、一般に分類問題の結果を表形式で表現する場合、ワンショット埋め込みを使用します。例えば、手書き数字認識の分類問題では、数字の0は [ 1 0 0 0 0 0 0 0 0 0 0 0 0 0 ] と表現されることになります。
数字の3は、[ 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 ]と表される。
手書き数字認識の問題では、l o s s = ( y - y ^ ) 2 、すなわち、yの埋め込み行列からpred_probability行列の結果行列のノルムを差し引いた行列をl o s sロス損失として計算する。

しかし、ここでは、クロスエントロピーは次のように計算されます。

H ( p , q ) = - ∑ i P ( i ) log Q ( i )

ここで、Pは真値の確率行列、Qは予測値の確率行列である。

そして、P が one-hot embedding を用いる場合、i が正しく分類された場合のみ P ( i ) は 1 になり、そうでない場合は P ( i ) は 0 になる。
例えば、手書きの数字認識において、数字の3のワンショット表現は [ 0 0 0 1 0 0 0 0 0 0 0 0 0 ] となります。
クロスエントロピーについては、H ( p , q ) = - ∑ i P ( i ) l o g Q ( i ) = - P ( 3 ) l o g Q ( 3 ) = - l o g Q ( 3 ) とすることができる。

H ( p , q ) の計算は、P行列に依存せず、Pの真のクラスのインデックスにのみ関係することがわかった。
V. まとめ

つまり、私の理解では、ワンホットコーディングでは、pytorchコードの中でターゲットをワンホットの形で表現するのではなく、スカラーを直接使い、そのスカラーの値が真のカテゴリのインデックスとなる。したがって、クロスエントロピーの式は次のように表現できる。
H ( p , q ) = - ∑ i P ( i ) l o g Q ( i ) = - P ( m ) l o g Q ( m ) = - l o g Q ( m )
ここで、m m は真のカテゴリを表す。
元記事へのリンクです。 https://blog.csdn.net/wuliBob/article/details/104119616