fuzz_linearizer.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. import random, traceback, ctypes, argparse
  2. from typing import List, Tuple, DefaultDict
  3. import numpy as np
  4. from collections import defaultdict
  5. from extra.optimization.helpers import load_worlds, ast_str_to_lin, kern_str_to_lin
  6. from tinygrad import Tensor, Device, dtypes
  7. from tinygrad.tensor import _to_np_dtype
  8. from tinygrad.codegen.kernel import Kernel
  9. from tinygrad.codegen.uops import UOp
  10. from tinygrad.codegen.kernel import Opt, OptOps
  11. from tinygrad.engine.search import get_kernel_actions, bufs_from_lin
  12. from tinygrad.engine.graph import print_tree
  13. from tinygrad.engine.realize import CompiledRunner
  14. from tinygrad.helpers import getenv, from_mv, prod, colored, Context, DEBUG
  15. from tinygrad.ops import LazyOp, UnaryOps, BufferOps
  16. from test.helpers import is_dtype_supported
  17. def tuplize_uops(uops:List[UOp]) -> Tuple:
  18. return tuple([(x.op, x.dtype, tuple(uops.index(x) for x in x.src), x.arg) for x in uops])
  19. device = Device[Device.DEFAULT]
  20. def get_fuzz_rawbufs(lin):
  21. rawbufs = bufs_from_lin(lin)
  22. # Reallocate output buffer with additional area to detect out-of-bounds writes.
  23. RED_AREA_SIZE = 1024
  24. # setting output # TODO: multi-output kernel
  25. rawbufs[0] = get_fuzz_rawbuf_like(rawbufs[0], zero=True, size=rawbufs[0].size+RED_AREA_SIZE)
  26. # setting inputs
  27. with Context(DEBUG=0):
  28. for rawbuf in rawbufs[1:]:
  29. if dtypes.is_unsigned(rawbuf.dtype):
  30. data = np.random.randint(0, 100, size=rawbuf.size, dtype=_to_np_dtype(rawbuf.dtype))
  31. elif dtypes.is_int(rawbuf.dtype):
  32. data = np.random.randint(-100, 100, size=rawbuf.size, dtype=_to_np_dtype(rawbuf.dtype))
  33. elif rawbuf.dtype == dtypes.bool:
  34. data = np.random.choice([True, False], size=rawbuf.size)
  35. elif rawbuf.dtype == dtypes.half:
  36. data = np.random.uniform(-1, 1, size=rawbuf.size).astype(dtype=_to_np_dtype(rawbuf.dtype))
  37. else:
  38. data = np.random.uniform(-10, 10, size=rawbuf.size).astype(dtype=_to_np_dtype(rawbuf.dtype))
  39. rawbuf.copyin(Tensor(data).realize().lazydata.realized.as_buffer())
  40. return rawbufs
  41. def get_fuzz_rawbuf_like(rawbuf, zero=False, size=None):
  42. rawbuf = type(rawbuf)(Device.DEFAULT, rawbuf.size if size is None else size, rawbuf.dtype).allocate()
  43. if zero:
  44. with Context(DEBUG=0):
  45. mv = memoryview(bytearray(rawbuf.size * rawbuf.dtype.itemsize))
  46. ctypes.memset(from_mv(mv), 0, len(mv))
  47. rawbuf.copyin(mv)
  48. return rawbuf
  49. def run_linearizer(lin: Kernel, rawbufs=None, var_vals=None):
  50. if rawbufs is None: rawbufs = bufs_from_lin(lin)
  51. if var_vals is None: var_vals = {v: v.min for v in lin.ast[0].vars()}
  52. # TODO: images needs required_optimization
  53. try:
  54. prg = CompiledRunner(lin.to_program())
  55. except Exception:
  56. traceback.print_exc()
  57. return "COMPILE_ERROR"
  58. try:
  59. prg(rawbufs, var_vals, wait=True)
  60. except Exception:
  61. traceback.print_exc()
  62. return "EXEC_ERROR"
  63. return "PASS"
  64. def compare_linearizer(lin: Kernel, rawbufs=None, var_vals=None, ground_truth=None, rtol=1e-2, atol=1e-2):
  65. # TODO: for bfloat16 it compiles linearizer, but it does not run because numpy cannot generate bf16 buffer.
  66. has_bf16 = any(b.dtype == dtypes.bfloat16 for b in lin.membufs)
  67. # TODO: raise specific fuzzing errors instead of str, and propagate the error message
  68. try:
  69. if rawbufs is None:
  70. rawbufs = get_fuzz_rawbufs(lin)
  71. else:
  72. rawbufs[0] = get_fuzz_rawbuf_like(rawbufs[0], zero=True) # get a new output buffer
  73. except BaseException:
  74. return ("RAWBUFS_ERROR", rawbufs, var_vals, ground_truth,)
  75. if var_vals is None:
  76. # TODO: handle symbolic max case
  77. var_vals = {v: random.randint(v.min, v.max if isinstance(v.max, int) else v.min) for v in lin.ast.vars()}
  78. if ground_truth is None and not has_bf16:
  79. unoptimized = Kernel(lin.ast)
  80. unoptimized.required_optimizations()
  81. if run_linearizer(unoptimized, rawbufs, var_vals) != "PASS":
  82. return ("BASELINE_ERROR", rawbufs, var_vals, ground_truth,)
  83. ground_truth = np.frombuffer(rawbufs[0].as_buffer(), _to_np_dtype(rawbufs[0].dtype)).copy()
  84. rawbufs[0] = get_fuzz_rawbuf_like(rawbufs[0], zero=True) # get a new output buffer
  85. if (run_msg := run_linearizer(lin, rawbufs, var_vals)) != "PASS":
  86. return (run_msg, rawbufs, var_vals, ground_truth,)
  87. try:
  88. if not has_bf16:
  89. result = np.frombuffer(rawbufs[0].as_buffer(), _to_np_dtype(rawbufs[0].dtype))
  90. np.testing.assert_allclose(result, ground_truth, rtol=rtol, atol=atol)
  91. except AssertionError as e:
  92. if DEBUG >= 2:
  93. print(f"COMPARE_ERROR details: {e}")
  94. if getenv("DEBUG_VALUES") > 0:
  95. mismatch_indices = np.where(~np.isclose(result, ground_truth, rtol=rtol, atol=atol))
  96. mismatched_result = result[mismatch_indices]
  97. mismatched_ground_truth = ground_truth[mismatch_indices]
  98. for i, idx in enumerate(mismatch_indices[0]):
  99. print(f"mismatch at {idx=}: result={mismatched_result[i]} <> ground_truth={mismatched_ground_truth[i]}")
  100. return ("COMPARE_ERROR", rawbufs, var_vals, ground_truth,)
  101. return ("PASS", rawbufs, var_vals, ground_truth,)
  102. def fuzz_linearizer(lin: Kernel, rtol=1e-2, atol=1e-2):
  103. SEED = getenv("SEED", 42)
  104. random.seed(SEED)
  105. np.random.seed(SEED)
  106. print_tree(lin.ast)
  107. print(lin.colored_shape())
  108. seen_uops = {}
  109. last_lins = [lin]
  110. failures:DefaultDict[str, List[Tuple[Tuple[LazyOp,...],List[Opt]]]] = defaultdict(list)
  111. rawbufs, var_vals, ground_truth = None, None, None
  112. FUZZ_ALL_ACTIONS = getenv("FUZZ_ALL_ACTIONS", 0)
  113. FUZZ_MAX_SIZE = getenv("FUZZ_MAX_SIZE", 0)
  114. FUZZ_IGNORE_SIMPLE_OPS = getenv("FUZZ_IGNORE_SIMPLE_OPS", 1)
  115. if FUZZ_MAX_SIZE > 0 and prod(lin.full_shape) > FUZZ_MAX_SIZE:
  116. print("skipping large kernel")
  117. return failures
  118. if FUZZ_IGNORE_SIMPLE_OPS and _is_simple(lin):
  119. print("skipping simple kernel")
  120. return failures
  121. for depth in range(getenv("DEPTH", 1 if FUZZ_ALL_ACTIONS else 10)):
  122. next_lins = []
  123. for lin in last_lins:
  124. actions = get_kernel_actions(lin, include_0=False)
  125. if not actions: continue
  126. if depth == 0 and getenv("FUZZ_REQUIRE_TC", 0):
  127. tc_acts = {i: k for k in actions.values() if k.applied_opts[0].op == OptOps.TC}
  128. if len(tc_acts) == 0: return failures
  129. else: actions = tc_acts
  130. test_lins = list(actions.values())
  131. if FUZZ_ALL_ACTIONS: print(f"testing {lin.applied_opts=} with {len(actions)} actions")
  132. else: test_lins = [random.choice(test_lins)]
  133. for test_lin in test_lins:
  134. if not FUZZ_ALL_ACTIONS and test_lin.applied_opts: print(f"applied opts: {test_lin.applied_opts}")
  135. # stop if kernel uops repeat
  136. try: tuops = tuplize_uops(test_lin.linearize().uops.uops)
  137. except BaseException as e:
  138. print(test_lin.ast)
  139. print(test_lin.applied_opts)
  140. print(e)
  141. failures["LINEARIZE_ERROR"].append((test_lin.ast, test_lin.applied_opts))
  142. continue
  143. if tuops in seen_uops: continue
  144. seen_uops[tuops] = tuple(test_lin.applied_opts)
  145. if not FUZZ_ALL_ACTIONS: print(test_lin.colored_shape())
  146. (msg, rawbufs, var_vals, ground_truth) = compare_linearizer(test_lin, rawbufs, var_vals, ground_truth, rtol=rtol, atol=atol)
  147. if msg != "PASS":
  148. print(test_lin.ast)
  149. print(test_lin.applied_opts)
  150. print(msg)
  151. failures[msg].append((test_lin.ast, test_lin.applied_opts))
  152. continue
  153. next_lins.append(test_lin)
  154. last_lins = next_lins
  155. if FUZZ_ALL_ACTIONS: print(f"depth={depth} total_lins={len(last_lins)} {failures=}")
  156. return failures
  157. def _is_simple(lin: Kernel) -> bool:
  158. if len(lin.ast.src) > 1: return False
  159. ast:LazyOp = lin.ast.src[0]
  160. if ast.src[0] and ast.src[0].op is UnaryOps.CAST and ast.src[0].src[0] and ast.src[0].src[0].op is BufferOps.LOAD: return True
  161. return False
  162. if __name__ == "__main__":
  163. parser = argparse.ArgumentParser(description="Run a fuzz testing on one or more kernels", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  164. parser.add_argument("--ast", type=str, default=None, help="the ast for the kernel to be optimized")
  165. parser.add_argument("--file", type=str, default=None, help="a file containing asts to be optimized, one per line")
  166. parser.add_argument("--logfile", type=str, default=None, help="a file containing a tuple of ast and applied_opts, one per line")
  167. parser.add_argument("--expected-failures", type=int, default=0, help="the number of expected failed kernels")
  168. parser.add_argument("--rtol", type=float, default=1e-2, help="relative tolerance for numerical comparison")
  169. parser.add_argument("--atol", type=float, default=1e-2, help="absolute tolerance for numerical comparison")
  170. args = parser.parse_args()
  171. if args.ast is not None:
  172. print("loaded AST from CLI")
  173. ast_strs = [args.ast]
  174. elif args.file is not None:
  175. print(f"loading ASTs from file '{args.file}'")
  176. with open(args.file, 'r') as file:
  177. ast_strs = file.readlines()
  178. elif args.logfile is not None:
  179. print(f"loading ASTs from LOGKERNS file '{args.file}'")
  180. with open(args.logfile, 'r') as file:
  181. kern_strs = file.readlines()
  182. test_lins = [kern_str_to_lin(kern_str) for kern_str in kern_strs]
  183. ast_strs = [f"{lin.ast}" for lin in test_lins]
  184. else:
  185. print("loading ASTs from world")
  186. ast_strs = load_worlds(filter_reduce=False, filter_novariable=False)
  187. print(f"{len(ast_strs)=}")
  188. tested = 0
  189. failed_ids = []
  190. failures = defaultdict(list)
  191. seen_ast_strs = set()
  192. for i, ast in enumerate(ast_strs[:getenv("FUZZ_N", len(ast_strs))]):
  193. if (nth := getenv("FUZZ_NTH", -1)) != -1 and i != nth: continue
  194. if "dtypes.image" in ast and Device.DEFAULT != "GPU": continue # IMAGE is only for GPU
  195. if ast in seen_ast_strs: continue
  196. seen_ast_strs.add(ast)
  197. lin = ast_str_to_lin(ast)
  198. if not all(is_dtype_supported(buf.dtype) for buf in lin.bufs):
  199. print("skipping kernel due to not supported dtype")
  200. continue
  201. print(f"testing ast {i}")
  202. tested += 1
  203. fuzz_failures = fuzz_linearizer(lin, rtol=args.rtol, atol=args.atol)
  204. if fuzz_failures: failed_ids.append(i)
  205. for k, v in fuzz_failures.items():
  206. for f in v:
  207. failures[k].append(f)
  208. for msg, errors in failures.items():
  209. for i, (ast, opts) in enumerate(errors):
  210. print(f"{msg} {i} kernel: {(ast,opts)}") # easier to use with output with verify_kernel.py
  211. print(f"{tested=}")
  212. if failures:
  213. print(f"{failed_ids=}")
  214. for msg, errors in failures.items():
  215. print(f"{msg}: {len(errors)}")
  216. if len(failed_ids) == args.expected_failures:
  217. print(colored(f"{len(failed_ids)} failed as expected", "yellow"))
  218. if len(failed_ids) != args.expected_failures:
  219. raise RuntimeError(f"failed on {len(failed_ids)} kernels, expected {args.expected_failures}")
  220. else:
  221. print(colored("all passed", "green"))