静かなる名辞

pythonとプログラミングのこと

【python】ランダムフォレストのOOBエラーが役に立つか確認

はじめに

 RandomForestではOOBエラー(Out-of-bag error、OOB estimate、OOB誤り率)を見ることができます。交差検証と同様に汎化性能を見れます。

 原理の説明とかは他に譲るのですが、これはちゃんと交差検証のように使えるのでしょうか? もちろん原理的には使えるのでしょうが、実際どうなるのかはやってみないとわかりません。

 もしかしたらもう他の人がやっているかもしれませんが*1、自分でやった方が納得感があります*2

 ということで、やってみました。

みたいこと

 とりあえずトイデータでやってみて、交差検証の場合とスコアを比べる。交差検証は分割のkを変えて様子を見る必要があるでしょう。

 また、モデルの性能がよくなったり悪くなったりしたとき、交差検証と同様のスコアの変化が見れるかも確認してみる必要がありそうです。

プログラム

 こんなプログラムを書きました。

import time

import numpy as np
from sklearn.datasets import load_iris, load_digits, load_wine
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import precision_recall_fscore_support as prf

def test_func(dataset):
    rfc = RandomForestClassifier(n_estimators=300, oob_score=True, n_jobs=-1)
    t1 = time.time()
    rfc.fit(dataset.data, dataset.target)
    t2 = time.time()
    oob_pred = rfc.oob_decision_function_.argmax(axis=1)
    print("{0:6} p:{2:.4f} r:{3:.4f} f1:{4:.4f}  time:{1:.4f}".format(
        "oob", t2-t1,
        *prf(dataset.target, oob_pred, average="macro")))

    rfc = RandomForestClassifier(n_estimators=300, oob_score=False, n_jobs=-1)
    for k in [2,4,6,8]:
        skf = StratifiedKFold(n_splits=k)
        trues = []
        preds = []
        t1 = time.time()
        for train_idx, test_idx in skf.split(dataset.data, dataset.target):
            rfc.fit(dataset.data[train_idx], dataset.target[train_idx])
            trues.append(dataset.target[test_idx])
            preds.append(rfc.predict(dataset.data[test_idx]))
        t2 = time.time()

        print("{0:6} p:{2:.4f} r:{3:.4f} f1:{4:.4f}  time:{1:.4f}".format(
            "CV k={}".format(k), t2-t1,
            *prf(np.hstack(trues), np.hstack(preds), average="macro")))        

def main():
    iris = load_iris()
    digits = load_digits()
    wine = load_wine()

    print("iris")
    test_func(iris)

    print("\ndigits")
    test_func(digits)

    print("\nwine")
    test_func(wine)

    print("\niris + noise")
    iris.data += np.random.randn(*iris.data.shape)*iris.data.std()
    test_func(iris)

    print("\ndigits + noise")
    digits.data += np.random.randn(*digits.data.shape)*digits.data.std()
    test_func(digits)

    print("\nwine + noise")
    wine.data += np.random.randn(*wine.data.shape)*wine.data.std()
    test_func(wine)

if __name__ == "__main__":
    main()

 注目ポイント。

  • iris, digits, wineでためしました
  • rfc.oob_decision_function_.argmax(axis=1)でOOBで推定されたラベルが得られるので、それを使って精度、再現率、F1値を計算しています(マクロ平均)。交差検証でも同様に計算することで、同じ指標で比較を可能にしています(accuracyだけだと寂しいので・・・)
  • 処理の所要時間も測った
  • 特徴量にノイズを付与して分類させることで、条件が悪いときのスコアも確認。ノイズは特徴量全体の標準偏差くらいの正規分布を付与しました。理論的な根拠は特にないです(だいたい軸によってスケールが違うのを無視しているのだし・・・)

 だいたいこんな感じで、あとは普通にやってます*3

結果

 テキスト出力をそのまんま。

iris
oob    p:0.9534 r:0.9533 f1:0.9533  time:0.5774
CV k=2 p:0.9534 r:0.9533 f1:0.9533  time:1.3196
CV k=4 p:0.9600 r:0.9600 f1:0.9600  time:2.7399
CV k=6 p:0.9600 r:0.9600 f1:0.9600  time:4.1367
CV k=8 p:0.9600 r:0.9600 f1:0.9600  time:5.9125

