静かなる名辞

pythonとプログラミングのこと


sklearnとmatplotlibでiris(3クラス)の予測確率を可視化した話

はじめに

 よく分類器の性質などを把握するために、2次元で可視化している図があります。

 特に予測確率なんかを平面的に出せるとかっこいいですよね。つまり、こういうのです。

Classifier comparison — scikit-learn 0.21.3 documentation

以前の記事より
以前の記事より
君はKNN(k nearest neighbor)の本当のすごさを知らない - 静かなる名辞

 ただ、これが素直にできるのは2クラス分類までで、3クラス分類だと下のような図にしかなりません。

以前の記事より
以前の記事より
【python】高次元の分離境界をなんとか2次元で見る - 静かなる名辞

 ということでずっと諦めていたのですが、ふと思いました。

「RGBに各クラスの予測確率あてればできるじゃん」

 簡単にできると思ったら思いの外手間取ったので、備忘録として書いておきます。

まずやる

 とりあえずirisを二次元でプロットします。この辺は定石どおりにやるだけです。

import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.decomposition import PCA

def main():
    iris = load_iris()
    pca = PCA(n_components=2)
    X = pca.fit_transform(iris.data)
    ax = plt.subplot()
    ax.scatter(X[:,0], X[:,1], c=iris.target, cmap="brg")
    plt.savefig("fig1.png")

if __name__ == "__main__":
    main()

fig1.png
fig1.png

 kNNを学習させて、まずは普通に分離境界を描きます。

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.decomposition import PCA
from sklearn.neighbors import KNeighborsClassifier

def main():
    iris = load_iris()
    pca = PCA(n_components=2)
    X = pca.fit_transform(iris.data)
    ax = plt.subplot()
    ax.scatter(X[:,0], X[:,1], c=iris.target, cmap="brg")
    
    clf = KNeighborsClassifier()
    clf.fit(X, iris.target)
    XX, YY = np.meshgrid(np.arange(-5, 5, 0.025),
                         np.arange(-2, 2, 0.025))
    Z = clf.predict(np.stack([XX.ravel(), YY.ravel()], axis=1))
    ZZ = Z.reshape(XX.shape)
    ax.pcolormesh(XX, YY, ZZ, alpha=0.05, cmap="brg", shading="gouraud")

    plt.savefig("fig2.png")

if __name__ == "__main__":
    main()


 参考:matplotlibのpcolormeshでalphaを小さくすると網目が出てくる対策 - 静かなる名辞

fig2.png
fig2.png

 さ、次はpredict_probaを呼ぶ訳ですが……pcolormeshとかこの辺の関数にはRGBのデータは渡せません。

matplotlib.pyplot.pcolormesh — Matplotlib 3.1.0 documentation

 しばし思案したあと、imshowならできると思いました。

 なにも考えずに書くと下のようなコードになります。

(これは動かないので注意してください)

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.decomposition import PCA
from sklearn.neighbors import KNeighborsClassifier

def main():
    iris = load_iris()
    pca = PCA(n_components=2)
    X = pca.fit_transform(iris.data)
    ax = plt.subplot()
    ax.scatter(X[:,0], X[:,1], c=iris.target, cmap="brg")
    
    clf = KNeighborsClassifier()
    clf.fit(X, iris.target)
    XX, YY = np.meshgrid(np.arange(-5, 5, 0.025),
                         np.arange(-2, 2, 0.025))
    Z = clf.predict_proba(np.stack([XX.ravel(), YY.ravel()], axis=1))
    ZZ = Z.reshape(XX.shape + (3, ))
    ax.imshow(ZZ, alpha=0.2)

    plt.savefig("fig3.png")

if __name__ == "__main__":
    main()

fig3.png
fig3.png

 なにこれ?

 ああ、縮尺を合わせないといけないんですね。aspect, extentという引数でできそうです。

matplotlib.pyplot.imshow — Matplotlib 3.1.0 documentation

(これもちゃんと動かないので注意してください)

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.decomposition import PCA
from sklearn.neighbors import KNeighborsClassifier

