Skip to content

Latest commit

 

History

History
501 lines (369 loc) · 31.4 KB

magni.md

File metadata and controls

501 lines (369 loc) · 31.4 KB

Example results on bare metal

I ran these on the following hardware:

  • Intel Xeon E5-2650 v4 @ 2.20 GHz
  • 512GB DDR4 memory (not that we would need it)
  • NVidia Tesla P100 (16GB memory)

Software stack:

  • CentOS 7

  • GNU compiler tookit 8.3.0

  • Python 3.8.12

  • CUDA 11.2

  • Packages pulled from pip

  • Backend versions:

    aesara==2.2.4
    cupy==9.5.0
    jax==0.2.24
    numba==0.54.1
    numpy==1.19.5
    pytorch==1.10.0
    tensorflow==2.6.0

Contents

Equation of state

An equation consisting of >100 terms with no data dependencies and only elementary math. This benchmark should represent a best-case scenario for vector instructions and GPU performance.

CPU

$ taskset -c 23 python run.py benchmarks/equation_of_state/

benchmarks.equation_of_state
============================
Running on CPU

size          backend     calls     mean      stdev     min       25%       median    75%       max       Δ
------------------------------------------------------------------------------------------------------------------
       4,096  pytorch       10,000     0.000     0.000     0.000     0.000     0.000     0.000     0.003     6.188
       4,096  jax           10,000     0.000     0.000     0.000     0.000     0.000     0.000     0.004     4.581
       4,096  numba         10,000     0.001     0.000     0.001     0.001     0.001     0.001     0.001     2.808
       4,096  aesara        10,000     0.001     0.000     0.001     0.001     0.001     0.001     0.004     2.517
       4,096  tensorflow    10,000     0.001     0.000     0.001     0.001     0.001     0.001     0.004     2.507
       4,096  numpy         10,000     0.002     0.000     0.002     0.002     0.002     0.002     0.005     1.000

      16,384  pytorch       10,000     0.002     0.000     0.002     0.002     0.002     0.002     0.002     4.947
      16,384  jax           10,000     0.002     0.000     0.002     0.002     0.002     0.002     0.002     4.193
      16,384  tensorflow    10,000     0.002     0.000     0.002     0.002     0.002     0.002     0.006     3.598
      16,384  numba         10,000     0.003     0.000     0.003     0.003     0.003     0.003     0.006     2.861
      16,384  aesara         1,000     0.003     0.000     0.003     0.003     0.003     0.003     0.003     2.734
      16,384  numpy          1,000     0.008     0.000     0.008     0.008     0.008     0.008     0.008     1.000

      65,536  pytorch        1,000     0.007     0.000     0.007     0.007     0.007     0.007     0.007     4.655
      65,536  tensorflow     1,000     0.007     0.000     0.007     0.007     0.007     0.007     0.007     4.405
      65,536  jax            1,000     0.008     0.000     0.008     0.008     0.008     0.008     0.012     3.997
      65,536  numba          1,000     0.011     0.000     0.011     0.011     0.011     0.011     0.011     2.954
      65,536  aesara         1,000     0.011     0.000     0.011     0.011     0.011     0.011     0.011     2.899
      65,536  numpy            100     0.032     0.000     0.032     0.032     0.032     0.032     0.032     1.000

     262,144  pytorch        1,000     0.023     0.000     0.023     0.023     0.023     0.023     0.024     5.570
     262,144  tensorflow     1,000     0.024     0.000     0.023     0.024     0.024     0.024     0.024     5.501
     262,144  jax              100     0.024     0.000     0.024     0.024     0.024     0.024     0.025     5.423
     262,144  numba            100     0.039     0.000     0.039     0.039     0.039     0.039     0.039     3.319
     262,144  aesara           100     0.040     0.000     0.040     0.040     0.040     0.040     0.040     3.280
     262,144  numpy            100     0.130     0.001     0.129     0.130     0.130     0.131     0.131     1.000

   1,048,576  pytorch          100     0.092     0.000     0.092     0.092     0.092     0.092     0.092     7.118
   1,048,576  jax              100     0.103     0.000     0.103     0.103     0.103     0.104     0.104     6.308
   1,048,576  tensorflow       100     0.105     0.000     0.105     0.105     0.105     0.105     0.105     6.210
   1,048,576  numba            100     0.161     0.000     0.160     0.161     0.161     0.161     0.161     4.060
   1,048,576  aesara           100     0.164     0.000     0.163     0.164     0.164     0.164     0.164     3.987
   1,048,576  numpy             10     0.653     0.002     0.647     0.653     0.653     0.654     0.654     1.000

   4,194,304  pytorch           10     0.388     0.000     0.388     0.388     0.388     0.388     0.388     9.440
   4,194,304  jax               10     0.397     0.000     0.397     0.397     0.397     0.398     0.398     9.220
   4,194,304  tensorflow        10     0.418     0.000     0.418     0.418     0.418     0.418     0.419     8.755
   4,194,304  numba             10     0.630     0.001     0.629     0.630     0.630     0.631     0.631     5.812
   4,194,304  aesara            10     0.647     0.000     0.647     0.647     0.647     0.647     0.648     5.659
   4,194,304  numpy             10     3.662     0.002     3.659     3.661     3.663     3.664     3.665     1.000

