speed_compare_cuda_ptx.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. import itertools
  2. from tinygrad import Device
  3. from tinygrad.engine.realize import CompiledRunner
  4. from tinygrad.helpers import getenv, colorize_float
  5. from extra.optimization.helpers import load_worlds, ast_str_to_lin
  6. from tinygrad.engine.search import bufs_from_lin
  7. from tinygrad.runtime.ops_cuda import PTXCompiler, PTXRenderer, CUDACompiler
  8. if __name__ == "__main__":
  9. ast_strs = load_worlds(filter_reduce=False, filter_novariable=True)
  10. # no bfloat16 for ptx at the moment
  11. ast_strs = [x for x in ast_strs if "dtypes.bfloat16" not in x]
  12. dev = Device["CUDA"]
  13. ptx = PTXRenderer(dev.arch)
  14. # NUM=112 python3 test/external/speed_compare_cuda_ptx.py
  15. single = getenv("NUM", -1)
  16. if single != -1: ast_strs = ast_strs[single:single+1]
  17. average_tm_cuda, average_tm_ptx = 0, 0
  18. for num,ast in enumerate(ast_strs):
  19. # cuda compile
  20. dev.compiler = CUDACompiler(dev.arch)
  21. lin = ast_str_to_lin(ast, opts=dev.renderer)
  22. lin.hand_coded_optimizations()
  23. cuda_prg = CompiledRunner(lin.to_program())
  24. bufs = bufs_from_lin(lin)
  25. # ptx compile
  26. dev.compiler = PTXCompiler(dev.arch)
  27. lin = ast_str_to_lin(ast, opts=ptx)
  28. lin.hand_coded_optimizations()
  29. lin.linearize()
  30. ptx_prg = CompiledRunner(lin.to_program())
  31. # warmup
  32. try:
  33. cuda_prg(bufs, {}, wait=True)
  34. except RuntimeError:
  35. print("cuda failed ast:", num)
  36. continue
  37. ptx_prg(bufs, {}, wait=True)
  38. tm_cuda, tm_ptx = [], []
  39. for i in range(5):
  40. tm_cuda.append(cuda_prg(bufs, {}, wait=True))
  41. tm_ptx.append(ptx_prg(bufs, {}, wait=True))
  42. average_tm_cuda += min(tm_cuda)
  43. average_tm_ptx += min(tm_ptx)
  44. ratio = min(tm_ptx)/min(tm_cuda)
  45. print(f"{average_tm_ptx/average_tm_cuda:5.2f}x -- {num:4d} {colorize_float(ratio)} {min(tm_ptx)*1e6:7.2f} us", lin.name)
  46. if ratio > 1.5:
  47. def fix(x): return x.replace('\t', ' ').strip()
  48. ll1, ll2 = cuda_prg.lib.decode().split('\n'), ptx_prg.lib.decode().split('\n')
  49. if single != -1:
  50. for ln, (l1, l2) in enumerate(itertools.zip_longest(ll1, ll2, fillvalue='')):
  51. print(f"{ln:5d} | {fix(l1):80s} | {fix(l2):80s}")
  52. print(len(ll1), len(ll2), "RATIO", ratio, "us", min(tm_ptx)*1e6)