静かなる名辞

pythonとプログラミングのこと


決定木をいろいろな方法で可視化する

はじめに

 決定木はデータが分類される過程がわかりやすいことから、可視化に向いています。特にサンプル数が少なく、データの特徴量の次元数が少ないようなケースではかなり直感的な結果が得られます。

 決定木の可視化では、原理的には単に図を描いて可視化すれば良いのですが、実際には実装によってさまざまな方法が用意されています。今回はscikit-learnの決定木分類器クラス(DecisionTreeClassifier)を使うことを前提として、いろいろな方法で可視化することを試してみようと思います。

sklearnの標準の方法でやる

 素晴らしいことにscikit-learnはデフォルトでそれ用の関数を用意してくれています。export_graphvizといいます。

sklearn.tree.export_graphviz — scikit-learn 0.21.3 documentation

 graphvizというのは外部ツールです。これを使った場合、描画まではやってくれませんが、graphvizで使えるdotというフォーマットのファイルに出力することができます。

 記事によってはgraphvizのpythonバインディングを使ったりしてpython上で処理していることもありますが、一度ファイルに出力してからdotコマンド(UNIXであれば簡単にインストールできます。たとえばapt install graphviz)で画像に変換した方が個人的には使いやすいので、その流れで使うことにします。

何も考えずに使う

 さて、単にirisでやってみましょう。

import subprocess
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier, export_graphviz

def main():
    iris = load_iris()
    dtc = DecisionTreeClassifier(max_depth=3)
    dtc.fit(iris.data, iris.target)
    export_graphviz(dtc, "tree1.1.dot")
    subprocess.run("dot -Tpng tree1.1.dot -o tree1.1.png".split())

if __name__ == "__main__":
    main()

 あまり木が深いとブログに載せづらくなるので、妥協してmax_depth=3にしています。また、横着してpythonコードからdotコマンドを呼び出しています。

tree1.1.png
tree1.1.png

 左上、左下、右下の3つの葉ノードに概ねすべてのデータが分類されており、分類が難しい一部のデータは完全に分類されないまま葉になっています。

 irisはこんなデータなので、納得感のある結果です。


RGBで可視化を試したときの記事から
RGBで可視化を試したときの記事から

sklearnとmatplotlibでiris(3クラス)の予測確率を可視化した話 - 静かなる名辞

特徴量の名前、クラス名の情報を付与する

 上の図でも慣れていればだいたいわかりますが、少しわかりづらいといえばわかりづらい面もあります。

 クラスの名前、特徴量の名前を付与するとわかりやすくなります。やってみましょう。

import subprocess
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier, export_graphviz

def main():
    iris = load_iris()
    dtc = DecisionTreeClassifier(max_depth=3)
    dtc.fit(iris.data, iris.target)
    export_graphviz(dtc, "tree1.2.dot",
                    feature_names=iris.feature_names,
                    class_names=iris.target_names)
    subprocess.run("dot -Tpng tree1.2.dot -o tree1.2.png".split())

if __name__ == "__main__":
    main()

tree1.2.png
tree1.2.png


 だいぶわかりやすくなりました。こうするのは基本なので、事情が許す限りはこうしましょう。

見た目をきれいにする

 色をつけたり、角を丸めたりすることができます。見た目の綺麗さに影響する引数には以下のものがあります。

  • filled : bool, optional (default=False)

 Trueにするとクラスを反映した色で塗ってくれます。

  • leaves_parallel : bool, optional (default=False)

 Trueにすると葉ノードがすべて木の下部の同じレベルに表示されます。途中が引き伸ばされるので、Trueにすると変な見た目になるし、誤解も招きやすくなると思います。使いづらい印象です。

  • impurity : bool, optional (default=True)

 不純度の表示を出すかどうか。

  • node_ids : bool, optional (default=False)

 ノードのIDを出すかどうか。

  • proportion : bool, optional (default=False)

 Falseの場合は実際のサンプル数、Trueの場合は総サンプル数に対する比率で各ノードのサンプル数が表現されます。

  • rotate : bool, optional (default=False)

 木の伸びる向きが回転します。

  • rounded : bool, optional (default=False)

 ノードを表す箱の角が丸められます。

  • special_characters : bool, optional (default=False)

 FalseだとPostScript非互換の特殊文字を無視してくれるそうです。

  • precision : int, optional (default=3)

 表示される小数の表示の最大桁数。

 あとは、max_depthもこっち側でいじれたりします。

 適当な設定でやってみましょう。

import subprocess
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier, export_graphviz

def main():
    iris = load_iris()
    dtc = DecisionTreeClassifier(max_depth=3)
    dtc.fit(iris.data, iris.target)
    export_graphviz(dtc, "tree1.3.dot",
                    feature_names=iris.feature_names,
                    class_names=iris.target_names,
                    filled=True, rounded=True, proportion=True)
    subprocess.run("dot -Tpng tree1.3.dot -o tree1.3.png".split())