(time in wall seconds, less is better)

$ taskset -c 23 python run.py benchmarks/equation_of_state/ -s 16777216

benchmarks.equation_of_state
============================
Running on CPU

size          backend     calls     mean      stdev     min       25%       median    75%       max       Δ
------------------------------------------------------------------------------------------------------------------
  16,777,216  pytorch           10     1.380     0.009     1.372     1.375     1.376     1.383     1.402    10.413
  16,777,216  tensorflow        10     1.665     0.001     1.664     1.665     1.665     1.665     1.667     8.628
  16,777,216  jax               10     1.737     0.001     1.736     1.737     1.737     1.738     1.740     8.270
  16,777,216  numba             10     2.436     0.002     2.432     2.436     2.436     2.437     2.438     5.899
  16,777,216  aesara            10     2.549     0.001     2.548     2.549     2.549     2.549     2.553     5.636
  16,777,216  numpy             10    14.369     0.004    14.362    14.366    14.368    14.371    14.377     1.000

(time in wall seconds, less is better)

GPU

$ for backend in cupy jax pytorch tensorflow; do CUDA_VISIBLE_DEVICES="0" python run.py benchmarks/equation_of_state/ --gpu -b $backend -b numpy; done

benchmarks.equation_of_state
============================
Running on GPU

size          backend     calls     mean      stdev     min       25%       median    75%       max       Δ
------------------------------------------------------------------------------------------------------------------
       4,096  numpy         10,000     0.002     0.001     0.002     0.002     0.002     0.002     0.016     1.000
       4,096  cupy           1,000     0.007     0.001     0.006     0.007     0.007     0.007     0.020     0.273

      16,384  cupy           1,000     0.007     0.001     0.007     0.007     0.007     0.007     0.020     1.190
      16,384  numpy          1,000     0.008     0.002     0.007     0.008     0.008     0.008     0.022     1.000

      65,536  cupy           1,000     0.007     0.001     0.007     0.007     0.007     0.007     0.020     6.930
      65,536  numpy            100     0.047     0.004     0.032     0.043     0.050     0.050     0.052     1.000

     262,144  cupy           1,000     0.007     0.001     0.007     0.007     0.007     0.007     0.021    29.436
     262,144  numpy            100     0.200     0.008     0.125     0.198     0.203     0.203     0.205     1.000

   1,048,576  cupy             100     0.016     0.000     0.016     0.016     0.016     0.016     0.017    49.831
   1,048,576  numpy             10     0.811     0.001     0.810     0.810     0.811     0.811     0.812     1.000

   4,194,304  cupy             100     0.061     0.000     0.061     0.061     0.061     0.061     0.062    60.944
   4,194,304  numpy             10     3.694     0.002     3.691     3.693     3.694     3.695     3.698     1.000

(time in wall seconds, less is better)

benchmarks.equation_of_state
============================
Running on GPU

