静かなる名辞

pythonとプログラミングのこと


【python】MeanShiftのbandwidthを変えるとどうなるか実験してみた

 前回の記事ではMeanShiftクラスタリングを試してみました。

www.haya-programming.com

 このMeanShiftにはbandwidthというパラメータがあり、クラスタ数を決定する上で重要な役割を果たしているはずです。

 いまいち結果に納得がいかないというとき、bandwidthをいじって改善が見込めるのかどうか確認してみます。

プログラム

 例によってirisとwineで比較。簡単に書きました。

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from sklearn.datasets import load_iris, load_wine
from sklearn.cluster import MeanShift, estimate_bandwidth
from sklearn.decomposition import PCA

def process(dataset, name):
    origin_bandwidth = estimate_bandwidth(dataset.data)
    rates = np.logspace(np.log10(0.2), np.log10(5), 11)
    fig, axes = plt.subplots(nrows=3, ncols=4, figsize=(24,18))

    PCA_X = PCA().fit_transform(dataset.data)
    for target in range(3):
        axes[0,0].scatter(PCA_X[dataset.target==target, 0],
                        PCA_X[dataset.target==target, 1],
                        c=cm.Paired(target/3))
    axes[0,0].set_title("original label", fontsize=28)

    for r, ax in zip(rates, axes.ravel()[1:]):
        ms = MeanShift(bandwidth=r*origin_bandwidth, n_jobs=-1)
        y = ms.fit_predict(dataset.data)
        n_cluster = ms.cluster_centers_.shape[0]
        for target in range(n_cluster):
            ax.scatter(PCA_X[y==target, 0],
                       PCA_X[y==target, 1],
                       c=cm.Paired(target/n_cluster))
        ax.set_title("r:{0:.3f} b:{1:.3f}".format(
            r, origin_bandwidth), fontsize=28)
    fig.savefig(name+".png")

def main():
    iris = load_iris()
    wine = load_wine()

    process(iris, "iris")
    process(wine, "wine")

if __name__ == "__main__":
    main()

 bandwidthをsklearn.cluster.estimate_bandwidthの推定値(デフォルトで用いられる値)の1/5倍から5倍まで変化させ、結果をプロットします。

結果

 プロットされた結果を示します。

 結果の図の見方は、まずタイトルが

  • b

 sklearn.cluster.estimate_bandwidthによる推定値

  • r

 かけた比率

 という風に対応しており、あとは便宜的に2次元上に主成分分析で写像した散布図が、クラスタごとに色分けされて出ています。一枚目が本来のクラスに基づく色分け、r=1の図が推定値による色分けです。

 まずiris。

iris.png
iris.png
 きれいに元通りになるrは今回見た中にはありませんでした。クラスタ数的にはr=0.525とr=0.725の間くらいで3クラスタになりそうですが、この図を見るとそれでうまく元通りまとまるかは疑問です。

 次にwine。

wine.png
wine.png
 こちらもうまく元通りにはならないようです。そもそもデータが悪いという話はあると思います。

結論

 確かにクラスタ数は変わるが、クラスタリングの良し悪しが改善するかはなんともいえないですね。

 データをスケーリングしたり、もっと色々頑張ると改善は見込めるかもしれません。