1. ホーム
  2. python

[解決済み] numpy配列の0要素を効率的にカウントする?

2022-02-05 04:25:42

質問

の中にあるゼロの要素の数を数える必要があります。 numpy 配列になります。私が知っているのは numpy.count_nonzero 関数がありますが、ゼロの要素をカウントするアナログはないようです。

私の配列はそれほど大きくはありませんが(通常は1E5要素未満)、この操作は数百万回実行されます。

もちろん len(arr) - np.count_nonzero(arr) しかし、もっと効率的な方法はないだろうか。

現在私が行っている方法をMWEで紹介します。

import numpy as np
import timeit

arrs = []
for _ in range(1000):
    arrs.append(np.random.randint(-5, 5, 10000))


def func1():
    for arr in arrs:
        zero_els = len(arr) - np.count_nonzero(arr)


print(timeit.timeit(func1, number=10))

解決方法は?

A 2x を使用した方が早いでしょう。 np.count_nonzero() を使用していますが 条件 を必要に応じて表示します。

In [3]: arr
Out[3]: 
array([[1, 2, 0, 3],
      [3, 9, 0, 4]])

In [4]: np.count_nonzero(arr==0)
Out[4]: 2

In [5]:def func_cnt():
            for arr in arrs:
                zero_els = np.count_nonzero(arr==0)
                # here, it counts the frequency of zeroes actually

を使用することもできます。 np.where() よりも遅いのですが np.count_nonzero()

In [6]: np.where( arr == 0)
Out[6]: (array([0, 1]), array([2, 2]))

In [7]: len(np.where( arr == 0))
Out[7]: 2


効率性。 (降順)

In [8]: %timeit func_cnt()
10 loops, best of 3: 29.2 ms per loop

In [9]: %timeit func1()
10 loops, best of 3: 46.5 ms per loop

In [10]: %timeit func_where()
10 loops, best of 3: 61.2 ms per loop


アクセラレータでさらに高速化

以上を実現することが可能になりました。 3桁のスピードアップ の助けを借りて JAX アクセラレータ(GPU/TPU)が利用できる場合。JAXを使うもう一つの利点は、NumPyのコードをJAX互換にするために、ほとんど修正する必要がないことです。以下は、再現可能な例です。

In [1]: import jax.numpy as jnp
In [2]: from jax import jit

# set up inputs
In [3]: arrs = []
In [4]: for _ in range(1000):
   ...:     arrs.append(np.random.randint(-5, 5, 10000))

# JIT'd function that performs the counting task
In [5]: @jit
   ...: def func_cnt():
   ...:     for arr in arrs:
   ...:         zero_els = jnp.count_nonzero(arr==0)


# efficiency test
In [8]: %timeit func_cnt()
15.6 µs ± 391 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)