size          backend     calls     mean      stdev     min       25%       median    75%       max       Δ
------------------------------------------------------------------------------------------------------------------
       4,096  jax           10,000     0.000     0.001     0.000     0.000     0.000     0.000     0.015    14.306
       4,096  numpy         10,000     0.002     0.001     0.002     0.002     0.002     0.002     0.017     1.000

      16,384  jax           10,000     0.000     0.001     0.000     0.000     0.000     0.000     0.015    67.323
      16,384  numpy          1,000     0.009     0.002     0.007     0.008     0.008     0.008     0.022     1.000

      65,536  jax           10,000     0.000     0.001     0.000     0.000     0.000     0.000     0.016   380.146
      65,536  numpy            100     0.051     0.002     0.043     0.051     0.051     0.051     0.056     1.000

     262,144  jax            1,000     0.000     0.001     0.000     0.000     0.000     0.000     0.015  1276.789
     262,144  numpy            100     0.205     0.001     0.204     0.204     0.205     0.205     0.208     1.000

   1,048,576  jax              100     0.000     0.000     0.000     0.000     0.000     0.000     0.002  2165.099
   1,048,576  numpy             10     0.849     0.001     0.848     0.848     0.849     0.850     0.851     1.000

   4,194,304  jax              100     0.001     0.000     0.001     0.001     0.001     0.001     0.001  3323.075
   4,194,304  numpy             10     3.763     0.002     3.760     3.761     3.764     3.765     3.765     1.000

(time in wall seconds, less is better)

benchmarks.equation_of_state
============================
Running on GPU

size          backend     calls     mean      stdev     min       25%       median    75%       max       Δ
------------------------------------------------------------------------------------------------------------------
       4,096  pytorch       10,000     0.000     0.001     0.000     0.000     0.000     0.000     0.013    15.127
       4,096  numpy         10,000     0.002     0.001     0.002     0.002     0.002     0.002     0.015     1.000

      16,384  pytorch       10,000     0.000     0.001     0.000     0.000     0.000     0.000     0.013    64.685
      16,384  numpy          1,000     0.008     0.002     0.007     0.008     0.008     0.008     0.020     1.000

      65,536  pytorch        1,000     0.000     0.001     0.000     0.000     0.000     0.000     0.013   359.766
      65,536  numpy            100     0.047     0.002     0.045     0.046     0.046     0.047     0.053     1.000

     262,144  pytorch        1,000     0.000     0.001     0.000     0.000     0.000     0.000     0.012  1067.722
     262,144  numpy            100     0.200     0.002     0.187     0.199     0.200     0.201     0.206     1.000

   1,048,576  pytorch        1,000     0.000     0.001     0.000     0.000     0.000     0.000     0.012  1673.047
   1,048,576  numpy             10     0.772     0.001     0.771     0.771     0.772     0.772     0.774     1.000

   4,194,304  pytorch          100     0.001     0.000     0.001     0.001     0.001     0.001     0.001  3173.256
   4,194,304  numpy             10     3.690     0.002     3.687     3.688     3.690     3.691     3.693     1.000

(time in wall seconds, less is better)

benchmarks.equation_of_state
============================
Running on GPU

size          backend     calls     mean      stdev     min       25%       median    75%       max       Δ
------------------------------------------------------------------------------------------------------------------
       4,096  tensorflow    10,000     0.000     0.001     0.000     0.000     0.000     0.000     0.015     4.807
       4,096  numpy         10,000     0.002     0.001     0.002     0.002     0.002     0.002     0.017     1.000

      16,384  tensorflow    10,000     0.000     0.001     0.000     0.000     0.000     0.000     0.016    22.519
      16,384  numpy          1,000     0.008     0.001     0.007     0.008     0.008     0.008     0.022     1.000

      65,536  tensorflow    10,000     0.000     0.001     0.000     0.000     0.000     0.000     0.016   113.232
      65,536  numpy            100     0.042     0.005     0.035     0.039     0.039     0.040     0.052     1.000

     262,144  tensorflow     1,000     0.000     0.000     0.000     0.000     0.000     0.000     0.012   555.981
     262,144  numpy            100     0.199     0.003     0.196     0.197     0.198     0.198     0.207     1.000

   1,048,576  tensorflow     1,000     0.001     0.001     0.001     0.001     0.001     0.001     0.015  1160.568
   1,048,576  numpy             10     0.791     0.001     0.790     0.791     0.791     0.792     0.793     1.000

   4,194,304  tensorflow       100     0.001     0.001     0.001     0.001     0.001     0.001     0.009  5053.383
   4,194,304  numpy             10     3.752     0.002     3.749     3.750     3.751     3.754     3.756     1.000

