静かなる名辞

pythonとプログラミングのこと


cross_val_scoreはもうやめようね。一発で交差検証するにはcross_validateを使う

はじめに

 scikit-learnで交差検証を行い、評価指標を算出する方法としては、cross_val_scoreがよくオススメされています。実際、「sklearn 交差検証」みたいな検索キーワードでググるとこの関数がよく出てきます。しかし、この関数は複数の評価指標を算出することができず、一つのスコアしか出力してくれません。

 これでどういうとき困るかというと、Accuracy, Precision, Recall, F1をすべて出したい・・・というとき、困ります。基本的にこれらはぜんぶ出して評価するものという考え方のもと検証しようとすると、うまくいかないのです。その辺りを柔軟に制御するために、これまで私は自分で交差検証のコードを書いてきました。

 しかし、そんな必要はありませんでした。cross_validateという関数を使えばいいのです。

 ・・・と、大げさに書いてみましたが、実はこの関数はsklearnのオンラインドキュメントのAPI ReferenceのModel Validationの筆頭に出ています。

API Reference — scikit-learn 0.21.2 documentation

 じゃあなんでcross_val_scoreがオススメされるの? というと、cross_validateの方が若干新しいからです。これが使えるのは0.19以降です。それだけの理由ですね。古い情報がそのまま定着してしまって、もっと良いものが出てきてもそれが広まらないというのはよくあることです。

 cross_validateの方が何かと柔軟なので、こちらを使いましょう。以下では淡々と説明していきます。

スポンサーリンク


cross_validate

 ドキュメントの個別ページはここです。

sklearn.model_selection.cross_validate — scikit-learn 0.21.2 documentation

 こんな関数になっています。

sklearn.model_selection.cross_validate(estimator, X, y=None,
 groups=None, scoring=None, cv=None,
 n_jobs=1, verbose=0, fit_params=None, pre_dispatch=‘2*n_jobs’, return_train_score=’warn’)

 あまりcross_val_scoreと代わり映えしないと言えばしないのですが、scoringに指定できるものがcross_val_scoreとは違います。ちょっと長いですが、当該部分を全文引用します。

scoring : string, callable, list/tuple, dict or None, default: None

A single string (see The scoring parameter: defining model evaluation rules) or a callable (see Defining your scoring strategy from metric functions) to evaluate the predictions on the test set.

For evaluating multiple metrics, either give a list of (unique) strings or a dict with names as keys and callables as values.

NOTE that when using custom scorers, each scorer should return a single value. Metric functions returning a list/array of values can be wrapped into multiple scorers that return one value each.

See Specifying multiple metrics for evaluation for an example.

If None, the estimator’s default scorer (if available) is used.

 けっこう凄いことがかいてあります。嬉しいのはlist/tuple, dictが渡せるところです。どう嬉しいかはすぐにわかります。

実例

 実際のコードを見ないと良さが伝わらないと思うので、簡単な例を示します。

# coding: UTF-8

from pprint import pprint

from sklearn.datasets import load_iris
from sklearn.naive_bayes import GaussianNB
from sklearn.model_selection import cross_validate, StratifiedKFold

def main():
    iris = load_iris()
    
    gnb = GaussianNB()

    scoring = {"p": "precision_macro",
               "r": "recall_macro",
               "f":"f1_macro"}

    skf = StratifiedKFold(shuffle=True, random_state=0)
    scores = cross_validate(gnb, iris.data, iris.target,
                            cv=skf, scoring=scoring)
    pprint(scores)

if __name__ == "__main__":
    main()

 scoresの中身がどうなってるのか? というのがミソの部分で、こうなっています。

{'fit_time': array([0.00096989, 0.00186491, 0.00081563]),
 'score_time': array([0.00263953, 0.0025475 , 0.00260162]),
 'test_f': array([0.9212963 , 0.98037518, 0.95816993]),
 'test_p': array([0.9251462 , 0.98148148, 0.96296296]),
 'test_r': array([0.92156863, 0.98039216, 0.95833333]),
 'train_f': array([0.95959596, 0.95955882, 0.95096979]),
 'train_p': array([0.95959596, 0.96067588, 0.95122655]),
 'train_r': array([0.95959596, 0.95959596, 0.95098039])}

 これを見た瞬間、「もう自分で交差検証を書く必要はない」と私は思いました。欲しい情報はぜんぶ出せます。

・・・といいつつ

 不満点はまったくない訳ではなく、

  • 各foldでfitさせた分類器がほしい、とか
  • 各foldの正解/予測ラベルを生でほしい、とか

 
 いろいろ思うところはあるので、そういうものが必要なときは他のモジュールを探すか(把握しきれてないです、はい)、自分で書くことになります。

 それでも普通に数字を出して評価する「だけ」なら一行一発で済む関数があるというのは素晴らしいことです。手間が省けます。素晴らしい。cross_val_scoreはその域には達してないと判断していましたが、cross_validateは使えそうです。

 これで簡単に交差検証が書けるようになりました。

まとめ

 cross_validateを使おう。