1. ホーム
  2. python

[解決済み] なぜnumpyは私のFortranルーチンよりずっと速いのですか?

2023-07-19 18:12:12

質問

Fortranで書かれたシミュレーションから、温度分布を表す512^3の配列が得られます。配列はバイナリファイルに保存され、そのサイズは約1/2Gです。私はこの配列の最小、最大、および平均を知る必要があり、いずれにせよすぐに Fortran コードを理解する必要があるので、それを試してみることにし、次の非常に簡単なルーチンを思いつきました。

  integer gridsize,unit,j
  real mini,maxi
  double precision mean

  gridsize=512
  unit=40
  open(unit=unit,file='T.out',status='old',access='stream',&
       form='unformatted',action='read')
  read(unit=unit) tmp
  mini=tmp
  maxi=tmp
  mean=tmp
  do j=2,gridsize**3
      read(unit=unit) tmp
      if(tmp>maxi)then
          maxi=tmp
      elseif(tmp<mini)then
          mini=tmp
      end if
      mean=mean+tmp
  end do
  mean=mean/gridsize**3
  close(unit=unit)

私が使っているマシンでは、1ファイルあたり約25秒かかります。これはかなり長いと思ったので、先にPythonで次のようにしてみました。

    import numpy

    mmap=numpy.memmap('T.out',dtype='float32',mode='r',offset=4,\
                                  shape=(512,512,512),order='F')
    mini=numpy.amin(mmap)
    maxi=numpy.amax(mmap)
    mean=numpy.mean(mmap)

さて、もちろんもっと速くなることは予想していましたが、本当に驚かされました。同一条件下で1秒を切るのです。平均値は、私の Fortran ルーチンが検出したもの (これは 128 ビット浮動小数点数で実行されたので、私は何となくそれをより信頼しています) から逸脱していますが、有効数字の 7 桁目くらいでしか検出されません。

どうしてnumpyはそんなに速いのでしょうか?つまり、これらの値を見つけるために、配列のすべてのエントリを見なければならないのですよね?私のFortranルーチンで何か非常に愚かなことをしているので、これほど時間がかかっているのでしょうか?

EDITです。

コメントでの質問に答えるために

  • はい、また、32 ビットと 64 ビットの浮動小数点数で Fortran ルーチンを実行しましたが、パフォーマンスには影響がありませんでした。
  • 私は iso_fortran_env で、128ビット浮動小数点数を提供します。
  • 32 ビット浮動小数点数を使用すると、私の平均はかなりずれてしまうので、精度は本当に問題です。
  • 私は異なる順序で異なるファイルで両方のルーチンを実行したので、キャッシュは比較で公正であるべきだったと思います ?
  • 私は実際にオープン MP を試しましたが、同時に異なる位置でファイルから読み取るようにしました。あなたのコメントと回答を読むと、これは今となっては本当に馬鹿げているようで、ルーチンもかなり長くかかるようになりました。私は配列操作でそれを試してみるかもしれませんが、おそらくそれは必要ないかもしれません。
  • ファイルのサイズは実際には 1/2G で、これはタイプミスでした。
  • これから配列の実装を試してみます。

EDIT 2です。

Alexander Vogt と @casey が回答で提案したことを実装してみたところ、以下のように高速化されました。 numpy のように高速ですが、@Luaan が指摘したような精度の問題が発生する可能性があります。32 ビット浮動小数点数の配列を使って sum によって計算された平均は 20%オフです。すること

...
real,allocatable :: tmp (:,:,:)
double precision,allocatable :: tmp2(:,:,:)
...
tmp2=tmp
mean=sum(tmp2)/size(tmp)
...

問題は解決されますが、計算時間が増加します(それほど多くはありませんが、顕著です)。 この問題を回避する良い方法はありますか?ファイルからシングルを直接ダブルに読み込む方法は見つかりませんでした。 また、どのようにして numpy はこれを避けることができるのでしょうか?

今までありがとうございました。

どのように解決するのですか?

あなたのFortran実装は、2つの大きな欠点があります。

  • IO と計算を混在させています (そして、ファイルからエントリごとに読み込んでいます)。
  • ベクトル/行列演算を使用しない。

この実装はあなたのものと同じ操作を行いますが、私のマシンでは20倍も高速です。

program test
  integer gridsize,unit
  real mini,maxi,mean
  real, allocatable :: tmp (:,:,:)

  gridsize=512
  unit=40

  allocate( tmp(gridsize, gridsize, gridsize))

  open(unit=unit,file='T.out',status='old',access='stream',&
       form='unformatted',action='read')
  read(unit=unit) tmp

  close(unit=unit)

  mini = minval(tmp)
  maxi = maxval(tmp)
  mean = sum(tmp)/gridsize**3
  print *, mini, maxi, mean

end program

このアイデアは、ファイル全体を一つの配列に読み込むことです。 tmp に一度に読み込むことです。そうすると、関数 MAXVAL , MINVAL そして SUM を配列上に直接置くことができます。


精度の問題については 単に倍精度値を使用して、次のようにオンザフライで変換を行うだけです。

mean = sum(real(tmp, kind=kind(1.d0)))/real(gridsize**3, kind=kind(1.d0))

は計算時間をわずかに増加させるだけです。要素単位やスライスで操作を実行してみましたが、デフォルトの最適化レベルでは必要な時間が増加するだけでした。

-O3 で、要素ごとの加算は配列操作よりも ~3 % 良いパフォーマンスを示します。倍精度演算と単精度演算の差は、私のマシンでは平均して 2% 未満です (個々の実行ははるかに大きく異なります)。


LAPACKを使用した非常に高速な実装を紹介します。

program test
  integer gridsize,unit, i, j
  real mini,maxi
  integer  :: t1, t2, rate
  real, allocatable :: tmp (:,:,:)
  real, allocatable :: work(:)
!  double precision :: mean
  real :: mean
  real :: slange

  call system_clock(count_rate=rate)
  call system_clock(t1)
  gridsize=512
  unit=40

  allocate( tmp(gridsize, gridsize, gridsize), work(gridsize))

  open(unit=unit,file='T.out',status='old',access='stream',&
       form='unformatted',action='read')
  read(unit=unit) tmp

  close(unit=unit)

  mini = minval(tmp)
  maxi = maxval(tmp)

!  mean = sum(tmp)/gridsize**3
!  mean = sum(real(tmp, kind=kind(1.d0)))/real(gridsize**3, kind=kind(1.d0))
  mean = 0.d0
  do j=1,gridsize
    do i=1,gridsize
      mean = mean + slange('1', gridsize, 1, tmp(:,i,j),gridsize, work)
    enddo !i
  enddo !j
  mean = mean / gridsize**3

  print *, mini, maxi, mean
  call system_clock(t2)
  print *,real(t2-t1)/real(rate)

end program


これは単精度行列の1-normを使用します。 SLANGE を行列の列に対して使用します。実行時間は単精度の配列関数を使用する方法よりもさらに速く、精度の問題も表示されません。