digits
oob    p:0.9795 r:0.9794 f1:0.9794  time:0.9511
CV k=2 p:0.9282 r:0.9271 f1:0.9272  time:1.9101
CV k=4 p:0.9429 r:0.9422 f1:0.9420  time:4.0144
CV k=6 p:0.9519 r:0.9515 f1:0.9515  time:5.9703
CV k=8 p:0.9500 r:0.9494 f1:0.9494  time:7.8814

wine
oob    p:0.9762 r:0.9803 f1:0.9780  time:0.6950
CV k=2 p:0.9748 r:0.9812 f1:0.9774  time:1.2872
CV k=4 p:0.9714 r:0.9746 f1:0.9728  time:3.2011
CV k=6 p:0.9603 r:0.9643 f1:0.9619  time:4.4640
CV k=8 p:0.9714 r:0.9746 f1:0.9728  time:6.1454

iris + noise
oob    p:0.4723 r:0.4800 f1:0.4759  time:0.6677
CV k=2 p:0.5298 r:0.5267 f1:0.5279  time:1.3729
CV k=4 p:0.5456 r:0.5533 f1:0.5489  time:3.1605
CV k=6 p:0.4776 r:0.4800 f1:0.4788  time:4.1999
CV k=8 p:0.5043 r:0.5000 f1:0.5009  time:6.0875

digits + noise
oob    p:0.7850 r:0.7853 f1:0.7834  time:1.4698
CV k=2 p:0.7365 r:0.7377 f1:0.7342  time:2.8328
CV k=4 p:0.7567 r:0.7568 f1:0.7543  time:5.4918
CV k=6 p:0.7682 r:0.7697 f1:0.7671  time:9.1825
CV k=8 p:0.7717 r:0.7714 f1:0.7692  time:12.1008

wine + noise
oob    p:0.5377 r:0.5483 f1:0.5348  time:0.7344
CV k=2 p:0.4911 r:0.5034 f1:0.4873  time:1.4287
CV k=4 p:0.4828 r:0.5058 f1:0.4855  time:3.1024
CV k=6 p:0.5375 r:0.5439 f1:0.5320  time:4.3943
CV k=8 p:0.4778 r:0.5129 f1:0.4840  time:5.9890

 これからわかることとしては、

  • 全体的にOOBエラーとCVで求めたスコアはそこそこ近いので、OOBエラーはそこそこ信頼できると思います
  • OOBは短い時間で済むので、お得です
  • CVの場合はkを大きくすると性能が上がりますが、これは学習に使うデータ量がk=2なら全体の1/2、k=4なら3/4、k=8なら7/8という風に増加していくからです
  • OOBエラーがCVのスコアを上回る場合、下回る場合ともにあるようです。OOBエラーは、学習しているデータ量はほぼleave one outに近いものの、木の本数が設定値の約1/3くらいになるという性質があります。学習データ量の有効性が高いデータセットではCVの場合より高いスコアに、木の本数の有効性が高いデータセットではCVの場合に対して低いスコアになるということでしょう

 まあ、とりあえず妥当に評価できるんじゃねえの? という感じがします。もちろん同じ条件下で計測したスコアではないので、OOBエラーと交差検証の結果を直接比較することはできませんが、OOBエラー同士の優劣で性能を見積もる分にはたぶん問題ないでしょう*4

 注意点としては、OOBエラーは全体の木の約1/3(厳密には36%くらい)を使って予測するので、実際の結果よりは悪めに出る可能性があります。木の本数を多めにおごってやると良いでしょう。

結論

 OOBでもいい。

参考

qiita.com

*1:というか論文は確実にあると思いますが

*2:し、簡単なことなら人の書いた論文を読み解くより自分でやった方が楽だったりします

*3:プログラムを書き上げて回してから記事にしているので説明が雑・・・

*4:逆に言えば、他の分類器との比較には使えないというけっこう致命的な欠点がある訳ですが・・・