process_replay.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. #!/usr/bin/env python3
  2. # compare kernels created by HEAD against master
  3. import difflib, pickle, multiprocessing, os, logging
  4. from tinygrad.codegen.kernel import Kernel
  5. from tinygrad.helpers import Context, ContextVar, colored, db_connection, VERSION, getenv, tqdm
  6. PAGE_SIZE = 100
  7. TABLE_NAME = f"process_replay_{getenv('GITHUB_RUN_ID', 'HEAD')}_{VERSION}"
  8. ASSERT_DIFF = getenv("ASSERT_PROCESS_REPLAY", int((k:="[run_process_replay]") in os.getenv("COMMIT_MESSAGE", k) or k in os.getenv("PR_TITLE", k)))
  9. SKIP_PROCESS_REPLAY = int((k:="[skip_process_replay]") in os.getenv("COMMIT_MESSAGE", "") or k in os.getenv("PR_TITLE", ""))
  10. MAX_DIFF_PCT = getenv("PROCESS_REPLAY_MAX_DIFF_PCT", 20)
  11. assert MAX_DIFF_PCT < 100
  12. early_stop = multiprocessing.Event()
  13. logging.basicConfig(level=logging.INFO, format='%(message)s')
  14. def process_replay(offset:int):
  15. if early_stop.is_set(): return
  16. conn = db_connection()
  17. cur = conn.cursor()
  18. cur.execute(f"SELECT val FROM '{TABLE_NAME}' LIMIT ? OFFSET ?", (PAGE_SIZE, offset))
  19. changed = 0
  20. for row in cur.fetchall():
  21. ast, applied_opts = None, None
  22. # try unpickle and linearize
  23. try:
  24. ast, opts, applied_opts, name, compare_src, ctx = pickle.loads(row[0])
  25. with Context(**{k:v for k,v in ctx.items() if k in ContextVar._cache and k != "DEBUG"}):
  26. k = Kernel(ast, opts=opts)
  27. for opt in applied_opts: k.apply_opt(opt)
  28. good_src = k.opts.render(name, k.linearize().uops)
  29. except Exception as e:
  30. logging.warn("FAILED TO RECREATE KERNEL")
  31. logging.info(ast)
  32. logging.info(applied_opts)
  33. logging.info(e)
  34. if ASSERT_DIFF: raise e
  35. continue
  36. # try compare
  37. try: assert compare_src == good_src
  38. except AssertionError as e:
  39. changed += 1
  40. logging.info("PROCESS REPLAY DETECTED CHANGE")
  41. logging.info(ast)
  42. logging.info(applied_opts)
  43. diff = list(difflib.unified_diff(good_src.splitlines(), compare_src.splitlines()))
  44. for line in diff:
  45. logging.info(colored(line, "red" if line.startswith("-") else "green" if line.startswith("+") else None))
  46. if ASSERT_DIFF: raise e
  47. if changed > MAX_DIFF_PCT:
  48. logging.warn(f"detected changes in over {MAX_DIFF_PCT}% of kernels. skipping further diff generation.")
  49. early_stop.set()
  50. break
  51. conn.commit()
  52. cur.close()
  53. if __name__ == "__main__":
  54. if SKIP_PROCESS_REPLAY:
  55. logging.info("skipping process replay.")
  56. exit(0)
  57. conn = db_connection()
  58. cur = conn.cursor()
  59. row_count = cur.execute(f"select count(*) from '{TABLE_NAME}'").fetchone()[0]
  60. conn.commit()
  61. cur.close()
  62. offsets = range(0, row_count, PAGE_SIZE)
  63. with multiprocessing.Pool(processes=multiprocessing.cpu_count()) as pool:
  64. list(tqdm(pool.imap(process_replay, offsets), total=len(offsets)))
  65. pool.close()
  66. pool.join()