if __name__ == "__main__":
    main()

tree1.3.png
tree1.3.png

 だいぶ人間にフレンドリーな感じになりました。プレゼン用スライドやブログに載せるときはこれくらいのものを使うと良いでしょう。モノクロ資料の場合は色までは塗らない方が良いです。

sklearn.tree.plot_treeで同じことをする

 これまではgraphvizを使う方法で可視化を行ってきましたが、実はsklearn 0.21以降ではmatplotlibで可視化する手段も用意されています。plot_tree関数です。

import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier, plot_tree

def main():
    iris = load_iris()
    dtc = DecisionTreeClassifier(max_depth=3)
    dtc.fit(iris.data, iris.target)
    plot_tree(dtc,
              feature_names=iris.feature_names,
              class_names=iris.target_names,
              filled=True, rounded=True, proportion=True)
    plt.savefig("tree1.4.png")
    
if __name__ == "__main__":
    main()

tree1.4.png(plot_treeを使って描画)
tree1.4.png(plot_treeを使って描画)

 見た目のクオリティは若干劣るような気もしますが、ほぼ同じような図をよりすっきりしたコードで描けます。ということで、こちらもおすすめです。

sklearn.tree.plot_treeをJupyter Notebookで使うと決定木の可視化が捗る・・・かな? - 静かなる名辞

dtreevizを使う方法

 dtreevizは決定木の可視化用ライブラリです。sklearnのデフォルトの図がダサいのに業を煮やした人たちが作ったようです。

GitHub - parrt/dtreeviz: A python library for decision tree visualization and model interpretation.

 けっこう活発に開発されており、ポテンシャルは未知数ですが流行りそうな雰囲気があります。日本語の記事もいくつかあります。

決定木の可視化ライブラリ「dtreeviz」が凄かったのでまとめる - St_Hakky’s blog
Pythonの決定木をdtreevizでスマートに可視化する - Qiita

 使い方は色々あるっぽいのですが、公式のAPIリファレンスはまだないし、この先どうなるのかわからないライブラリなので軽い紹介にとどめます。

つまづきポイント

 使おうと思って「あれ?」となった部分についてだけ説明しておきます。

  • 基本的にはdtreeviz.trees.dtreevizを使うということだと思います。引数は、tree_model, X_train, y_train, feature_names, target_name, class_namesの最低5つは渡す必要があります。
  • target_nameには目的変数の名前を渡します。表示に使われるだけ(だと思う)なので、"hoge"でも"fuga"でも""でも構いません。図中の何箇所かに出てきます。
  • class_namesにはシーケンス型、辞書などが渡せますが、numpy配列はなぜか渡せません。tolistで変換してください。
  • 出力を別ウィンドウで見るにはdtreeviz関数の返り値オブジェクト(DtreeVizクラス)のviewメソッドを使います。plt.show()と同じノリです。
  • 出力を画像として保存するには、dtreeviz関数の返り値オブジェクト(DtreeVizクラス)のsaveメソッドを用います。mac環境以外ではsvgファイル以外の選択肢はないようなので、.svgの拡張子のファイル名を指定します。

実践編

 ということで、雰囲気だけ見るためにやってみます。

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from dtreeviz.trees import dtreeviz

def main():
    iris = load_iris()
    dtc = DecisionTreeClassifier(max_depth=3)
    dtc.fit(iris.data, iris.target)
    viz = dtreeviz(dtc, X_train=iris.data, y_train=iris.target, 
                   feature_names=iris.feature_names, target_name="あやめ",
                   class_names=iris.target_names.tolist())
    viz.save("iris1.5.svg")
             
if __name__ == "__main__":
    main()

出力画像
出力画像

 いい感じになりました。

 なお、結果はSVGで出てきます。pngにしたい場合はimagemagicのconvertコマンドで変換できるという記事が検索するとたくさんでてきますが、なんかうまくいかなかったので試行錯誤した結果、rsvg-convertならいけました。

 参考:
svgからpngに変換 - Tan.Starの世界

まとめ

 sklearn標準でも一通りのことはできますし、dtreevizを使うと図がかっこよくなります。

 このような決定木の可視化に慣れておくと、とりあえず決定木に突っ込んで見てみるという選択肢が出てくるのでいいと思います。

 ただし、決定木でうまく可視化できるのは条件が良いデータだけで、主成分分析でうまく見れない(累積寄与率が低すぎて)ようなデータの性質をこれで見れるとは思わないほうが良いです。そういう意味では、使いどころが難しい気もします。うまくデータにハマると強力なときもあるんですけどねー。