静かなる名辞

pythonとプログラミングのこと

【python】CKY法をpythonで実装

 構文解析アルゴリズムのCKY法の実装について説明する。参考にしたテキストはこれ。

自然言語処理の基礎

自然言語処理の基礎

理論

 教科書読めばぜんぶ書いてあります(ちゃんと解説しようとすると大変なので、自分で説明したくない)。

 ネット上の解説としては、

 この3つを読めば理解できると思います。プログラムとして実装する前に、紙とペンで一回やってみるべきです。

 CKY法は理屈が簡単な割に、プログラムに書き起こすのが面倒くさいタイプのアルゴリズムです(だと思う)。頭でちゃんと理解してから挑むことが望ましい。

問題設定

 教科書のToy problemで行きます。次の文を構文解析するというものです。

astronomers saw stars with ears

 文法は次のように与えられています。

S→NP VP:1.0
PP→P NP:1.0
VP→V NP:0.7
VP→VP PP:0.3
NP→NP PP:0.4
P→with:1.0
V→saw:1.0
NP→astronomers:0.1
NP→ears:0.18
NP→saw:0.04
NP→stars:0.18
NP→telescope:0.1

 コロンの後の数字は文法規則が適用される確率です。今回の例文には多義性があるので(複数の構文解析結果が出て来る)、この確率を使ってもっともらしい結果を選ぼうという訳です。

実装

 書いたソースコードは記事の最後に丸ごと載せてます。以下では実装方法を簡単に解説します。

 とりあえず何も考えず、上記例文と文法をグローバル変数として定義。

example_sentence = "astronomers saw stars with ears"

grammar_text = """S→NP VP:1.0
PP→P NP:1.0
VP→V NP:0.7
VP→VP PP:0.3
NP→NP PP:0.4
P→with:1.0
V→saw:1.0
NP→astronomers:0.1
NP→ears:0.18
NP→saw:0.04
NP→stars:0.18
NP→telescope:0.1"""

 CKYクラスを作ることにする。CKYクラスインスタンスのparseメソッドを呼べば、然るべき型で結果を返してくれるように作る方針で行こう。