def main():
    iris = load_iris()
    pca = PCA(n_components=2)
    X = pca.fit_transform(iris.data)
    ax = plt.subplot()
    ax.set_xlim((-5, 5))
    ax.set_ylim((-2, 2))
    ax.scatter(X[:,0], X[:,1], c=iris.target, cmap="brg")
    
    clf = KNeighborsClassifier()
    clf.fit(X, iris.target)
    XX, YY = np.meshgrid(np.arange(-5, 5, 0.025),
                         np.arange(-2, 2, 0.025))
    Z = clf.predict_proba(np.stack([XX.ravel(), YY.ravel()], axis=1))
    ZZ = Z.reshape(XX.shape + (3, ))
    ax.imshow(ZZ, alpha=0.2,
              aspect="auto", extent=(-5, 5, -2, 2))

    plt.savefig("fig4.png")

if __name__ == "__main__":
    main()

 まず、先にax.set_xlimとax.set_ylimで図の範囲を指定し、そこにextentをあわせるようにしています。aspectはドキュメントを見た感じだとautoが無難そうに思います。


fig4.png
fig4.png


 どう見ても上下反転しているので、ZZを上下反転します。ついでに、マーカーの色を揃えることにします。

これが動くコードです

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from sklearn.datasets import load_iris
from sklearn.decomposition import PCA
from sklearn.neighbors import KNeighborsClassifier

def main():
    iris = load_iris()
    pca = PCA(n_components=2)
    X = pca.fit_transform(iris.data)
    ax = plt.subplot()
    ax.set_xlim((-5, 5))
    ax.set_ylim((-2, 2))
    cm = ListedColormap(["b", "g", "r"])
    ax.scatter(X[:,0], X[:,1], c=iris.target, cmap=cm)
    
    clf = KNeighborsClassifier()
    clf.fit(X, iris.target)
    XX, YY = np.meshgrid(np.arange(-5, 5, 0.025),
                         np.arange(-2, 2, 0.025))
    Z = clf.predict_proba(np.stack([XX.ravel(), YY.ravel()], axis=1))
    ZZ = np.flip(Z.reshape(XX.shape + (3, )), axis=1)
    ax.imshow(ZZ, alpha=0.2,
              aspect="auto", extent=(-5, 5, -2, 2))

    plt.savefig("fig5.png")

if __name__ == "__main__":
    main()

fig5.png
fig5.png

 だいたい不満のない結果になりました。ここまで長かった。

他の分類器も試す

 せっかくなのでいろいろやってみます。SVM, ロジスティック回帰, ランダムフォレストを追加してやってみましょう。
 

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from sklearn.datasets import load_iris
from sklearn.decomposition import PCA
from sklearn.neighbors import KNeighborsClassifier
from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier

def main():
    iris = load_iris()
    pca = PCA(n_components=2)
    X = pca.fit_transform(iris.data)

    fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(9, 9))
    knn = KNeighborsClassifier()
    svm = SVC(probability=True)
    lr = LogisticRegression()
    rfc = RandomForestClassifier(n_estimators=100)

    cm = ListedColormap(["b", "g", "r"])    
    XX, YY = np.meshgrid(np.arange(-5, 5, 0.025),
                         np.arange(-2, 2, 0.025))

    for ax, clf in zip(axes.ravel(), [knn, svm, lr, rfc]):
        ax.set_xlim((-5, 5))
        ax.set_ylim((-2, 2))
        ax.scatter(X[:,0], X[:,1], c=iris.target, cmap=cm)
    
        clf.fit(X, iris.target)
        Z = clf.predict_proba(np.stack([XX.ravel(), YY.ravel()], axis=1))
        ZZ = np.flip(Z.reshape(XX.shape + (3, )), axis=1)
        ax.imshow(ZZ, alpha=0.2,
                  aspect="auto", extent=(-5, 5, -2, 2))
        ax.set_title(clf.__class__.__name__)

    plt.tight_layout()
    plt.savefig("fig6.png")

if __name__ == "__main__":
    main()

fig6.png
fig6.png

 こんなもんか、という感じ。

まとめ

 やればできることはわかりました。もう少しかっこいい図にするには、さらなる工夫が要るのかもしれません。