静かなる名辞

pythonとプログラミングのこと


【python】sklearnのLDA(LatentDirichletAllocation)を試してみる

 注意:線形判別分析(LinearDiscriminantAnalysis)ではありません。トピックモデルのLDAです。

はじめに

 LDAといえば、トピックモデルの代表的な手法であり、一昔前の自然言語処理では頻繁に使われていました(最近は分散表現や深層学習に押されて廃れ気味な気もしますが)。

 普通、pythonでLDAといえばgensimの実装を使うことが多いと思います。が、gensimは独自のフレームワークを持っており、少しとっつきづらい感じがするのも事実です。

gensim: models.ldamodel – Latent Dirichlet Allocation

 このLDA、実はsklearnにもモデルがあるので、そっちを試しに使ってみようと思います。

 ライブラリのリンク
sklearn.decomposition.LatentDirichletAllocation — scikit-learn 0.20.1 documentation

何はともあれ分類タスク

 当ブログでは定番になっている20newsgroupsデータセットで文書分類をやってみましょう。

from sklearn.datasets import fetch_20newsgroups
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import CountVectorizer as CV
from sklearn.decomposition import LatentDirichletAllocation as LDA
from sklearn.ensemble import RandomForestClassifier as RFC
from sklearn.pipeline import Pipeline
from sklearn.metrics import classification_report

def main():
    news20 = fetch_20newsgroups()

    X_train, X_test, y_train, y_test = train_test_split(
        news20.data[:5000], news20.target[:5000],
        stratify=news20.target[:5000])

    cv = CV(min_df=0.04, stop_words="english")
    lda = LDA(n_components=100, max_iter=30, n_jobs=-1)
    rfc = RFC(n_estimators=500, n_jobs=-1)
    estimators = [("cv", cv), ("lda", lda), ("rfc", rfc)]
    pl = Pipeline(estimators)

    pl.fit(X_train, y_train)
    y_pred = pl.predict(X_test)
    print(classification_report(
        y_test, y_pred, target_names=news20.target_names))

if __name__ == "__main__":
    main()

 見ての通りのモデルです。CountVectorizerでBoWに変換し、LDAで次元削減したあとランダムフォレストにかけて分類します。パラメータは適当にいじって計算時間と分類性能の兼ね合いで決めました。

 で、肝心の分類性能ですが、classification_reportで出した結果を以下に示します。

                          precision    recall  f1-score   support

             alt.atheism       0.30      0.25      0.27        55
           comp.graphics       0.29      0.27      0.28        63
 comp.os.ms-windows.misc       0.58      0.74      0.65        66
comp.sys.ibm.pc.hardware       0.35      0.31      0.33        64
   comp.sys.mac.hardware       0.16      0.09      0.12        64
          comp.windows.x       0.31      0.25      0.28        67
            misc.forsale       0.46      0.62      0.53        66
               rec.autos       0.48      0.57      0.52        65
         rec.motorcycles       0.22      0.25      0.23        69
      rec.sport.baseball       0.25      0.25      0.25        65
        rec.sport.hockey       0.44      0.60      0.51        65
               sci.crypt       0.47      0.66      0.55        67
         sci.electronics       0.37      0.29      0.33        69
                 sci.med       0.25      0.28      0.27        68
               sci.space       0.65      0.63      0.64        62
  soc.religion.christian       0.49      0.56      0.52        66
      talk.politics.guns       0.31      0.28      0.29        58
   talk.politics.mideast       0.53      0.38      0.45        60
      talk.politics.misc       0.43      0.35      0.38        52
      talk.religion.misc       0.36      0.21      0.26        39

               micro avg       0.40      0.40      0.40      1250
               macro avg       0.39      0.39      0.38      1250
            weighted avg       0.39      0.40      0.39      1250

 ちょっとしょぼい・・・かな。何も考えず、BoWのままランダムフォレストに入れたときでも0.7くらいの評価指標になっていたので、しょぼいと言わざるを得ない感じです。

