静かなる名辞

pythonとプログラミングのこと


sklearnのfetch_20newsgroups_vectorizedでベクトル化された20 newsgroupsを試す

はじめに

 20 newsgroupsはこのブログでも過去何回か取り上げまたしが、ベクトル化済みのデータを読み込めるfetch_20newsgroups_vectorizedは意図的にスルーしていました。

 使えるかどうか気になったので、試してみます。

sklearn.datasets.fetch_20newsgroups_vectorized — scikit-learn 0.20.2 documentation

使い方

 単純なので基本的にはドキュメントを見てください。一応引数の説明だけ引用で貼りますが。

subset : ‘train’ or ‘test’, ‘all’, optional
  Select the dataset to load: ‘train’ for the training set, ‘test’ for the test set, ‘all’ for both, with shuffled ordering.

remove : tuple
  May contain any subset of (‘headers’, ‘footers’, ‘quotes’). Each of these are kinds of text that will be detected and removed from the newsgroup posts, preventing classifiers from overfitting on metadata.

  ‘headers’ removes newsgroup headers, ‘footers’ removes blocks at the ends of posts that look like signatures, and ‘quotes’ removes lines that appear to be quoting another post.

data_home : optional, default: None
  Specify an download and cache folder for the datasets. If None, all scikit-learn data is stored in ‘~/scikit_learn_data’ subfolders.

download_if_missing : optional, True by default
  If False, raise an IOError if the data is not locally available instead of trying to download the data from the source site.

return_X_y : boolean, default=False.
  If True, returns (data.data, data.target) instead of a Bunch object.

  New in version 0.20.

 sklearn 0.19とsklearn 0.20で仕様が変わっていて、新バージョンでは何かと親切な引数が増えています。この記事では0.19相当の機能しか使っていませんが。

確認する

 とりあえずデータ数と次元数でも見てみます。初回の実行では時間がかかりますが、内部的にはCountVectorizerでベクトル化を行う実装のようです(ベクトル化されたデータをダウンロードしてくる訳ではない)。二回目以降はキャッシュされてpickleの読み込み時間だけになります。実装は単純なので、確認したい人はドキュメントからsourceのリンクに飛んで見てみてください。

from sklearn.datasets import fetch_20newsgroups_vectorized

for subset in ["train", "test", "all"]:
    dataset =  fetch_20newsgroups_vectorized(subset=subset)
    print(dataset.data.shape)

""" =>
(11314, 130107)
(7532, 130107)
(18846, 130107)
"""

 13107次元という次元数は、昔の記事でも確認した通り低頻度語を落とす処理などは一切為されていない場合の次元数です。

 さて、この関数にはremoveという引数があります。これを使うと無駄なメッセージヘッダだのフッタだの、引用部だのを除去できます。

from sklearn.datasets import fetch_20newsgroups_vectorized

for subset in ["train", "test", "all"]:
    dataset =  fetch_20newsgroups_vectorized(
        subset=subset, remove=("headers", "footers", "quotes"))
    print(dataset.data.shape)

""" =>
(11314, 101631)
(7532, 101631)
(18846, 101631)
"""

 気持ち程度(3万弱)次元が下がります。

 なお、昔記事にしたときには気づきませんでしたが、この辺の機能は素のfetch_20newsgroupsにもあります(だからtrainのデータを更に分割して分類の学習データとテストデータに分けるなんてアホなことをしていた訳ですが……)。なので、これを使うためにfetch_20newsgroups_vectorizedを使う必要があるという話ではまったくありません。

使ってみる

 とりあえず、10万次元なんて使い物にならないことだけは何もしなくてもわかるので、次元数を下げる算段を考えることにします。

 CountVectorizerなら頻度が高すぎる、低すぎる単語は落とすオプションがデフォルトであるのですが、そういう便利機能はfetch_20newsgroups_vectorizedにはありません。また、その変換に対応したモデルもsklearn.preprocessingやsklearn.model_selectionには用意されていません。困ったものです。

 仕方ないので、SelectKBestを引っ張り出し、score_funcはnumpyで書くことにします。

 単に入力の頻度を数える関数はこんな感じです。

def freq(X, y=None):
    return np.sum(X, axis=0).A.reshape(-1)

 なんか無駄なものが後ろに付いている気がする? 私もそう思いますが、なんとSelectKBestがmatrix型で投げてくるので、arrayに直して返却する必要があります。えー、なんでー、って感じ。すでに非推奨扱いなのに

 とにかく、これを使って特徴選択します。高頻度語だけ残すという特徴選択になります。

import numpy as np
from sklearn.feature_selection import SelectKBest
from sklearn.datasets import fetch_20newsgroups_vectorized

def freq(X, y=None):
    return np.sum(X, axis=0).A.reshape(-1)

train =  fetch_20newsgroups_vectorized(
    subset="train", remove=("headers", "footers", "quotes"))

skb = SelectKBest(score_func=freq, k=2000)
print(skb.fit_transform(train.data, train.target).shape)

""" =>
(11314, 2000)
"""

 できそうですね。

 試しに分類やってみましょう。分類器は、BoWなのでMultinomialNBにしてみます。

sklearn.naive_bayes.MultinomialNB — scikit-learn 0.20.2 documentation

import numpy as np
from sklearn.feature_selection import SelectKBest
from sklearn.naive_bayes import MultinomialNB
from sklearn.pipeline import Pipeline
from sklearn.datasets import fetch_20newsgroups_vectorized
from sklearn.metrics import classification_report

