replay_schedule.py 1.2 KB

123456789101112131415161718192021222324252627282930313233
  1. #!/usr/bin/env python3
  2. import subprocess, pickle, shlex, sys, os
  3. from typing import Dict, List, Tuple
  4. from tinygrad.engine.graph import print_tree
  5. from tinygrad.helpers import colored
  6. from tinygrad.ops import LazyOp
  7. def _run(name:str, cmd:List[str], env:Dict[str, str]) -> List[Tuple[LazyOp, ...]]:
  8. commit = subprocess.check_output(["git", "rev-parse", name], encoding="utf-8").strip()
  9. subprocess.run(["git", "checkout", commit], check=True)
  10. subprocess.run(cmd, env={**env, "SAVE_SCHEDULE_PATH": f"{commit}.pkl"})
  11. return pickle.load(open(f"./{commit}.pkl", "rb"))
  12. def _get_cmd():
  13. parts, env = shlex.split(sys.argv[1]), {**os.environ, "SAVE_SCHEDULE": "1", "CAPTURE_AST": "1"}
  14. env.update({k: v for p in parts if "=" in p for k, v in [p.split("=")]})
  15. return [p for p in parts if "=" not in p], env
  16. if __name__ == "__main__":
  17. cmd, env = _get_cmd()
  18. feat = _run("HEAD", cmd, env)
  19. master = _run("master", cmd, env)
  20. assert len(master) == len(feat)
  21. for m, f in zip(master, feat):
  22. try: assert m == f
  23. except AssertionError as e:
  24. print(colored("FAILED FOR AST: ", "red"))
  25. print("expected:")
  26. for op in m: print_tree(op)
  27. print("got:")
  28. for op in f: print_tree(op)
  29. raise e