(time in wall seconds, less is better)

Isoneutral mixing

A more balanced routine with many data dependencies (stencil operations), and tensor shapes of up to 5 dimensions. This is the most expensive part of Veros, so in a way this is the benchmark that interests me the most.

CPU

$ taskset -c 23 python run.py benchmarks/isoneutral_mixing/

benchmarks.isoneutral_mixing
============================
Running on CPU

size          backend     calls     mean      stdev     min       25%       median    75%       max       Δ
------------------------------------------------------------------------------------------------------------------
       4,096  jax            1,000     0.001     0.000     0.001     0.001     0.001     0.001     0.003     3.617
       4,096  numba          1,000     0.001     0.000     0.001     0.001     0.001     0.001     0.004     3.074
       4,096  aesara         1,000     0.003     0.000     0.003     0.003     0.003     0.003     0.005     1.566
       4,096  pytorch        1,000     0.004     0.000     0.004     0.004     0.004     0.004     0.006     1.061
       4,096  numpy          1,000     0.004     0.000     0.004     0.004     0.004     0.004     0.007     1.000

      16,384  jax            1,000     0.006     0.000     0.006     0.006     0.006     0.006     0.009     2.759
      16,384  numba          1,000     0.007     0.000     0.007     0.007     0.007     0.007     0.009     2.348
      16,384  pytorch        1,000     0.011     0.000     0.010     0.010     0.011     0.011     0.014     1.528
      16,384  aesara         1,000     0.011     0.000     0.011     0.011     0.011     0.011     0.014     1.451
      16,384  numpy          1,000     0.016     0.000     0.016     0.016     0.016     0.016     0.019     1.000

      65,536  jax              100     0.028     0.000     0.028     0.028     0.028     0.028     0.030     2.238
      65,536  numba            100     0.030     0.000     0.030     0.030     0.030     0.030     0.031     2.081
      65,536  pytorch          100     0.037     0.000     0.037     0.037     0.037     0.037     0.038     1.702
      65,536  aesara           100     0.047     0.000     0.047     0.047     0.047     0.047     0.049     1.352
      65,536  numpy            100     0.063     0.000     0.063     0.063     0.063     0.063     0.066     1.000

     262,144  numba            100     0.116     0.000     0.115     0.116     0.116     0.116     0.116     2.157
     262,144  jax              100     0.122     0.000     0.122     0.122     0.122     0.122     0.124     2.047
     262,144  pytorch          100     0.147     0.000     0.147     0.147     0.147     0.148     0.149     1.693
     262,144  aesara            10     0.179     0.000     0.179     0.179     0.179     0.179     0.180     1.393
     262,144  numpy             10     0.250     0.000     0.249     0.250     0.250     0.250     0.250     1.000

   1,048,576  numba             10     0.516     0.004     0.512     0.512     0.519     0.519     0.520     2.221
   1,048,576  jax               10     0.623     0.004     0.616     0.620     0.626     0.626     0.627     1.840
   1,048,576  pytorch           10     0.751     0.002     0.747     0.752     0.752     0.752     0.752     1.527
   1,048,576  aesara            10     0.851     0.001     0.850     0.850     0.851     0.851     0.852     1.348
   1,048,576  numpy             10     1.147     0.002     1.142     1.147     1.147     1.148     1.148     1.000

   4,194,304  numba             10     2.247     0.003     2.243     2.244     2.246     2.250     2.252     2.282
   4,194,304  jax               10     2.569     0.003     2.563     2.569     2.570     2.571     2.573     1.995
   4,194,304  aesara            10     3.773     0.002     3.769     3.772     3.774     3.776     3.776     1.359
   4,194,304  pytorch           10     3.797     0.022     3.751     3.784     3.797     3.815     3.826     1.350
   4,194,304  numpy             10     5.126     0.003     5.119     5.124     5.128     5.128     5.131     1.000

(time in wall seconds, less is better)

$ taskset -c 23 python run.py benchmarks/isoneutral_mixing/ -s 16777216

benchmarks.isoneutral_mixing
============================
Running on CPU

