sklearnのclassification_reportで多クラス分類の結果を簡単に見る
はじめに
多クラス分類をしていると、「どのクラスが上手く分類できてて、どのクラスが上手く行ってないんだろう」と気になることがままあります。
そういった情報を簡単に要約して出力してくれるのがsklearnのclassification_reportで、簡単に使える割に便利なので実験中や開発中に威力を発揮します。
スポンサーリンク
※この記事はsklearn 0.19の時代に書きましたが、その後sklearn 0.20で使い方が変更されたので、2019/03/18に全面的に改稿しました。
使い方
ドキュメントを見るととても簡単そうです。
sklearn.metrics.classification_report — scikit-learn 0.20.3 documentation
sklearn.metrics.classification_report( y_true, y_pred, labels=None, target_names=None, sample_weight=None, digits=2, output_dict=False)
要するに真のラベルと予測ラベル、あとラベルに対応する名前を入れてあげればとりあえず使えます。文字列の返り値が出力になります。sample_weight, digitsはそれぞれサンプルの重みと結果に出力される桁数を表しますが、とりあえず入れなくても大した問題は普通はありません。output_dictはsklearn 0.20から追加された引数で、pandasデータフレームに変換可能な辞書を返します。
さっそく使ってみましょう。
# coding: UTF-8 import numpy as np from sklearn.datasets import load_iris from sklearn.svm import SVC from sklearn.metrics import classification_report from sklearn.model_selection import StratifiedKFold as SKF def main(): # irisでやる iris = load_iris() # svmで分類してみる svm = SVC(C=3, gamma=0.1) # 普通の交差検証 trues = [] preds = [] for train_index, test_index in SKF().split(iris.data, iris.target): svm.fit(iris.data[train_index], iris.target[train_index]) trues.append(iris.target[test_index]) preds.append(svm.predict(iris.data[test_index])) # 今回の記事の話題はここ print("iris") print(classification_report(np.hstack(trues), np.hstack(preds), target_names=iris.target_names)) if __name__ == "__main__": main()
すると、次のような出力が得られます。
iris precision recall f1-score support setosa 1.00 1.00 1.00 50 versicolor 0.96 0.98 0.97 50 virginica 0.98 0.96 0.97 50 micro avg 0.98 0.98 0.98 150 macro avg 0.98 0.98 0.98 150 weighted avg 0.98 0.98 0.98 150
precision, recall, f1-scoreという代表的な評価指標と、support(=y_trueに含まれるデータ数)が、クラスごとと全体の各種平均(後述)で出る、というのが基本的な仕組みです。
まずクラスごとの結果を見ると、setasoは100%分類できていますが、versicolorとvirginicaはどうも混ざっているようです。以前の記事でirisを二次元にした画像を作ったので、再掲します。
RGBの順でsetaso, versicolor, virginicaに対応しているはずです。ということはsetasoが綺麗に分離できてversicolorとvirginicaが混ざるというのは極めて妥当な結果ということになりそうです。
また、下にあるmicro avg, macro avg, weighted avgは、それぞれマイクロ平均、マクロ平均、サンプル数で重み付けられた平均です。
出る評価指標などの詳細については別途記事を書いたので、そちらを御覧ください。
output_dictを使って便利に集計する
sklearn 0.20ではoutput_dictという引数がこの関数に追加されました。これを使うとデフォルトの文字列ではなく辞書形式で結果を得ることができ、結果をプログラム上で取り扱うことが容易になります。
上のコードの出力部分を2行書き換えます。
from pprint import pprint pprint(classification_report(np.hstack(trues), np.hstack(preds), target_names=iris.target_names, output_dict=True))
結果はこのようになります。
{'macro avg': {'f1-score': 0.97999799979998, 'precision': 0.9801253834867282, 'recall': 0.98, 'support': 150}, 'micro avg': {'f1-score': 0.98, 'precision': 0.98, 'recall': 0.98, 'support': 150}, 'setosa': {'f1-score': 1.0, 'precision': 1.0, 'recall': 1.0, 'support': 50}, 'versicolor': {'f1-score': 0.9702970297029702, 'precision': 0.9607843137254902, 'recall': 0.98, 'support': 50}, 'virginica': {'f1-score': 0.9696969696969697, 'precision': 0.9795918367346939, 'recall': 0.96, 'support': 50}, 'weighted avg': {'f1-score': 0.9799979997999799, 'precision': 0.980125383486728, 'recall': 0.98, 'support': 150}}
この辞書の形式はpandasデータフレームに変換することも可能です。
import pandas as pd d = classification_report(np.hstack(trues), np.hstack(preds), target_names=iris.target_names, output_dict=True) df = pd.DataFrame(d) print(df)
とすると、
macro avg micro avg setosa versicolor virginica weighted avg f1-score 0.979998 0.98 1.0 0.970297 0.969697 0.979998 precision 0.980125 0.98 1.0 0.960784 0.979592 0.980125 recall 0.980000 0.98 1.0 0.980000 0.960000 0.980000 support 150.000000 150.00 50.0 50.000000 50.000000 150.000000
のようにデータフレームとして見ることができます。ここからCSV, TeX, HTML, グラフなど任意のフォーマットに変換できるので、なにかと捗ると思います。
classification_reportを使わないとしたら
このように大変便利なのですが、参考のためにこれを使わない方法も紹介しておきます。sklearn.metrics.precision_recall_fscore_supportを使います。
sklearn.metrics.precision_recall_fscore_support — scikit-learn 0.20.3 documentation
使い方はこんな感じです。
from sklearn.metrics import precision_recall_fscore_support precision_recall_fscore_support(y_true, y_pred, average=None)
結果はこんな感じになります(上のプログラムを対象に計算し、返り値をpprintしました)。
(array([1. , 0.96078431, 0.97959184]), array([1. , 0.98, 0.96]), array([1. , 0.97029703, 0.96969697]), array([50, 50, 50]))
numpy配列を格納したタプルが返ってますね。それぞれのnumpy配列がprecision, recall, fscore, supportに対応します。
まとめ
簡単に使えるので、分類結果を見てみたいときはとりあえずこれに放り込むと良いかと思います。また、sklearn 0.20からはかなり便利になったので、汎用的な分類結果集計方法としても使えるようになりました。