読者です 読者をやめる 読者になる 読者になる

静かなる名辞

pythonと読書

【python】混同行列(Confusion matrix)をヒートマップにして描画

python 機械学習 matplotlib seaborn sklearn

 pythonでラクして混同行列を描画したい(sklearnとかpandasとかseabornとか使って)という話。

 そもそもscikit-learnにはsklearn.metrics.confusion_matrixなるメソッドがあって、混同行列がほしいときはこれ使えば解決じゃん、と思う訳だが、このconfusion_matrixは2次元のnumpy配列を返すだけで「あとはユーザーが自分で描画してね♪」というメソッド。なので、とりあえずコンソールに結果を吐かせて、混同行列(の値が入った2次元配列)を確認したあと、ちょっとどう料理してやるか悩む羽目になる。

 表形式で出すのはダサいし見づらいので、ヒートマップにしようというところまではそんなに迷わないと思う。で、pythonのヒートマップの作り方はぶっちゃけよくわからない(日本語資料があまりない)。とりあえずseabornというライブラリを使えば良いらしいんだけど……。

 
 日本語で「python 混同行列 ヒートマップ」みたいな検索をすると、ぶっちゃけ楽そうな(5,6行くらいで書ける)方法がほとんど出てこない(皆無?)のだけど、Stack Overflowにはあった。

python - How can I plot a confusion matrix? - Stack Overflow


 とりあえずやり方はわかったので、使いやすいように正解ラベルと予測ラベルを受け取って描画する関数をメモしておく(骨子だけ)。

import pandas as pd
import seaborn as sn
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt

def print_cmx(y_true, y_pred):
    labels = sorted(list(set(y_true)))
    cmx_data = confusion_matrix(y_true, y_pred, labels=labels)
    
    df_cmx = pd.DataFrame(cmx_data, index=labels, columns=labels)

    plt.figure(figsize = (10,7))
    sn.heatmap(df_cmx, annot=True)
    plt.show()

 これをたとえば、

print_cmx(["a","b","c","d","e"],["a","b","c","d","e"])

 こうやって呼び出すと、
f:id:hayataka2049:20161215224000p:plain
 こんな結果が表示される。あとはmatplotlibなんで、好きにいじれば(文字大きくするとかタイトル/軸ラベル付けるとか)良いと思う。

confusion_matrixを呼ぶとき、明示的にラベルを渡して(sklearnのドキュメントにもなんか正解ラベル+予測ラベルから重複取り除いてソートして使うみたいに書いてあるから、この場合は不要かもしれないけど。自分好みの順番で出力したいときは必要)、そのラベルをpandasデータフレームに渡すindex,columnsに使いまわすのがミソ。