size          backend     calls     mean      stdev     min       25%       median    75%       max       Δ
------------------------------------------------------------------------------------------------------------------
  16,777,216  numba             10     9.239     0.098     9.042     9.178     9.295     9.301     9.327     2.614
  16,777,216  jax               10    10.134     0.006    10.125    10.131    10.134    10.137    10.144     2.383
  16,777,216  aesara            10    15.387     0.046    15.323    15.343    15.389    15.433    15.450     1.569
  16,777,216  pytorch           10    17.916     0.024    17.856    17.910    17.925    17.931    17.939     1.348
  16,777,216  numpy             10    24.148     0.018    24.117    24.137    24.147    24.161    24.179     1.000

(time in wall seconds, less is better)

GPU

$ for backend in cupy jax pytorch; do CUDA_VISIBLE_DEVICES="0" python run.py benchmarks/isoneutral_mixing/ --gpu -b $backend -b numpy; done

benchmarks.isoneutral_mixing
============================
Running on GPU

size          backend     calls     mean      stdev     min       25%       median    75%       max       Δ
------------------------------------------------------------------------------------------------------------------
       4,096  numpy          1,000     0.004     0.000     0.004     0.004     0.004     0.004     0.008     1.000
       4,096  cupy           1,000     0.011     0.000     0.010     0.010     0.011     0.011     0.014     0.401

      16,384  cupy           1,000     0.011     0.000     0.010     0.011     0.011     0.011     0.014     1.519
      16,384  numpy          1,000     0.016     0.000     0.016     0.016     0.016     0.016     0.020     1.000

      65,536  cupy             100     0.011     0.000     0.011     0.011     0.011     0.011     0.011     5.851
      65,536  numpy            100     0.063     0.001     0.063     0.063     0.063     0.063     0.071     1.000

     262,144  cupy             100     0.011     0.000     0.011     0.011     0.011     0.011     0.013    24.889
     262,144  numpy             10     0.273     0.005     0.257     0.274     0.274     0.275     0.276     1.000

   1,048,576  cupy              10     0.021     0.000     0.021     0.021     0.021     0.022     0.022    56.671
   1,048,576  numpy             10     1.211     0.004     1.208     1.209     1.209     1.211     1.222     1.000

   4,194,304  cupy              10     0.080     0.001     0.079     0.079     0.079     0.081     0.082    64.295
   4,194,304  numpy             10     5.133     0.003     5.127     5.133     5.134     5.135     5.138     1.000

(time in wall seconds, less is better)

benchmarks.isoneutral_mixing
============================
Running on GPU

size          backend     calls     mean      stdev     min       25%       median    75%       max       Δ
------------------------------------------------------------------------------------------------------------------
       4,096  jax            1,000     0.001     0.000     0.001     0.001     0.001     0.001     0.005     8.233
       4,096  numpy          1,000     0.005     0.000     0.005     0.005     0.005     0.005     0.009     1.000

      16,384  jax            1,000     0.001     0.000     0.001     0.001     0.001     0.001     0.003    25.151
      16,384  numpy          1,000     0.017     0.000     0.016     0.016     0.017     0.017     0.019     1.000

      65,536  jax              100     0.001     0.000     0.001     0.001     0.001     0.001     0.001    56.285
      65,536  numpy            100     0.065     0.001     0.063     0.065     0.065     0.065     0.068     1.000

     262,144  jax              100     0.004     0.000     0.004     0.004     0.004     0.004     0.007    67.464
     262,144  numpy             10     0.270     0.009     0.258     0.262     0.272     0.273     0.283     1.000

   1,048,576  jax               10     0.015     0.000     0.015     0.015     0.015     0.015     0.015    81.204
   1,048,576  numpy             10     1.241     0.001     1.239     1.240     1.241     1.242     1.242     1.000

   4,194,304  jax               10     0.057     0.000     0.057     0.057     0.057     0.057     0.057    91.099
   4,194,304  numpy             10     5.222     0.020     5.209     5.215     5.216     5.219     5.281     1.000

(time in wall seconds, less is better)

benchmarks.isoneutral_mixing
============================
Running on GPU

