静かなる名辞

pythonとプログラミングのこと


sklearnで混同行列をヒートマップにして描画するplot_confusion_matrix

はじめに

 scikit-learnのv0.22で、混同行列をプロットするための便利関数であるsklearn.metrics.plot_confusion_matrixが追加されました。

 使いやすそうなので試してみます。

使い方

 リファレンスはこちらです。

sklearn.metrics.plot_confusion_matrix — scikit-learn 0.22 documentation

 引数のフォーマットを見ると、

sklearn.metrics.plot_confusion_matrix(estimator, X, y_true, labels=None, sample_weight=None, normalize=None, display_labels=None, include_values=True, xticks_rotation='horizontal', values_format=None, cmap='viridis', ax=None)

 あ、予測器とXとyを入れるタイプの関数だ。なんか微妙に使いづらいですね。この時点でなんか困惑気味ですが、やってみます。

import matplotlib.pyplot as plt
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import plot_confusion_matrix


wine = load_wine()
X_train, X_test, y_train, y_test = train_test_split(
    wine.data, wine.target, stratify=wine.target)

clf = LogisticRegression()
clf.fit(X_train, y_train)
plot_confusion_matrix(clf, X_test, y_test,
                      display_labels=wine.target_names, 
                      cmap=plt.cm.Blues,)
plt.savefig("result.png")

result.png
result.png

 ソースを見る限り、内部で交差検証などしてくれる訳ではないようなので、学習済みモデルとテストデータを渡してプロットさせます。また、labelsという引数がありますが微妙に罠っぽくて、表示に使われるラベルはdisplay_labelsの方です。

 一応各引数の説明など。

  • estimator, X, y_true

 は、説明要らないよね。上で示したのと同じ使い方をします。

  • labels

 ラベルの順序を並び替えたり、一部のラベルのみ取り出してプロットしたいとき使うそうです。y_trueの中身が[0,0,0,1,1,1]なら[0,1]や[1,0]などが指定できます。別に要らないでしょう。

  • sample_weightarray-like of shape (n_samples,), default=None

 サンプルの重み。

  • normalize{"true", "pred", "all"}, default=None

 全体を正規化するかどうか。するならその方法を文字列で指定します。

  • display_labels

 表示されるラベルの名前はこちらで指定します。使用頻度は高いはずです。

  • include_valuesbool, default=True

 Falseに設定すると数字が出てこなくなります。普通は数字があったほうが好ましいでしょう。

  • xticks_rotation{"vertical", "horizontal"}

 x軸のラベルが回転するかどうか。デフォルトでは回転しません。

  • values_format

 "d"や".2f"などが指定できる。表示の書式で、format関数などに準じると思われる。

  • cmap, ax

 matplotlib関連です。デフォルトのcmapの"viridis"がexampleで「ダサいよねこれ」とBluesにされているあたり泣けます。 

使いづらいのでConfusionMatrixDisplayを使うことにする

 なんで混同行列を描くためだけにpredictメソッド走らせなきゃいかんのだと思ったので、仕様を確認します。すると、ConfusionMatrixDisplayなるクラスがあることがわかります。

It is recommend to use plot_confusion_matrix to create a ConfusionMatrixDisplay.
sklearn.metrics.ConfusionMatrixDisplay — scikit-learn 0.22 documentation

 うるせえ、あるもんは使うんじゃ。

 インスタンスを作ってplotメソッドを呼ぶと動きます。引数はだいたい上と共通ですが、インスタンスを作るときに

class sklearn.metrics.ConfusionMatrixDisplay(confusion_matrix, display_labels)

 なので柔軟性が多少上がります。

import matplotlib.pyplot as plt
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay


wine = load_wine()
X_train, X_test, y_train, y_test = train_test_split(
    wine.data, wine.target, stratify=wine.target)

clf = LogisticRegression()
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
cmx = confusion_matrix(y_test, y_pred)
cmd = ConfusionMatrixDisplay(cmx, wine.target_names)
cmd.plot()
plt.savefig("result.png")

 結果の図は同じなので省略。任意の予測ラベルで描画しようと思えばできます。ちょっと微妙な感じもしますが、許容範囲でしょう。

まとめ

 このようなものができましたので、今後は混同行列のプロットではそんなに困らないと思います。

 あとどうでもいい話、そもそもこのブログは混同行列の描き方がわからなくて調べてまとめたのが始まりですが、

www.haya-programming.com

 なんとなく原点回帰した感があって感動しています。