class CKY:
    def __init__(self, grammar_text):
        self.grammar_dict = defaultdict(set)
        for line in grammar_text.split("\n"):
            rule, p = line.split(":")
            l, r = rule.split("→")
            self.grammar_dict[r].add((l, float(p)))

        self.cky_array = None

 文法の情報はクラス内で持ってないと困るので、defaultdict(set)で格納。解析の過程を考えると、文法の右側の要素から左側の要素(あと確率)が取り出せると嬉しいので、そうする。今回、keyは"astronomers"みたいな終端記号を表す文字列か、"P NP"みたいな非終端記号のペアを表す文字列にしている。非終端記号のペアはtupleにして……とか考えるとかえって面倒くさい。

 なんでsetにするのか? 右が同じだけど左が異なるパターンがあるから。「V→saw:1.0」と「NP→saw:0.04」とかですね。

 self.cky_arrayはとりあえずNoneにしておく。文の長さが決まらないとinitializeできない。ということは、initializeするメソッドも作っておく必要がある(別にメソッドにしないでparseメソッド内でやっても良いんだが)。

 このcky_array、CKYテーブル、詰まるところ三角行列をどう実装するかは悩みどころで、適当に作るとインデックスでエラく苦労する。とりあえず、今回は多義性がある文を解析するので、行列の一つのセルに複数の要素が入るので、三重リストみたいなものにしないといけない。

 という訳で、単語数*単語数*空リストの三重リストとして実装する。こうすると下半分が無駄にメモリを食うけど、大した実害はない。ちょっと無駄っぽいけど。

    def _init_cky_array(self, length):
        self.cky_array = [[[] for _ in range(length)]
                          for i in range(length)]
        return self.cky_array

 あとは空リストに適当に値を突っ込んでいけば、CKYテーブルは作れる。適当に、と書いたけど、ここが一番つらい。とりあえずparseメソッドを書き始める。

    def parse(self, text):
        words = text.split()
        self.length = len(words)
        self._init_cky_array(self.length)

 まず行列の対角成分(NT→Tの文法の部分)を埋める。

        for i, word in enumerate(words):
            for l, p in self.grammar_dict[word]:
                self.cky_array[i][i].append((l, word, p))

 テーブルのセルに入れる値は、(左辺値(str), 単語(str), 確率(float))の形のtuple。対角成分以外では、(左辺値(str), (右辺の左のindex(tuple), 右辺の右のindex(tuple)), 確率(float))とする方針。こういうところに独自定義のオブジェクトを入れたがる人がたまにいるが、経験上かえって面倒くさくなることが多い。CKY以外のクラスは定義しないで書く。

 そして謎のfor文で一気にCKY配列を埋める。コメントを書いたので頑張って理解して。

        # 対角成分の1つ右,2つ右,...と処理を回すループ
        for d in range(1, self.length):
            
            # 斜め下に進んでいくループ
            # i,jでどのセルを処理対象とするか決める
            for i in range(self.length - d):
                j = i + d
                
                # セルの中身を埋めるループ
                for k in range(i, j):

                    # 右辺の可能な組み合わせを列挙してる
                    for a, b in product(
                            range(len(self.cky_array[i][k])),
                            range(len(self.cky_array[k+1][j]))):

                        # 辞書のキーを作る
                        s = "{0} {1}".format(
                            self.cky_array[i][k][a][0],
                            self.cky_array[k+1][j][b][0])
                        
                        # キーに合致する文法をぜんぶ出す
                        for l,p in self.grammar_dict[s]:

                            # セルに中身を入れる
                            self.cky_array[i][j].append(
                                (l, ((i,k,a), (k+1,j,b)), p))

 なんとforループが5つもある。五重のforと名付けよう。なお、CKY法は O(n^3)アルゴリズムである。一番内側の2つのループは基本的に定数項で、計算量には効かない。

 このforループが終わると、CKYテーブルはすでに完成している。後は、これを辿って構文木を出力するだけだ。セルにindexを入れたことがここで効いてくる。なお、紙とペンでやるときはNP1とか通し番号を振り、NP1(astronomers)とかPP1(P1, NP2)みたいに書くと混乱が少ない。
 
 構文木を辿る方法は、当然再帰である。indexを見て次のセルに飛べば良い。indexを表現するtupleではなく、終端記号を表現するstrが格納されていたら、再帰の終了条件を満たしたとみなす。

 構文木の出力形式は、XMLで行く。僕はlxmlを使って処理するのに慣れているので、今回も使うことにする。

 以上の方針を決めた上で、次のコードを書き足す。

        # parseの最後
        return self._gen_xml_etree_list()

    def _traverse_tree(self, index=(0,0,0)):
        # 構文木を辿る
        i,j,k = index
        node = self.cky_array[i][j][k]
        elem = etree.Element(node[0])
        child = node[1]
        p = node[2]
        elem.attrib["p"] = str(p)

        if type(child) == str:
            elem.text = child
            return elem
        else:
            l, r = child
            elem.append(self._traverse_tree(index=l))            
            elem.append(self._traverse_tree(index=r))
            return elem

    def _gen_xml_etree_list(self):
        # 再帰呼出しを開始する
        lst = []
        for i, s in enumerate(self.cky_array[0][self.length - 1]):
            if s[0] != "S":
                pass
            else:
                # etreeのまま返すことにしよう...
                lst.append(self._traverse_tree((0,4,i)))
        return lst

 お疲れ様でした。これでCKYクラスの実装はおしまいです。あとはmainを書くだけです。mainは確率の総乗を計算し、またetreeを文字列に変換して表示します。

def main():
    cky = CKY(grammar_text)
    lst = cky.parse(example_sentence)
    for xml_tree in lst:
        p = 1
        for elem in xml_tree.iter():
            p *= float(elem.attrib["p"])
        print(p)
        print(
            etree.tostring(xml_tree, pretty_print=True).decode())

 あとはmainの呼び出しを書けば終了です。import等はここでは省略しました。記事末尾のソースコードには載せています。

結果

 実行結果を見せます。

0.0009071999999999998
<S p="1.0">
  <NP p="0.1">astronomers</NP>
  <VP p="0.7">
    <V p="1.0">saw</V>
    <NP p="0.4">
      <NP p="0.18">stars</NP>
      <PP p="1.0">
        <P p="1.0">with</P>
        <NP p="0.18">ears</NP>
      </PP>
    </NP>
  </VP>
</S>

0.0006803999999999998
<S p="1.0">
  <NP p="0.1">astronomers</NP>
  <VP p="0.3">
    <VP p="0.7">
      <V p="1.0">saw</V>
      <NP p="0.18">stars</NP>
    </VP>
    <PP p="1.0">
      <P p="1.0">with</P>
      <NP p="0.18">ears</NP>
    </PP>
  </VP>
</S>

 まあ、良いのでは。「天文学者は耳と一緒の星を見た」と「天文学者は耳で星を見た」の二通りの解析結果があり、前者の方が良い感じ、みたいな結果・・・だと思います。

感想

 やっぱりアルゴリズムが簡単な割に書くのが大変だった。特にindexの範囲をミスると簡単に死ねるので、自分で実装するときはindexを随時printして(あるいはデバッガで確認して)正しい値が出ているか確認しながらやるのが良いです。

付録 ソースコード