torch_gemm.py 536 B

1234567891011121314151617
  1. import time
  2. import torch
  3. for dtype in [torch.float16, torch.float32]:
  4. for N in [256, 512, 1024, 2048, 4096]:
  5. FLOPS = N*N*N*2
  6. b = torch.rand((N,N), dtype=dtype).cuda()
  7. c = torch.rand((N,N), dtype=dtype).cuda()
  8. def torch_prog(b, c):
  9. st = time.perf_counter()
  10. a = b@c
  11. torch.cuda.synchronize()
  12. return time.perf_counter() - st
  13. tm = min([torch_prog(b, c) for _ in range(20)])
  14. print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS {N:4d}x{N:4d}x{N:4d} matmul in {dtype}")