静かなる名辞

pythonとプログラミングのこと

【python】sklearnライクなデータセットを作る

 自作したりネットから拾ってきたデータセットにsklearnライクなインターフェースがあるとそこそこ便利です。なので、作る方法について調べました。

 とりあえずデータセットを読み込んで型を調べます。

>>> from sklearn.datasets import load_iris
>>> iris = load_iris()
>>> type(iris)
<class 'sklearn.utils.Bunch'>

 Bunchという型らしいです。

 ではこのBunchを継承したクラスを定義してデータセットのオブジェクトを作るのかというとそうではなく、実はこのBunchというのは「キーにattribとしてアクセスできる辞書のようなもの」であり、コンストラクタに渡したキーと値のセットを内部で作ります。わかりづらいと思いますので例を示すと、

>>> from sklearn.datasets.base import Bunch
>>> hoge = Bunch(hogehoge="hoge!")
>>> hoge.hogehoge
'hoge!'

 こうやって使えるということです。sklearnのソースを読んでも、このような形で使っているので間違いないようです。

scikit-learn/base.py at master · scikit-learn/scikit-learn · GitHub

 なので、自作のsklearnライクなデータセットを作りたいということであれば、このBunchをimportしてあげた上で、

Bunch(data=自作したnumpy配列のデータなど, target=自作したnumpy配列のターゲットなど)

 こんな形で書けば良いということです。あとはload_***関数を作ってこのBunchのインスタンスをreturnしてやれば完了です。

 ちなみに、Bunchのattribはその気になれば外部から上書きすることも可能。実際のdatasetsの実装もそうなっているようです。些か危ない気もしますが、pythonによくある「紳士協定で使え(間違った使い方なんかしないよね)」という奴でしょう。しかし、in-placeで処理される関数をうっかり呼び出したら普通に事故る気が・・・。

>>> from sklearn.datasets import load_iris
>>> import numpy as np
>>> iris = load_iris()
>>> iris.data[:10]
array([[5.1, 3.5, 1.4, 0.2],
       [4.9, 3. , 1.4, 0.2],
       [4.7, 3.2, 1.3, 0.2],
       [4.6, 3.1, 1.5, 0.2],
       [5. , 3.6, 1.4, 0.2],
       [5.4, 3.9, 1.7, 0.4],
       [4.6, 3.4, 1.4, 0.3],
       [5. , 3.4, 1.5, 0.2],
       [4.4, 2.9, 1.4, 0.2],
       [4.9, 3.1, 1.5, 0.1]])
>>> np.random.shuffle(iris.data)
>>> iris.data[:10]
array([[7.7, 2.8, 6.7, 2. ],
       [5.2, 3.5, 1.5, 0.2],
       [6.5, 3. , 5.8, 2.2],
       [5.1, 3.8, 1.5, 0.3],
       [6.2, 2.8, 4.8, 1.8],
       [4.4, 2.9, 1.4, 0.2],
       [4.9, 3. , 1.4, 0.2],
       [4.3, 3. , 1.1, 0.1],
       [6.4, 3.1, 5.5, 1.8],
       [6.8, 3.2, 5.9, 2.3]])

 やはり事故るのだった。ちょっと怖いですねぇ。read-onlyにするのはそこまで難しくはないと思うのですが、敢えてそうしていないのか、単に実装がヘボいのか