size          backend     calls     mean      stdev     min       25%       median    75%       max       Δ
------------------------------------------------------------------------------------------------------------------
       4,096  numpy          1,000     0.004     0.000     0.004     0.004     0.004     0.004     0.007     1.000
       4,096  pytorch        1,000     0.005     0.000     0.005     0.005     0.005     0.005     0.008     0.827

      16,384  pytorch        1,000     0.005     0.000     0.005     0.005     0.005     0.005     0.008     3.072
      16,384  numpy          1,000     0.016     0.000     0.016     0.016     0.016     0.016     0.020     1.000

      65,536  pytorch          100     0.006     0.000     0.006     0.006     0.006     0.006     0.006    11.134
      65,536  numpy            100     0.063     0.001     0.063     0.063     0.063     0.063     0.066     1.000

     262,144  pytorch          100     0.006     0.000     0.006     0.006     0.006     0.006     0.006    45.709
     262,144  numpy             10     0.271     0.009     0.250     0.267     0.274     0.276     0.285     1.000

   1,048,576  pytorch           10     0.016     0.001     0.016     0.016     0.016     0.016     0.018    74.789
   1,048,576  numpy             10     1.208     0.001     1.207     1.207     1.208     1.209     1.209     1.000

   4,194,304  pytorch           10     0.056     0.000     0.056     0.056     0.056     0.057     0.057    91.363
   4,194,304  numpy             10     5.139     0.002     5.135     5.136     5.139     5.141     5.141     1.000

(time in wall seconds, less is better)

Turbulent kinetic energy

This routine consists of some stencil operations and some linear algebra (a tridiagonal matrix solver), which cannot be vectorized.

CPU

$ taskset -c 23 python run.py benchmarks/turbulent_kinetic_energy/

benchmarks.turbulent_kinetic_energy
===================================
Running on CPU

size          backend     calls     mean      stdev     min       25%       median    75%       max       Δ
------------------------------------------------------------------------------------------------------------------
       4,096  jax            1,000     0.000     0.000     0.000     0.000     0.000     0.000     0.001     5.826
       4,096  numba          1,000     0.001     0.000     0.001     0.001     0.001     0.001     0.002     1.838
       4,096  pytorch        1,000     0.002     0.000     0.002     0.002     0.002     0.002     0.003     1.221
       4,096  numpy          1,000     0.002     0.000     0.002     0.002     0.002     0.002     0.003     1.000

      16,384  jax            1,000     0.002     0.000     0.002     0.002     0.002     0.002     0.003     3.914
      16,384  pytorch        1,000     0.004     0.000     0.004     0.004     0.004     0.004     0.005     1.862
      16,384  numba          1,000     0.004     0.000     0.004     0.004     0.004     0.004     0.005     1.802
      16,384  numpy          1,000     0.008     0.000     0.008     0.008     0.008     0.008     0.011     1.000

      65,536  jax              100     0.009     0.000     0.009     0.009     0.009     0.009     0.010     3.263
      65,536  pytorch          100     0.013     0.000     0.013     0.013     0.013     0.013     0.013     2.209
      65,536  numba            100     0.014     0.000     0.014     0.014     0.014     0.014     0.014     2.006
      65,536  numpy            100     0.029     0.000     0.028     0.028     0.028     0.029     0.029     1.000

     262,144  jax              100     0.042     0.000     0.042     0.042     0.042     0.042     0.043     2.569
     262,144  numba            100     0.047     0.000     0.047     0.047     0.047     0.048     0.048     2.295
     262,144  pytorch          100     0.050     0.000     0.050     0.050     0.050     0.050     0.051     2.171
     262,144  numpy             10     0.109     0.000     0.108     0.109     0.109     0.109     0.110     1.000

   1,048,576  numba             10     0.187     0.000     0.187     0.187     0.187     0.187     0.188     2.711
   1,048,576  jax               10     0.237     0.000     0.237     0.237     0.237     0.237     0.238     2.140
   1,048,576  pytorch           10     0.276     0.000     0.275     0.276     0.276     0.276     0.276     1.839
   1,048,576  numpy             10     0.507     0.001     0.506     0.507     0.507     0.508     0.508     1.000

   4,194,304  numba             10     0.689     0.002     0.686     0.687     0.690     0.691     0.693     3.022
   4,194,304  jax               10     1.043     0.001     1.042     1.043     1.043     1.044     1.044     1.997
   4,194,304  pytorch           10     1.314     0.003     1.310     1.312     1.314     1.317     1.318     1.585
   4,194,304  numpy             10     2.084     0.003     2.079     2.080     2.085     2.086     2.088     1.000

