get_action_space.py 1.2 KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. import random
  2. from extra.optimization.helpers import load_worlds, ast_str_to_lin
  3. from tinygrad.engine.search import actions
  4. from tinygrad.codegen.kernel import Kernel
  5. from tinygrad.helpers import tqdm
  6. tactions = set()
  7. def test_rebuild(lin):
  8. linr = Kernel(lin.ast)
  9. for o in lin.applied_opts:
  10. assert o in actions, f"{o} is not in actions"
  11. tactions.add(o)
  12. linr.apply_opt(o)
  13. assert len(lin.sts) == len(linr.sts)
  14. for st1,st2 in zip(lin.sts, linr.sts):
  15. assert st1 == st2, f"{st1} != {st2}"
  16. if __name__ == "__main__":
  17. ast_strs = load_worlds(False, False, False)
  18. random.shuffle(ast_strs)
  19. ast_strs = ast_strs[:2000]
  20. for ast_str in tqdm(ast_strs):
  21. lin = ast_str_to_lin(ast_str)
  22. #if not lin.apply_tensor_cores():
  23. lin.hand_coded_optimizations()
  24. test_rebuild(lin)
  25. # confirm linearize can be called twice
  26. uops1 = lin.linearize().uops
  27. uops2 = lin.linearize().uops
  28. for x,y in zip(uops1.uops, uops2.uops):
  29. # for some reason DEFINE_ACC is changing the arg
  30. if x.op != y.op or x.dtype != y.dtype: # or x.arg != y.arg:
  31. uops1.print()
  32. uops2.print()
  33. raise Exception(f"UOPS MISMATCH {x} {y}")
  34. print(len(tactions), len(actions))
  35. print(sorted(list(tactions)))