【python】混合ガウスモデル (GMM)でハード・ソフトクラスタリング
はじめに
先日はFuzzy c-meansによるソフトクラスタリングを行いました。
【python】skfuzzyのFuzzy c-meansでソフトクラスタリング - 静かなる名辞
ソフトクラスタリングの有名な手法としてはFuzzy c-meansの他に、混合ガウスモデル(混合正規分布モデル)を使った手法があります。この手法はデータが「複数の正規分布から構成されている」と仮定し、その正規分布のパラメタ*1をEMアルゴリズム(expectation–maximization algorithm)という手法を使って最尤推定します。
ごちゃごちゃと書きましたが、要するに「3つのクラスタにクラスタリングしたければ、(各クラスタのデータの分布が正規分布に従うと仮定して)3つの正規分布が重なりあってると思ってGMMを使って解く」という乱暴なお話です。正規分布が重なりあっているとみなすということは、どの分布に属するかも確率でわかる訳で、これがソフトクラスタリングに使える理由です。ハードクラスタリングに使いたいときは、確率最大のクラスタラベルに振ることになるかと思います。
このGMM、pythonではsklearnに入っているので簡単に使えます。
sklearn.mixture.GaussianMixture — scikit-learn 0.20.1 documentation
ということで、他のクラスタリング手法と比較してみることにしました。
スポンサーリンク
実験の説明
先日の記事でやったのと同様、irisをPCAで二次元に落としたデータに対してクラスタリングを行います。クラスタリング結果(所属するクラスタの確率)はirisが3クラスのデータなのを利用し、色(RGB)で表現します。
比較するクラスタリング手法はk-means(ハード)、Fuzzy c-means(ソフト)、GMM(ハード・ソフト)です。
前回はFuzzy c-meansのパラメタmを動かして結果を見たりしましたが、今回これは2で決め打ちにします。
実験用ソースコードは次のものです。走らせるにはいつもの定番ライブラリ以外にscikit-fuzzyというライブラリを入れる必要があります(あるいはFuzzy c-means関連の部分をコメントアウトするか。でもskfuzzyはpipで一発で入るし、入れておいても別に損はない)。
# coding: UTF-8 import numpy as np from sklearn.datasets import load_iris from sklearn.decomposition import PCA from sklearn.cluster import KMeans as KM from sklearn.mixture import GaussianMixture as GMM from matplotlib import pyplot as plt from skfuzzy.cluster import cmeans def target_to_color(target): if type(target) == np.ndarray: return (target[0], target[1], target[2]) else: return "rgb"[target] def plot_data(data, target, filename="fig.png"): plt.figure() plt.scatter(data[:,0], data[:,1], c=[target_to_color(t) for t in target]) plt.savefig(filename) def gen_data(): iris = load_iris() pca = PCA(n_components=2) return pca.fit_transform(iris.data), iris.target def main(): data, target = gen_data() plot_data(data, target, filename="origin.png") km = KM(n_clusters=3) km_target = km.fit_predict(data) plot_data(data, km_target, filename="kmeans.png") cm_result = cmeans(data.T, 3, 2, 0.003, 10000) plot_data(data, cm_result[1].T, filename="cmeans_2.png") gmm = GMM(n_components=3, max_iter=1000) gmm.fit(data) gmm_target = gmm.predict(data) gmm_target_proba = gmm.predict_proba(data) plot_data(data, gmm_target, filename="gmm.png") plot_data(data, gmm_target_proba, filename="gmm_proba.png") if __name__ == "__main__": main()
結果
オリジナルデータ
元データ
k-means
k-means
c-means
Fuzzy c-means
こうして見るとc-meansは「ファジー理論を入れて境界を曖昧にしたk-means」という気がしてきます。実際アルゴリズムもそんな感じなんですけど。
GMM
GMM-based clustering (hard)
GMM-based clustering (soft)
どうしてこうなるのかというと、「irisのデータが正規分布していた」ということに尽きます。ま、アヤメの花びらの大きさとかのデータですから、正規分布しているんでしょう、きっと。
こうして見るとGMMの方が良さそうな気もしますが、「ちゃんと正規分布してるか」が怪しいとちょっと適用するのを躊躇うのと、あと計算コスト自体はk-meansより高いはずなので*2、いまいちk-meansと比べて使われていない、というのが実情に近いかもしれません。
まとめ
GMMを使ってみたらけっこう良かったです。