archprobe.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. # copying the kernels from https://github.com/microsoft/ArchProbe into Python
  2. import numpy as np
  3. import pickle
  4. from tinygrad.runtime.ops_gpu import CLProgram, CLBuffer
  5. from tinygrad import dtypes
  6. from tqdm import trange, tqdm
  7. from matplotlib import pyplot as plt
  8. tests = {}
  9. def register_test(fxn):
  10. tests[fxn.__name__] = fxn
  11. def warp_size2(nthread):
  12. prg = """__kernel void warp_size2(
  13. __global float* src,
  14. __global int* dst,
  15. const int niter,
  16. const int prime_number
  17. ) {
  18. int drain = 0;
  19. for (int j = 0; j < niter; ++j) {
  20. drain += j / prime_number;
  21. barrier(0);
  22. }
  23. dst[get_local_id(0)] = drain;
  24. }"""
  25. src_buf = CLBuffer(1, dtypes.float32)
  26. dst_buf = CLBuffer(1, dtypes.int32)
  27. cl = CLProgram("warp_size2", prg, argdtypes=[None, None, np.int32, np.int32])
  28. return min([cl([nthread, 1024, 1], [nthread, 1, 1], src_buf, dst_buf, 10, 3, wait=True) for _ in range(5)])*1e9
  29. @register_test
  30. def test_warp_size():
  31. return [(nthread, warp_size2(nthread)) for nthread in trange(1,256)]
  32. def reg_count(nthread, ngrp, nreg):
  33. reg_declr = ''.join([f"float reg_data{i} = (float)niter + {i};\n" for i in range(nreg)])
  34. reg_comp = ''.join([f"reg_data{i} *= {(i-1)%nreg};\n" for i in range(nreg)])
  35. reg_reduce = ''.join([f"out_buf[{i}] = reg_data{i};\n" for i in range(nreg)])
  36. prg = f"""__kernel void reg_count(
  37. __global float* out_buf,
  38. __private const int niter
  39. ) {{
  40. {reg_declr}
  41. int i = 0;
  42. for (; i < niter; ++i) {{
  43. {reg_comp}
  44. }}
  45. i = i >> 31;
  46. {reg_reduce}
  47. }}"""
  48. out_buf = CLBuffer(1, dtypes.float32)
  49. cl = CLProgram("reg_count", prg, argdtypes=[None, np.int32])
  50. return min([cl([nthread, ngrp, 1], [nthread, 1, 1], out_buf, 20, wait=True) for _ in range(10)])*1e9
  51. @register_test
  52. def test_reg_count(nthread=1, ngrp=1):
  53. base = reg_count(nthread, ngrp, 1)
  54. return [(nreg, (reg_count(nthread, ngrp, nreg)-base)/nreg) for nreg in trange(4, 513, 4)]
  55. def buf_cache_hierarchy_pchase(ndata, stride=1, NCOMP=1, steps=65536):
  56. ndata //= NCOMP*4 # ptr size
  57. prg = f"""__kernel void buf_cache_hierarchy_pchase(
  58. __global int{str(NCOMP) if NCOMP > 1 else ''}* src,
  59. __global int* dst,
  60. const int niter
  61. ) {{
  62. int idx = 0;
  63. for (int i = 0; i < niter; ++i) {{
  64. idx = src[idx]{'.x' if NCOMP > 1 else ''};
  65. }}
  66. *dst = idx;
  67. }}"""
  68. idx_buf = np.zeros(ndata*NCOMP, dtype=np.int32)
  69. for i in range(ndata): idx_buf[i*NCOMP] = (i + stride) % ndata
  70. in_buf = CLBuffer.fromCPU(idx_buf)
  71. out_buf = CLBuffer(1, dtypes.int32)
  72. cl = CLProgram("buf_cache_hierarchy_pchase", prg, argdtypes=[None, None, np.int32])
  73. return min([cl([1, 1, 1], [1, 1, 1], in_buf, out_buf, steps, wait=True)/steps for _ in range(5)])*1e9
  74. @register_test
  75. def test_memory_latency():
  76. # requires cacheline < 16
  77. szs = [int(1.3**x) for x in range(20, 70)]
  78. return [(ndata, buf_cache_hierarchy_pchase(ndata, NCOMP=16, steps=128*1024)) for ndata in tqdm(szs)]
  79. @register_test
  80. def test_cacheline_size():
  81. # TODO: this buffer must be at least 2x the L1 cache for this test to work
  82. return [(stride, buf_cache_hierarchy_pchase(4*65536, stride, steps=65536)) for stride in trange(1,64)]
  83. def cl_read(sz, niter=1):
  84. prg = f"""__kernel void copy(
  85. __global float4* src,
  86. __global float* dst) {{
  87. int gid = get_global_id(0);
  88. if (src[gid].x == 99+get_global_id(1)) *dst = 1;
  89. }}"""
  90. in_buf = CLBuffer(sz//4, dtypes.float32)
  91. out_buf = CLBuffer(1, dtypes.float32)
  92. cl = CLProgram("copy", prg)
  93. # NOTE: if nay of the niters form a local group, this is wrong
  94. return min([cl([sz//16, niter, 1], [1, 1, 1], in_buf, out_buf, wait=True) for _ in range(10)])*1e9
  95. @register_test
  96. def test_read_bandwidth():
  97. szs = list(range(128*1024, 20*1024*1024, 128*1024))
  98. NITER = 8
  99. base = cl_read(16, niter=NITER)
  100. return [(sz, (sz*NITER)/(cl_read(sz, niter=NITER)-base)) for sz in tqdm(szs)]
  101. def gflops(niter=4, nroll=4, ngroups=4096):
  102. NCOMP = 8
  103. prg = f"""__kernel void gflops(
  104. __global float* out_buf
  105. ) {{
  106. float{NCOMP} x = (float{NCOMP})({",".join(f"get_local_id(0)+{i}" for i in range(NCOMP))});
  107. float{NCOMP} y = (float{NCOMP})({",".join(f"get_local_id(1)+{i}" for i in range(NCOMP))});
  108. for (int i = 0; i < {niter}; i++) {{
  109. {''.join(['x = mad(y, y, x); y = mad(x, x, y);'+chr(10)]*nroll)}
  110. }}
  111. out_buf[get_global_id(0) >> 31] = {'+'.join(f"y.s{'0123456789abcdef'[i]}" for i in range(NCOMP))};
  112. }}"""
  113. out_buf = CLBuffer(1, dtypes.float32)
  114. cl = CLProgram("gflops", prg, options="-cl-mad-enable -cl-fast-relaxed-math")
  115. FLOPS = NCOMP*2*2 * niter * nroll * ngroups * 32
  116. # NOTE: if nay of the niters form a local group, this is wrong
  117. return FLOPS/(min([cl([32, ngroups, 1], [32, 1, 1], out_buf, wait=True) for _ in range(10)])*1e9)
  118. @register_test
  119. def test_gflops():
  120. return [(niter, gflops(niter=niter, nroll=32)) for niter in trange(1, 32, 1)]
  121. if __name__ == "__main__":
  122. cache = {}
  123. #cache = pickle.load(open("/tmp/cache.pkl", "rb"))
  124. #tests = {"test_cacheline_size": tests["test_cacheline_size"]}
  125. plt.figure(figsize=(16, 9))
  126. for i,(k,test) in enumerate(tests.items()):
  127. print(f"running {k}")
  128. plt.subplot(2, (len(tests)+1)//2, i+1)
  129. plt.title(k)
  130. if k == "test_memory_latency": plt.xscale('log')
  131. if k not in cache: cache[k] = test()
  132. plt.plot(*zip(*cache[k]))
  133. #pickle.dump(cache, open("/tmp/cache.pkl", "wb"))
  134. plt.tight_layout(pad=0.5)
  135. plt.savefig("/tmp/results.png")
  136. plt.show()