静かなる名辞

pythonとプログラミングのこと


【python】itertools.chainを使って複数のiterableを一つにまとめる

概要

 複数のiterable(リストとか)を結合させてループさせたいときがあります。

>>> lst1 = [1, 2, 3]
>>> lst2 = [4, 5, 6]
>>> # 1, 2, 3, 4, 5, 6というループをやりたい

 連結すればできたりしますが、余計なメモリを確保するのでスマートではないし、パフォーマンスが気になります。

>>> for x in lst1 + lst2:
...     print(x)
... 
1
2
3
4
5
6

 というマニアックなお悩みを解決してくれるのがitertools.chainです。

使い方

 リファレンスを見ると使い方が書いてあります。

itertools --- 効率的なループ実行のためのイテレータ生成関数 — Python 3.7.4 ドキュメント

 まあまったく難しいことはなく、

>>> from itertools import chain
>>> for x in chain(lst1, lst2):
...     print(x)
... 
1
2
3
4
5
6

 たったこれだけで使えます。単純ですね。

速度を測る

 実際にパフォーマンスがいいのかどうか試してみましょう。

import timeit
from itertools import chain

def f1(a, b):
    for x in a + b:
        pass

def f2(a, b):
    for x in chain(a, b):
        pass

def main():
    for n in [10, 100, 1000, 10000]:
        print(n)
        a, b = list(range(n)), list(range(n, 2*n))
        t1 = timeit.timeit(lambda : f1(a, b), number=1000)
        t2 = timeit.timeit(lambda : f2(a, b), number=1000)
        print("{0:.6f}\n{1:.6f}".format(t1, t2))
        
if __name__ == "__main__":
    main()

""" =>
10
0.000657
0.000822
100
0.005128
0.004024
1000
0.032347
0.025703
10000
0.249460
0.232873
"""

 なんかどっちでも良いような気が……論理的にはchainを使ったほうがスマートですが、

  • リストの結合と値の取り出しは速い
  • chainはかえって遅い

 というあたりが本質かな、という気がします。

 素のリストとchainの速度も比較してみる。

import timeit
from itertools import chain

def f1(lst):
    for x in lst:
        pass

def f2(lst):
    for x in chain(lst):
        pass

def main():
    lst = list(range(10**4))
    t1 = timeit.timeit(lambda : f1(lst), number=1000)
    t2 = timeit.timeit(lambda : f2(lst), number=1000)
    print("{0:.6f}\n{1:.6f}".format(t1, t2))
    # 0.092083
    # 0.130389

if __name__ == "__main__":
    main()

 かえって遅い可能性もあるということです。

諦めて二重forで書く

 これで良いんじゃ……

import timeit
from itertools import chain

def f1(a, b):
    for x in a + b:
        pass

def f2(a, b):
    for x in chain(a, b):
        pass

def f3(a, b):
    for lst in [a, b]:
        for x in lst:
            pass

def main():
    for n in [10, 100, 1000, 10000]:
        print(n)
        a, b = list(range(n)), list(range(n, 2*n))
        t1 = timeit.timeit(lambda : f1(a, b), number=1000)
        t2 = timeit.timeit(lambda : f2(a, b), number=1000)
        t3 = timeit.timeit(lambda : f3(a, b), number=1000)
        print("{0:.6f}\n{1:.6f}\n{2:.6f}".format(t1, t2, t3))
        
if __name__ == "__main__":
    main()

""" =>
10
0.000675
0.000837
0.000726
100
0.003831
0.003594
0.002471
1000
0.033191
0.025650
0.016758
10000
0.279039
0.225997
0.163289
"""

 こんなのが一番速いです。無理に一つにまとめない方がパフォーマンス上良い、というしょうもない結論に。
 

まとめ

 なんかchainするとパフォーマンス上は微妙な気もしますが、使うと連結よりスマートな感じのコードになるので、好みで使うと良いでしょう。