静かなる名辞

pythonとプログラミングのこと


sklearnのtrain_test_splitを使うときはstratifyを指定した方が良い

はじめに

 train_test_splitはsklearnをはじめて学んだ頃からよくお世話になっています。しかし、stratifyを指定しないとまずいことが起こり得ると最近気づきました。

stratifyって何?

 層化という言葉を聞いたことがある方が一定数いると思いますが、それです。あるいは、交差検証でStratifiedKFoldを使ったことのある人もだいたい理解しているでしょう。

 要するに、クラスラベル(など)ごとにサンプルを取ってくるということを意味します。2クラス分類、100サンプルで元の各クラスの比率が50,50であれば、10件取り出しても5,5ずつになることが保証される、というのがここでいうstratify(層化)の意味です。

 これを指定しないと、けっこういい加減なことになります。

指定しないで試してみる

 まずdigitsでやります。こんな感じ。

from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.metrics import classification_report

def main():
    dataset = load_digits()
    X_train, X_test, y_train, y_test\
        = train_test_split(dataset.data, dataset.target)
    svm = SVC(gamma="scale")
    svm.fit(X_train, y_train)
    prediction = svm.predict(X_test)
    print(classification_report(y_test, prediction))

if __name__ == "__main__":
    main()

 結果

              precision    recall  f1-score   support

           0       1.00      1.00      1.00        58
           1       0.94      1.00      0.97        49
           2       1.00      1.00      1.00        46
           3       1.00      1.00      1.00        51
           4       1.00      0.98      0.99        42
           5       1.00      0.98      0.99        42
           6       0.98      1.00      0.99        43
           7       1.00      1.00      1.00        46
           8       0.98      0.93      0.96        46
           9       1.00      1.00      1.00        27

    accuracy                           0.99       450
   macro avg       0.99      0.99      0.99       450
weighted avg       0.99      0.99      0.99       450

 supportに注目してください。0は58, 9は27で、倍以上ばらついています。あまり良い評価とは言えないということです。

 stratifyを指定します。

    X_train, X_test, y_train, y_test\
        = train_test_split(dataset.data, dataset.target,
                           stratify=dataset.target)

 結果

              precision    recall  f1-score   support

           0       1.00      0.98      0.99        45
           1       0.96      1.00      0.98        46
           2       1.00      1.00      1.00        44
           3       1.00      1.00      1.00        46
           4       0.98      0.96      0.97        45
           5       1.00      1.00      1.00        46
           6       1.00      1.00      1.00        45
           7       1.00      0.98      0.99        45
           8       0.98      0.95      0.96        43
           9       0.96      1.00      0.98        45

    accuracy                           0.99       450
   macro avg       0.99      0.99      0.99       450
weighted avg       0.99      0.99      0.99       450

 若干のぶれがありますが、概ね数がそろっています。これによってより妥当な評価が可能になると思われます。

 次に不均衡データの場合について考えてみます。同じデータセットで、5かどうかを判別するとしましょう。

from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.metrics import classification_report

def main():
    dataset = load_digits()
    y = dataset.target == 5
    X_train, X_test, y_train, y_test\
        = train_test_split(dataset.data, y)
    svm = SVC(gamma="scale")
    svm.fit(X_train, y_train)
    prediction = svm.predict(X_test)
    print(classification_report(
        y_test, prediction))

if __name__ == "__main__":
    main()

 結果(何回かやった中で偏りがひどかった奴です)

              precision    recall  f1-score   support

       False       1.00      1.00      1.00       415
        True       1.00      0.97      0.99        35

    accuracy                           1.00       450
   macro avg       1.00      0.99      0.99       450
weighted avg       1.00      1.00      1.00       450

 stratifyを指定します。

    X_train, X_test, y_train, y_test\
        = train_test_split(dataset.data, y, stratify=y)

 結果

              precision    recall  f1-score   support

       False       1.00      1.00      1.00       404
        True       0.98      0.96      0.97        46

    accuracy                           0.99       450
   macro avg       0.99      0.98      0.98       450
weighted avg       0.99      0.99      0.99       450

 この場合は、何回やってもデータ個数の比率は保証されます。

まとめ

 ということで、stratifyを指定した方が良いでしょう。

 testに回すサンプル数が少ないときは特に相対的にクラス間の比率のばらつきが大きくなるので、やっておくべきだと思います。testに回すサンプル数が多い場合はまだやらなくてもなんとかなるかもしれませんが、それでもやった方が結果が安定するはずです。