1
0

verify_kernel.py 3.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. import argparse
  2. from collections import defaultdict
  3. from extra.optimization.helpers import kern_str_to_lin
  4. from test.external.fuzz_linearizer import compare_linearizer
  5. from tinygrad.helpers import colored
  6. from tinygrad.codegen.kernel import Kernel
  7. from tinygrad.engine.graph import print_tree
  8. from tinygrad.engine.search import time_linearizer
  9. # Use this with the LOGKERNS options to verify that all executed kernels are valid and evaluate to the same ground truth results
  10. # Example for GPT2:
  11. # 1) Run the model to log all kernels: `PYTHONPATH=. LOGKERNS=/tmp/gpt2_kerns.txt JIT=1 HALF=1 BEAM=2 CACHELEVEL=0 python3 examples/gpt2.py --count 10 --temperature 0 --timing` # noqa: E501
  12. # 2) Validate the kernel correctness: `PYTHONPATH=. python3 ./test/external/verify_kernel.py --file /tmp/gpt2_kerns.txt`
  13. if __name__ == "__main__":
  14. parser = argparse.ArgumentParser(description="Verify the correctness of one or more kernel", formatter_class=argparse.ArgumentDefaultsHelpFormatter) # noqa: E501
  15. parser.add_argument("--kernel", type=str, default=None, help="a string of a tuple of (ast, applied_opts,)")
  16. parser.add_argument("--file", type=str, default=None, help="a file containing a tuple of ast and applied_opts, one per line")
  17. parser.add_argument("--pkl", type=str, default=None, help="a pickle file containing a single tuple of ast and applied_opts")
  18. parser.add_argument("--rtol", type=float, default=1e-2, help="relative tolerance for numerical comparison")
  19. parser.add_argument("--atol", type=float, default=1e-2, help="absolute tolerance for numerical comparison")
  20. parser.add_argument("--timing", action='store_true', help="show final timing for the kernel")
  21. parser.add_argument("--expected-failures", type=int, default=0, help="the number of expected failed kernels")
  22. args = parser.parse_args()
  23. if args.kernel is not None:
  24. print("loading kernel from args")
  25. test_lins = [kern_str_to_lin(args.kernel)]
  26. elif args.file is not None:
  27. print(f"loading kernel from file '{args.file}'")
  28. with open(args.file, 'r') as file:
  29. kern_strs = file.readlines()
  30. test_lins = [kern_str_to_lin(kern_str) for kern_str in kern_strs]
  31. elif args.pkl is not None:
  32. print(f"loading kernel from pickle file '{args.file}'")
  33. import pickle
  34. with open(args.pkl, 'rb') as file:
  35. (ast, applied_opts,) = pickle.load(file)
  36. lin = Kernel(ast)
  37. for opt in applied_opts:
  38. lin.apply_opt(opt)
  39. test_lins = [lin]
  40. else:
  41. raise RuntimeError("no kernel specified; use --kernel, --file, or --pkl options")
  42. print(f"verifying {len(test_lins)} kernels")
  43. failed_ids = []
  44. failures = defaultdict(list)
  45. for i, test_lin in enumerate(test_lins):
  46. print(f"testing kernel {i}")
  47. print_tree(test_lin.ast)
  48. print(test_lin.ast)
  49. print(test_lin.applied_opts)
  50. unoptimized_lin = Kernel(test_lin.ast)
  51. unoptimized_lin.required_optimizations()
  52. print(f"{unoptimized_lin.colored_shape()} -> {test_lin.colored_shape()}")
  53. (msg,rb,vv,gt) = compare_linearizer(test_lin, None, None, None, rtol=args.rtol, atol=args.atol)
  54. if msg != "PASS":
  55. failed_ids.append(i)
  56. failures[msg].append((test_lin.ast, test_lin.applied_opts))
  57. if args.timing:
  58. tm = time_linearizer(test_lin, rb, allow_test_size=False, cnt=10)
  59. print(f"final time {tm*1e6:9.0f} us")
  60. for msg, errors in failures.items():
  61. for i, (ast, opts) in enumerate(errors):
  62. print(f"{msg} {i} AST: {ast}")
  63. print(f"{msg} {i} OPTS: {opts}\n")
  64. print(f"tested {len(test_lins)} kernels")
  65. if failures:
  66. print(f"{failed_ids=}")
  67. for msg, errors in failures.items():
  68. print(f"{msg}: {len(errors)}")
  69. if len(failed_ids) == args.expected_failures:
  70. print(colored(f"{len(failed_ids)} failed as expected", "yellow"))
  71. if len(failed_ids) != args.expected_failures:
  72. raise RuntimeError(f"failed on {len(failed_ids)} kernels, expected {args.expected_failures}")
  73. else:
  74. print(colored("all passed", "green"))