[解決済み] Pytorchはワンホットベクターをサポートしないのですか?
質問
Pytorchがワンホットベクトルをどのように扱うのか、非常に混乱しています。これには チュートリアル ニューラルネットワークは、出力としてワンホットベクトルを生成します。私が理解する限り、チュートリアルのニューラルネットワークの概略構造は次のようになるはずです。
しかし
labels
は一発ベクトル形式ではありません。以下のようになります。
size
print(labels.size())
print(outputs.size())
output>>> torch.Size([4])
output>>> torch.Size([4, 10])
奇跡的に
outputs
と
labels
から
criterion=CrossEntropyLoss()
であれば、全くエラーになりません。
loss = criterion(outputs, labels) # How come it has no error?
私の仮説
多分、pytorchは自動的に
labels
をワンホットベクトル形式に変換します。そこで、損失関数に渡す前に、ラベルを1-hot vectorに変換するようにしています。
def to_one_hot_vector(num_class, label):
b = np.zeros((label.shape[0], num_class))
b[np.arange(label.shape[0]), label] = 1
return b
labels_one_hot = to_one_hot_vector(10,labels)
labels_one_hot = torch.Tensor(labels_one_hot)
labels_one_hot = labels_one_hot.type(torch.LongTensor)
loss = criterion(outputs, labels_one_hot) # Now it gives me error
しかし、次のようなエラーが発生しました。
RuntimeError: multi-target not supported at /opt/pytorch/pytorch/aten/src/THCUNN/generic/ClassNLLCriterion.cu:15
で、ワンホットベクターはサポートされていません。
Pytorch
? どのように
Pytorch
が計算されます。
cross entropy
に対して、2つのテンソル
outputs = [1,0,0],[0,0,1]
と
labels = [0,2]
? 今のところ全く意味不明です。
どのように解決するのですか?
PyTorchは、そのドキュメントで
CrossEntropyLoss
その
この基準は、ミニバッチサイズの1次元テンソルの各値のターゲットとして、クラスインデックス(0〜C-1)を期待するものである。
言い換えれば、それはあなたの
to_one_hot_vector
関数が概念的に組み込まれている
CEL
であり、ワンショットAPIを公開しない。ワンホットベクトルはクラスラベルを格納するのに比べてメモリ効率が悪いことに注意してください。
ワンホットベクターが与えられ、クラスラベルの形式にする必要がある場合(たとえば
CEL
を使用することができます。
argmax
のようにします。
import torch
labels = torch.tensor([1, 2, 3, 5])
one_hot = torch.zeros(4, 6)
one_hot[torch.arange(4), labels] = 1
reverted = torch.argmax(one_hot, dim=1)
assert (labels == reverted).all().item()
関連
-
pythonを使ったオフィス自動化コード例
-
Python関数の高度な応用を解説
-
PythonはWordの読み書きの変更操作を実装している
-
Python入門 openを使ったファイルの読み書きの方法
-
Python LeNetネットワークの説明とpytorchでの実装
-
[解決済み】numpy: true_divide で無効な値に遭遇
-
[解決済み】 NameError: グローバル名 'xrange' は Python 3 で定義されていません。
-
[解決済み】IndexError: invalid index to scalar variableを修正する方法
-
[解決済み】ValueError: pickleプロトコルがサポートされていません。3、python2 pickleはpython3 pickleでダンプしたファイルを読み込むことができない?
-
[解決済み] 複数の例外を1行でキャッチする(ブロックを除く)
最新
-
nginxです。[emerg] 0.0.0.0:80 への bind() に失敗しました (98: アドレスは既に使用中です)
-
htmlページでギリシャ文字を使うには
-
ピュアhtml+cssでの要素読み込み効果
-
純粋なhtml + cssで五輪を実現するサンプルコード
-
ナビゲーションバー・ドロップダウンメニューのHTML+CSSサンプルコード
-
タイピング効果を実現するピュアhtml+css
-
htmlの選択ボックスのプレースホルダー作成に関する質問
-
html css3 伸縮しない 画像表示効果
-
トップナビゲーションバーメニュー作成用HTML+CSS
-
html+css 実装 サイバーパンク風ボタン
おすすめ
-
Pythonの非常に便利な2つのデコレーターを解説
-
python string splicing.join()とsplitting.split()の説明
-
[解決済み】RuntimeWarning: 割り算で無効な値が発生しました。
-
[解決済み] 'DataFrame' オブジェクトに 'sort' 属性がない
-
[解決済み】Python elifの構文が無効です【終了しました
-
[解決済み】Pythonでgoogle APIのJSONコードを読み込むとエラーになる件
-
[解決済み】インポートエラー。モジュール名 urllib2 がない
-
[解決済み】Python Error: "ValueError: need more than 1 value to unpack" (バリューエラー:解凍に1つ以上の値が必要です
-
[解決済み】 TypeError: += でサポートされていないオペランド型: 'int' および 'list' です。
-
[解決済み】ValueError: xとyは同じサイズでなければならない