1
0

fuzz_graph.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. import random, ctypes
  2. import numpy as np
  3. from tinygrad.device import Buffer, Device
  4. from tinygrad.helpers import Context, getenv, from_mv
  5. from tinygrad.dtype import dtypes
  6. from tinygrad.tensor import Tensor, _to_np_dtype
  7. from tinygrad.engine.schedule import create_schedule
  8. from tinygrad.engine.realize import ExecItem, BufferXfer, get_runner
  9. from tinygrad.engine.jit import apply_graph_to_jit
  10. BUF_LEN = getenv("BUF_LEN", 128)
  11. cached_prgs = {}
  12. def gen_prg(device, inputs_cnt):
  13. if (device, inputs_cnt) in cached_prgs: return cached_prgs[(device, inputs_cnt)]
  14. with Context(DEBUG=0):
  15. fst = [Tensor.randn(BUF_LEN, dtype=dtypes.int).realize() for i in range(inputs_cnt)]
  16. s = fst[0]
  17. for i in range(1, inputs_cnt): s = s.xor(fst[i])
  18. si = create_schedule([s.lazydata])[-1]
  19. prg = get_runner(device, si.ast)
  20. cached_prgs[(device, inputs_cnt)] = prg
  21. return prg
  22. def alloc_rawbuffer(device, fill=False):
  23. rawbuf = Buffer(device, BUF_LEN, dtypes.int).ensure_allocated()
  24. if fill:
  25. with Context(DEBUG=0):
  26. data = np.random.randint(-10000, 10000, size=rawbuf.size, dtype=_to_np_dtype(rawbuf.dtype))
  27. rawbuf.copyin(Tensor(data).realize().lazydata.realized.as_buffer())
  28. return rawbuf
  29. def gen_kernel_ji(device, deps):
  30. assert len(deps) >= 2
  31. out = alloc_rawbuffer(device)
  32. prg = gen_prg(device, len(deps))
  33. return ExecItem(prg, [out] + deps)
  34. def gen_copy_ji(device, deps):
  35. assert len(deps) == 1
  36. out = alloc_rawbuffer(device)
  37. prg = BufferXfer(deps[0].nbytes, device, deps[0].device)
  38. return ExecItem(prg, [out] + deps)
  39. def gen_graph():
  40. input_buffers = []
  41. all_buffers = []
  42. jis = []
  43. last_n_deps = getenv("LAST_N_DEPS", 0)
  44. kernel_count = random.randint(2, getenv("MAX_KERNELS", 128))
  45. for i in range(kernel_count):
  46. target_device_id = random.randint(0, getenv("MAX_DEVICES", 6) - 1)
  47. target_device = f"{Device.DEFAULT}:{target_device_id}"
  48. is_copy = random.randint(0, 10) < 3
  49. if is_copy:
  50. deps_pool = [buf for buf in all_buffers[-last_n_deps:] if buf.device != target_device]
  51. if len(deps_pool) == 0: deps = []
  52. else: deps = random.sample(deps_pool, 1)
  53. else:
  54. deps_pool = [buf for buf in all_buffers[-last_n_deps:] if buf.device == target_device]
  55. deps_count = random.randint(0, min(getenv("MAX_DEPS_COUNT", 6), len(deps_pool)))
  56. if deps_count == 0: deps = []
  57. else: deps = random.sample(deps_pool, deps_count)
  58. if len(deps) == 0 or (not is_copy and len(deps) < 2):
  59. buf = alloc_rawbuffer(target_device, fill=True)
  60. input_buffers.append(buf)
  61. all_buffers.append(buf)
  62. elif is_copy:
  63. jis.append(gen_copy_ji(target_device, deps))
  64. all_buffers.append(jis[-1].bufs[0])
  65. else:
  66. jis.append(gen_kernel_ji(target_device, deps))
  67. all_buffers.append(jis[-1].bufs[0])
  68. return jis, all_buffers, input_buffers
  69. def run_jit(jis, all_buffers, input_buffers, var_vals):
  70. with Context(DEBUG=0):
  71. for rawbuf in all_buffers:
  72. if rawbuf in input_buffers: continue
  73. mv = memoryview(bytearray(rawbuf.size * rawbuf.dtype.itemsize))
  74. ctypes.memset(from_mv(mv), 0, len(mv))
  75. rawbuf.copyin(mv)
  76. for ei in jis: ei.run(var_vals, jit=True)
  77. with Context(DEBUG=0):
  78. res_buffers = []
  79. for rawbuf in all_buffers: res_buffers.append(rawbuf.as_buffer())
  80. return res_buffers
  81. def fuzz_graph(jis, all_buffers, input_buffers):
  82. ground_thruth_bufs = run_jit(jis, input_buffers, all_buffers, {})
  83. ground_truth_np = [np.frombuffer(x, _to_np_dtype(all_buffers[i].dtype)) for i,x in enumerate(ground_thruth_bufs)]
  84. for _ in range(getenv("FUZZ_GRAPH_SPLIT_RUNS", 64)):
  85. max_split_points = len(jis) // 3
  86. split_points = random.randint(0, min(max_split_points, getenv("FUZZ_GRAPH_MAX_SPLITS", 8)))
  87. split = [0]
  88. for i in range(split_points - 1):
  89. split.append(random.randint(split[-1] + 2, len(jis) - 2 * (max_split_points - i)))
  90. split.append(len(jis))
  91. graphed_jit = []
  92. for sp in range(len(split)-1):
  93. graphed_jit += apply_graph_to_jit(jis[split[sp]:split[sp+1]], [], {})
  94. for _ in range(getenv("FUZZ_GRAPH_SPLIT_RETRY_RUNS", 4)):
  95. test_bufs = run_jit(graphed_jit, input_buffers, all_buffers, {})
  96. test_bufs_np = [np.frombuffer(x, _to_np_dtype(all_buffers[i].dtype)) for i,x in enumerate(test_bufs)]
  97. for i in range(len(ground_thruth_bufs)): np.testing.assert_equal(ground_truth_np[i], test_bufs_np[i])
  98. if __name__ == "__main__":
  99. SEED = getenv("SEED", 42)
  100. random.seed(SEED)
  101. np.random.seed(SEED)
  102. next_graph_id = 0
  103. while True:
  104. print("Running graph", next_graph_id)
  105. jis, all_buffers, input_buffers = gen_graph()
  106. fuzz_graph(jis, all_buffers, input_buffers)
  107. next_graph_id += 1