gemm.py 599 B

12345678910111213141516171819202122232425262728
  1. #!/usr/bin/env python3
  2. import os
  3. #os.environ['OMP_NUM_THREADS'] = '1'
  4. import time
  5. import numpy as np
  6. N = 512
  7. if __name__ == "__main__":
  8. # N^2
  9. A = np.random.randn(N, N).astype(np.float32)
  10. # N^2
  11. B = np.random.randn(N, N).astype(np.float32)
  12. # 2N compute in N^2 output cells
  13. flop = 2*N*N*N
  14. #print(f"{flop / 1e9:.2f} GFLOP")
  15. for i in range(10):
  16. st = time.monotonic()
  17. C = A @ B.T
  18. et = time.monotonic()
  19. s = et-st
  20. print(f"{flop/s * 1e-9:.2f} GFLOP/S, {s*1e3:.2f} ms")
  21. with open("/tmp/matmul", "wb") as f:
  22. f.write(A.data)
  23. f.write(B.data)
  24. f.write(C.data)