【python】sklearnのfetch_20newsgroupsで文書分類を試す(2) - 静かなる名辞

 トピック数やmax_iterなどのパラメータを上げると改善する傾向ですが、計算時間は増加します。気合い入れてパラメータを適正に選んでぶん回せばたぶんちゃんとした結果になるのだと思いますが、そうするコストは高そうです。

トピックの中身を見てみる

 LDAは分類の前処理に使うより、むしろ出来上がったトピックを見て「はーんこんな風に分かれるのね」と楽しむ方が面白いような気もするので、軽く見てみます。ただし、膨大な文書をLDAで回すのも、その結果を人手で解釈するのも大変なので、

  • 1000文書だけ入れる
  • 50トピックでLDAモデルを構築し、うち先頭10トピックだけ見ることにする
  • 各トピックの大きい係数5個だけ見る

 という手抜き解釈です。

import numpy as np
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import CountVectorizer as CV
from sklearn.decomposition import LatentDirichletAllocation as LDA

def main():
    news20 = fetch_20newsgroups()
    X = news20.data[:1000]
    y = news20.target[:1000]

    cv = CV(min_df=0.04, stop_words="english")
    lda = LDA(n_components=50, max_iter=50, n_jobs=-1)
    lda.fit(cv.fit_transform(X, y))

    feature_names = np.array(cv.get_feature_names())
    for i, component in enumerate(lda.components_[:10]):
        print("component:", i)
        idx = component.argsort()[::-1][:5]
        for j in idx:
            print(feature_names[j], component[j])

if __name__ == "__main__":
    main()

 結果は以下の通り。コメントで解釈を付けてみました。

component: 0  # carとspeedなので車関係
car 105.0307361645338
speed 66.0512803728286
years 34.28189251426528
year 28.634803273219863
ago 22.20999274754642
component: 1  # よくわからん
time 88.93501118238727
david 40.73976076823079
group 37.700769220589336
information 25.679148778591387
services 24.1215050863642
component: 2  # IBMなので恐らくコンピュータ関連
com 178.86471720750257
ibm 103.60651295996851
writes 53.83219596330523
article 38.98071797104704
lines 31.65901015205818
component: 3  # 教育機関とか?
edu 293.2163826198764
cc 106.47895192641978
writes 96.85399760416148
organization 67.46869749413264
article 67.44141504814338
component: 4  # 上のトピックと似ていてeduがトップに来る。なんでしょうねぇ(メールアドレスとかURLだったりして)
edu 303.2418519434582
article 97.83240562135947
writes 89.61827003849298
technology 78.82229006007034
organization 65.80274084351083
component: 5  # windowsとかソフト関連
windows 150.91253925843284
00 91.50205933157577
15 24.611812739517074
software 19.04945545853369
20 14.902153206775973
component: 6  # よくわからない
thought 28.89413050306475
kind 21.98192709084761
david 13.591746503570576
code 8.357254795587956
little 7.6968699914925
component: 7  # メーリングリスト関連?
list 62.55156212975074
address 50.37113047912307
data 45.349827179735904
information 39.6113013625701
mail 33.98010842887469
component: 8  # カネ、政府、ソフトウェア……
money 36.13150161231052
gov 28.618673196459778
software 21.193931015050016
source 20.603010804981693
cost 20.35807225674739
component: 9  # 大学とかのネットワーク関連っぽい
edu 439.0633403498364
university 165.04391386786645
posting 162.13921320408093
nntp 159.2048859320109
host 158.37699187575328

 恐ろしく雑な解釈ですが、一応トピックごとにまとまっているっぽいことはなんとなくわかりました。

 はっきり内容がわかるのは半分以下ですが……。

sklearnのLDAの欠点

 たぶんperplexityとかCoherenceとかは見れないと思う。

 体感ではgensimのより遅い気がします。ちゃんと測っていないので、あくまでもなんとなくですが。

まとめ

 sklearnでも一応できることはわかりました。

 機能がしょぼいので本格的に使いたければgensimの方がおすすめです。gensimが使えない環境でLDAしないといけなくなったときには代用品として使えなくはありません。