def freq(X, y=None):
    return np.sum(X, axis=0).A.reshape(-1)

def main():
    train =  fetch_20newsgroups_vectorized(
        subset="train", remove=("headers", "footers", "quotes"))
    test =  fetch_20newsgroups_vectorized(
        subset="test", remove=("headers", "footers", "quotes"))

    skb = SelectKBest(score_func=freq, k=2000)
    nb = MultinomialNB()
    pl = Pipeline([("skb", skb), ("nb", nb)])

    pl.fit(train.data, train.target)
    pred = pl.predict(test.data)
    print(classification_report(test.target, pred, 
                                target_names=test.target_names))

if __name__ == "__main__":
    main()

""" =>
                          precision    recall  f1-score   support

             alt.atheism       0.51      0.07      0.12       319
           comp.graphics       0.50      0.56      0.53       389
 comp.os.ms-windows.misc       0.61      0.53      0.57       394
comp.sys.ibm.pc.hardware       0.54      0.60      0.57       392
   comp.sys.mac.hardware       0.72      0.45      0.55       385
          comp.windows.x       0.65      0.62      0.63       395
            misc.forsale       0.78      0.75      0.77       390
               rec.autos       0.60      0.60      0.60       396
         rec.motorcycles       0.64      0.56      0.59       398
      rec.sport.baseball       0.68      0.57      0.62       397
        rec.sport.hockey       0.50      0.74      0.60       399
               sci.crypt       0.58      0.64      0.61       396
         sci.electronics       0.49      0.38      0.43       393
                 sci.med       0.55      0.56      0.55       396
               sci.space       0.66      0.50      0.57       394
  soc.religion.christian       0.24      0.89      0.38       398
      talk.politics.guns       0.53      0.49      0.51       364
   talk.politics.mideast       0.71      0.62      0.66       376
      talk.politics.misc       0.82      0.05      0.09       310
      talk.religion.misc       0.50      0.00      0.01       251

               micro avg       0.53      0.53      0.53      7532
               macro avg       0.59      0.51      0.50      7532
            weighted avg       0.59      0.53      0.51      7532

"""

 極端に駄目なクラスが幾つかある。なんか上手く行ってない予感。

 次元数を4000まで増やし、MultinomialNBのパラメタで簡単にいじれて意味のありそうなalphaを調整します(alphaはデータ全体に加算されるスムージングパラメータ(たぶん)。余計なことしない方が良いと考えてごく小さめに設定する)。

import numpy as np
from sklearn.feature_selection import SelectKBest
from sklearn.naive_bayes import MultinomialNB
from sklearn.pipeline import Pipeline
from sklearn.datasets import fetch_20newsgroups_vectorized
from sklearn.metrics import classification_report

def freq(X, y=None):
    return np.sum(X, axis=0).A.reshape(-1)

def main():
    train =  fetch_20newsgroups_vectorized(
        subset="train", remove=("headers", "footers", "quotes"))
    test =  fetch_20newsgroups_vectorized(
        subset="test", remove=("headers", "footers", "quotes"))

    skb = SelectKBest(score_func=freq, k=4000)
    nb = MultinomialNB(alpha=0.01)
    pl = Pipeline([("skb", skb), ("nb", nb)])

    pl.fit(train.data, train.target)
    pred = pl.predict(test.data)
    print(classification_report(test.target, pred, 
                                target_names=test.target_names))

if __name__ == "__main__":
    main()

""" =>
                          precision    recall  f1-score   support

             alt.atheism       0.47      0.38      0.42       319
           comp.graphics       0.54      0.64      0.59       389
 comp.os.ms-windows.misc       0.61      0.55      0.57       394
comp.sys.ibm.pc.hardware       0.60      0.59      0.60       392
   comp.sys.mac.hardware       0.68      0.59      0.63       385
          comp.windows.x       0.75      0.68      0.71       395
            misc.forsale       0.79      0.77      0.78       390
               rec.autos       0.67      0.67      0.67       396
         rec.motorcycles       0.67      0.72      0.69       398
      rec.sport.baseball       0.82      0.74      0.78       397
        rec.sport.hockey       0.55      0.84      0.67       399
               sci.crypt       0.74      0.68      0.71       396
         sci.electronics       0.59      0.50      0.54       393
                 sci.med       0.73      0.66      0.70       396
               sci.space       0.65      0.69      0.67       394
  soc.religion.christian       0.50      0.83      0.63       398
      talk.politics.guns       0.52      0.68      0.59       364
   talk.politics.mideast       0.76      0.74      0.75       376
      talk.politics.misc       0.58      0.35      0.43       310
      talk.religion.misc       0.54      0.15      0.23       251

               micro avg       0.64      0.64      0.64      7532
               macro avg       0.64      0.62      0.62      7532
            weighted avg       0.64      0.64      0.63      7532

"""

 それなりにマシになりました。

まとめ

 fetch_20newsgroups_vectorizedが使えるか? というと、「使うとかえって余計な手間がかかるから、おとなしくfetch_20newsgroups呼んでCountVectorizerでよくね」という結論に達さざるを得ません。

 自然言語処理は特徴抽出が命みたいなところがあるので、そこを雑にやられて投げてこられても魅力を感じません。

 低頻度語とストップワードを除去する前処理でも勝手にかけてくれて、使いやすいサイズ感のデータで返してくれればまた違った感想になったと思うので、ちょっと残念な感じです。