(time in wall seconds, less is better)

$ taskset -c 23 python run.py benchmarks/turbulent_kinetic_energy/ -s 16777216

benchmarks.turbulent_kinetic_energy
===================================
Running on CPU

size          backend     calls     mean      stdev     min       25%       median    75%       max       Δ
------------------------------------------------------------------------------------------------------------------
  16,777,216  numba             10     2.997     0.005     2.991     2.994     2.996     2.999     3.007     3.616
  16,777,216  jax               10     4.168     0.003     4.164     4.165     4.168     4.170     4.174     2.600
  16,777,216  pytorch           10     6.270     0.009     6.249     6.266     6.270     6.277     6.282     1.729
  16,777,216  numpy             10    10.839     0.011    10.823    10.829    10.836    10.850    10.853     1.000

(time in wall seconds, less is better)

GPU

$ for backend in jax pytorch; do CUDA_VISIBLE_DEVICES="0" python run.py benchmarks/turbulent_kinetic_energy/ --gpu -b $backend -b numpy; done

benchmarks.turbulent_kinetic_energy
===================================
Running on GPU

size          backend     calls     mean      stdev     min       25%       median    75%       max       Δ
------------------------------------------------------------------------------------------------------------------
       4,096  jax            1,000     0.001     0.000     0.000     0.001     0.001     0.001     0.008     4.261
       4,096  numpy          1,000     0.003     0.000     0.002     0.002     0.003     0.003     0.010     1.000

      16,384  jax            1,000     0.001     0.001     0.001     0.001     0.001     0.001     0.008    11.252
      16,384  numpy          1,000     0.008     0.000     0.008     0.008     0.008     0.008     0.013     1.000

      65,536  jax              100     0.001     0.001     0.001     0.001     0.001     0.001     0.007    26.405
      65,536  numpy            100     0.030     0.002     0.029     0.029     0.029     0.029     0.037     1.000

     262,144  jax              100     0.003     0.001     0.003     0.003     0.003     0.003     0.006    46.720
     262,144  numpy             10     0.135     0.007     0.115     0.134     0.136     0.137     0.143     1.000

   1,048,576  jax               10     0.016     0.000     0.015     0.016     0.016     0.016     0.016    36.351
   1,048,576  numpy             10     0.579     0.008     0.567     0.571     0.584     0.585     0.588     1.000

   4,194,304  jax               10     0.039     0.000     0.039     0.039     0.039     0.039     0.039    55.896
   4,194,304  numpy             10     2.190     0.032     2.149     2.152     2.212     2.218     2.220     1.000

(time in wall seconds, less is better)

benchmarks.turbulent_kinetic_energy
===================================
Running on GPU

size          backend     calls     mean      stdev     min       25%       median    75%       max       Δ
------------------------------------------------------------------------------------------------------------------
       4,096  numpy          1,000     0.002     0.000     0.002     0.002     0.002     0.002     0.005     1.000
       4,096  pytorch        1,000     0.003     0.000     0.003     0.003     0.003     0.003     0.004     0.834

      16,384  pytorch        1,000     0.003     0.000     0.003     0.003     0.003     0.003     0.005     2.455
      16,384  numpy          1,000     0.008     0.000     0.008     0.008     0.008     0.008     0.008     1.000

      65,536  pytorch          100     0.004     0.000     0.004     0.004     0.004     0.004     0.004     7.323
      65,536  numpy            100     0.029     0.000     0.029     0.029     0.029     0.029     0.030     1.000

     262,144  pytorch          100     0.005     0.000     0.005     0.005     0.005     0.005     0.005    23.533
     262,144  numpy             10     0.111     0.000     0.110     0.111     0.111     0.111     0.111     1.000

   1,048,576  pytorch           10     0.008     0.000     0.008     0.008     0.008     0.008     0.008    72.466
   1,048,576  numpy             10     0.573     0.003     0.567     0.571     0.574     0.574     0.576     1.000

   4,194,304  pytorch           10     0.029     0.000     0.029     0.029     0.030     0.030     0.030    73.957
   4,194,304  numpy             10     2.175     0.002     2.172     2.174     2.175     2.176     2.177     1.000

(time in wall seconds, less is better)