1. ホーム
  2. python

[解決済み] numpy の argpartition 出力を理解できない

2022-02-05 14:38:49

質問

numpyからarpgpartitionを使おうとしているのですが、何かうまくいかないようで、なかなか解決できません。以下がその内容です。

これらは、ソートされた配列の最初の5つの要素です。 norms

np.sort(norms)[:5]
array([ 53.64759445,  54.91434479,  60.11617279,  64.09630585,  64.75318909], dtype=float32)

しかし indices_sorted = np.argpartition(norms, 5)[:5]

norms[indices_sorted]
array([ 60.11617279,  64.09630585,  53.64759445,  54.91434479,  64.75318909], dtype=float32)

ソートされた配列と同じ結果が得られるはずだと思うのですが?

3をパラメータにするとうまくいきます。 indices_sorted = np.argpartition(norms, 3)[:3]

norms[indices_sorted]
array([ 53.64759445,  54.91434479,  60.11617279], dtype=float32)

これは私にはあまり意味をなさないのですが、どなたか洞察を与えていただけませんか?

EDIT: この質問を、argpartitionがk個の分割された要素の順序を維持するかどうか、と言い換えると、より意味があります。

どのように解決するのですか?

k番目のパラメータをスカラーとして与える代わりに、ソートされた順序で保持されるインデックスのリストを使用する必要があります。したがって,最初の 5 要素の代わりに np.argpartition(a,5)[:5] は、単に - を行う。

np.argpartition(a,range(5))[:5]

以下は、わかりやすくするためのサンプル実行です。

In [84]: a = np.random.rand(10)

In [85]: a
Out[85]: 
array([ 0.85017222,  0.19406266,  0.7879974 ,  0.40444978,  0.46057793,
        0.51428578,  0.03419694,  0.47708   ,  0.73924536,  0.14437159])

In [86]: a[np.argpartition(a,5)[:5]]
Out[86]: array([ 0.19406266,  0.14437159,  0.03419694,  0.40444978,  0.46057793])

In [87]: a[np.argpartition(a,range(5))[:5]]
Out[87]: array([ 0.03419694,  0.14437159,  0.19406266,  0.40444978,  0.46057793])

ご注意ください argpartition はパフォーマンス面で理にかなったもので、要素の小さなサブセットに対してソートされたインデックスを取得したい場合、例えば k の数は、elems の総数のごく一部です。

もっと大きなデータセットを使って、すべての要素についてソートされたインデックスを取得することで、上記の点を明確にしてみましょう -。

In [51]: a = np.random.rand(10000)*100

In [52]: %timeit np.argpartition(a,range(a.size-1))[:5]
10 loops, best of 3: 105 ms per loop

In [53]: %timeit a.argsort()
1000 loops, best of 3: 893 µs per loop

このように、すべてのエレメントをソートするには np.argpartition は向かない。

さて、この大きなデータセットで、最初の5つの要素だけについてソートされたインデックスを取得し、それらの順序も維持したいとしましょう -。

In [68]: a = np.random.rand(10000)*100

In [69]: np.argpartition(a,range(5))[:5]
Out[69]: array([1647,  942, 2167, 1371, 2571])

In [70]: a.argsort()[:5]
Out[70]: array([1647,  942, 2167, 1371, 2571])

In [71]: %timeit np.argpartition(a,range(5))[:5]
10000 loops, best of 3: 112 µs per loop

In [72]: %timeit a.argsort()[:5]
1000 loops, best of 3: 888 µs per loop

とても便利です。