mlx_matmul.py 233 B

12345678910
  1. import mlx.core as mx
  2. from tinygrad.helpers import Timing
  3. N = 4096
  4. x = mx.random.normal((N,N))
  5. w = mx.random.normal((N,N))
  6. FLOPS = N*N*N*2
  7. for i in range(10):
  8. with Timing("", lambda x: f" {FLOPS/x:.2f} GFLOPS"):
  9. mx.eval(x@w)