[解決済み] scikit learnでmulticlassの場合のprecision, recall, accuracy, f1-scoreはどのように計算するのでしょうか?
質問
私は感情分析の問題に取り組んでいます。データは以下のようなものです。
label instances
5 1190
4 838
3 239
1 204
2 127
つまり、私のデータは1190年からアンバランスになっています。
instances
でラベル付けされています。
5
. 分類のために、私は scikit の
SVC
. 問題は、マルチクラスの場合の精度、再現性、正確性、f1-scoreを正確に計算するために、正しい方法でデータのバランスを取る方法がわからないことです。そこで、以下のようなアプローチを試みました。
まず
wclf = SVC(kernel='linear', C= 1, class_weight={1: 10})
wclf.fit(X, y)
weighted_prediction = wclf.predict(X_test)
print 'Accuracy:', accuracy_score(y_test, weighted_prediction)
print 'F1 score:', f1_score(y_test, weighted_prediction,average='weighted')
print 'Recall:', recall_score(y_test, weighted_prediction,
average='weighted')
print 'Precision:', precision_score(y_test, weighted_prediction,
average='weighted')
print '\n clasification report:\n', classification_report(y_test, weighted_prediction)
print '\n confussion matrix:\n',confusion_matrix(y_test, weighted_prediction)
2番目
auto_wclf = SVC(kernel='linear', C= 1, class_weight='auto')
auto_wclf.fit(X, y)
auto_weighted_prediction = auto_wclf.predict(X_test)
print 'Accuracy:', accuracy_score(y_test, auto_weighted_prediction)
print 'F1 score:', f1_score(y_test, auto_weighted_prediction,
average='weighted')
print 'Recall:', recall_score(y_test, auto_weighted_prediction,
average='weighted')
print 'Precision:', precision_score(y_test, auto_weighted_prediction,
average='weighted')
print '\n clasification report:\n', classification_report(y_test,auto_weighted_prediction)
print '\n confussion matrix:\n',confusion_matrix(y_test, auto_weighted_prediction)
第三に
clf = SVC(kernel='linear', C= 1)
clf.fit(X, y)
prediction = clf.predict(X_test)
from sklearn.metrics import precision_score, \
recall_score, confusion_matrix, classification_report, \
accuracy_score, f1_score
print 'Accuracy:', accuracy_score(y_test, prediction)
print 'F1 score:', f1_score(y_test, prediction)
print 'Recall:', recall_score(y_test, prediction)
print 'Precision:', precision_score(y_test, prediction)
print '\n clasification report:\n', classification_report(y_test,prediction)
print '\n confussion matrix:\n',confusion_matrix(y_test, prediction)
F1 score:/usr/local/lib/python2.7/site-packages/sklearn/metrics/classification.py:676: DeprecationWarning: The default `weighted` averaging is deprecated, and from version 0.18, use of precision, recall or F-score with multiclass or multilabel data or pos_label=None will result in an exception. Please set an explicit value for `average`, one of (None, 'micro', 'macro', 'weighted', 'samples'). In cross validation use, for instance, scoring="f1_weighted" instead of scoring="f1".
sample_weight=sample_weight)
/usr/local/lib/python2.7/site-packages/sklearn/metrics/classification.py:1172: DeprecationWarning: The default `weighted` averaging is deprecated, and from version 0.18, use of precision, recall or F-score with multiclass or multilabel data or pos_label=None will result in an exception. Please set an explicit value for `average`, one of (None, 'micro', 'macro', 'weighted', 'samples'). In cross validation use, for instance, scoring="f1_weighted" instead of scoring="f1".
sample_weight=sample_weight)
/usr/local/lib/python2.7/site-packages/sklearn/metrics/classification.py:1082: DeprecationWarning: The default `weighted` averaging is deprecated, and from version 0.18, use of precision, recall or F-score with multiclass or multilabel data or pos_label=None will result in an exception. Please set an explicit value for `average`, one of (None, 'micro', 'macro', 'weighted', 'samples'). In cross validation use, for instance, scoring="f1_weighted" instead of scoring="f1".
sample_weight=sample_weight)
0.930416613529
しかし、このような警告が表示されます。
/usr/local/lib/python2.7/site-packages/sklearn/metrics/classification.py:1172:
DeprecationWarning: The default `weighted` averaging is deprecated,
and from version 0.18, use of precision, recall or F-score with
multiclass or multilabel data or pos_label=None will result in an
exception. Please set an explicit value for `average`, one of (None,
'micro', 'macro', 'weighted', 'samples'). In cross validation use, for
instance, scoring="f1_weighted" instead of scoring="f1"
分類器のメトリクスを正しく計算するために、アンバランスなデータを正しく扱うにはどうしたらよいでしょうか?
どのように解決するのですか?
どのウェイトが何に使われるかについて、多くの混乱があると思います。私は何があなたを悩ませているのか正確に知っているわけではないので、異なるトピックをカバーするつもりですが、我慢してください;)。
クラス ウェイト
の重みは
class_weight
パラメータは
分類器を学習します。
.
これらは
の計算には使用されません。
: クラスの重みが異なると、分類器が異なるだけで数値が異なってきます。
基本的にすべてのscikit-learnの分類器において、クラスの重みはクラスがどれだけ重要かをモデルに伝えるために使用されます。つまり、学習の間、分類器は高い重みを持つクラスを適切に分類するために特別な努力をすることになります。
どのようにそれを行うかは、アルゴリズムに依存します。もし、SVCでどのように動作するかの詳細が必要で、ドキュメントが意味をなさない場合は、遠慮なく言及してください。
測定基準
分類器ができたら、それがどの程度うまく機能しているかを知りたくなります。
ここでは、先ほどのメトリクスを使用することができます。
accuracy
,
recall_score
,
f1_score
...
通常、クラス分布が不均衡な場合、最も頻度の高いクラスを予測するだけのモデルに高いスコアを与えるため、accuracyは悪い選択とみなされます。
これらのメトリクスのすべてを詳しく説明することはしませんが、以下の例外に注意してください。
accuracy
を除いて、これらは当然クラスレベルで適用されることに注意してください。
print
にあるように、それぞれのクラスに対して定義されます。これらは次のような概念に依存しています。
true positives
または
false negative
を定義する必要がある場合、どのクラスが
正
であることを定義する必要があります。
precision recall f1-score support
0 0.65 1.00 0.79 17
1 0.57 0.75 0.65 16
2 0.33 0.06 0.10 17
avg / total 0.52 0.60 0.51 50
警告
F1 score:/usr/local/lib/python2.7/site-packages/sklearn/metrics/classification.py:676: DeprecationWarning: The
default `weighted` averaging is deprecated, and from version 0.18,
use of precision, recall or F-score with multiclass or multilabel data
or pos_label=None will result in an exception. Please set an explicit
value for `average`, one of (None, 'micro', 'macro', 'weighted',
'samples'). In cross validation use, for instance,
scoring="f1_weighted" instead of scoring="f1".
この警告が出るのは、f1-score、recall、precisionをどのように計算すべきかを定義せずに使っているためです! この質問は次のように言い換えることができます。 1 をどのように出力しますか? あなたはできる。
-
各クラスのf1-scoreの平均をとります:これは
avg / total
の結果です。これはまた マクロ の平均化です。 - 真陽性/偽陰性などのグローバルカウントを使用してf1-scoreを計算します。(各クラスの真陽性/偽陰性の数を合計する)。別名 マイクロ の平均化。
-
f1-スコアの加重平均を計算します。使用方法
'weighted'
を使用すると、クラスのサポートによってf1-スコアが重み付けされます:クラスがより多くの要素を持つほど、このクラスのf1-スコアは計算でより重要になります。
これらはscikit-learnのオプションのうちの3つです。
を選ばなければなりません。
. なので
average
の引数を指定する必要があります。
どちらを選択するかは、分類器の性能をどのように測定したいかによります。例えば、マクロ平均はクラスの不均衡を考慮しないので、クラス1のf1スコアはクラス5のf1スコアと同じように重要視されるでしょう。しかし、加重平均を使用すると、クラス5がより重要視されるようになります。
これらのメトリクスにおける引数の仕様全体は、現在scikit-learnではあまり明確ではありません。彼らはいくつかの明白でない標準的な動作を削除しており、開発者がそれに気づくように警告を発しています。
スコアの計算
最後に(ご存知の方は読み飛ばしていただいて結構です)、スコアは、分類器である が見たことのないデータ . これは非常に重要なことで、分類器のフィッティングに使用されたデータで得られたスコアは全く意味がありません。
を使って行う方法を示します。
StratifiedShuffleSplit
を使う方法です。これは、ラベルの分布を維持したまま、(シャッフルした後の)データのランダムな分割を行います。
from sklearn.datasets import make_classification
from sklearn.cross_validation import StratifiedShuffleSplit
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, classification_report, confusion_matrix
# We use a utility to generate artificial classification data.
X, y = make_classification(n_samples=100, n_informative=10, n_classes=3)
sss = StratifiedShuffleSplit(y, n_iter=1, test_size=0.5, random_state=0)
for train_idx, test_idx in sss:
X_train, X_test, y_train, y_test = X[train_idx], X[test_idx], y[train_idx], y[test_idx]
svc.fit(X_train, y_train)
y_pred = svc.predict(X_test)
print(f1_score(y_test, y_pred, average="macro"))
print(precision_score(y_test, y_pred, average="macro"))
print(recall_score(y_test, y_pred, average="macro"))
これが役立つといいのですが。
関連
-
opencvとpillowを用いた顔認証システム(デモあり)
-
Pythonによるjieba分割ライブラリ
-
Python 入出力と高次代入の基礎知識
-
[解決済み】TypeError: unhashable type: 'numpy.ndarray'.
-
[解決済み】TypeError: re.findall()でバイトのようなオブジェクトに文字列パターンを使用することはできません。)
-
[解決済み】Python elifの構文が無効です【終了しました
-
[解決済み】NameError: 名前 'self' が定義されていません。
-
[解決済み】Flaskのテンプレートが見つからない【重複あり
-
[解決済み】 'numpy.float64' オブジェクトは反復可能ではない
-
[解決済み] scikit-learnの決定木から決定規則を抽出する方法は?
最新
-
nginxです。[emerg] 0.0.0.0:80 への bind() に失敗しました (98: アドレスは既に使用中です)
-
htmlページでギリシャ文字を使うには
-
ピュアhtml+cssでの要素読み込み効果
-
純粋なhtml + cssで五輪を実現するサンプルコード
-
ナビゲーションバー・ドロップダウンメニューのHTML+CSSサンプルコード
-
タイピング効果を実現するピュアhtml+css
-
htmlの選択ボックスのプレースホルダー作成に関する質問
-
html css3 伸縮しない 画像表示効果
-
トップナビゲーションバーメニュー作成用HTML+CSS
-
html+css 実装 サイバーパンク風ボタン
おすすめ
-
Python関数の高度な応用を解説
-
Python百行で韓服サークルの画像クロールを実現する
-
Python 可視化 big_screen ライブラリ サンプル 詳細
-
Evidentlyを用いたPythonデータマイニングによる機械学習モデルダッシュボードの作成
-
[解決済み】socket.error: [Errno 48] アドレスはすでに使用中です。
-
[解決済み】TypeError: re.findall()でバイトのようなオブジェクトに文字列パターンを使用することはできません。)
-
[解決済み] 'DataFrame' オブジェクトに 'sort' 属性がない
-
[解決済み】Python Error: "ValueError: need more than 1 value to unpack" (バリューエラー:解凍に1つ以上の値が必要です
-
[解決済み】ImportError: bs4という名前のモジュールがない(BeautifulSoup)
-
[解決済み】ValueError: xとyは同じサイズでなければならない