fuzz_symbolic.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. import itertools
  2. import random
  3. from tinygrad.helpers import DEBUG
  4. from tinygrad.shape.symbolic import Variable, NumNode
  5. random.seed(42)
  6. def add_v(expr, rng=None):
  7. if rng is None: rng = random.randint(0,2)
  8. return expr + v[rng], rng
  9. def div(expr, rng=None):
  10. if rng is None: rng = random.randint(1,9)
  11. return expr // rng, rng
  12. def mul(expr, rng=None):
  13. if rng is None: rng = random.randint(-4,4)
  14. return expr * rng, rng
  15. def mod(expr, rng=None):
  16. if rng is None: rng = random.randint(1,9)
  17. return expr % rng, rng
  18. def add_num(expr, rng=None):
  19. if rng is None: rng = random.randint(-4,4)
  20. return expr + rng, rng
  21. def lt(expr, rng=None):
  22. if rng is None: rng = random.randint(-4,4)
  23. return expr < rng, rng
  24. def ge(expr, rng=None):
  25. if rng is None: rng = random.randint(-4,4)
  26. return expr >= rng, rng
  27. def le(expr, rng=None):
  28. if rng is None: rng = random.randint(-4,4)
  29. return expr <= rng, rng
  30. def gt(expr, rng=None):
  31. if rng is None: rng = random.randint(-4,4)
  32. return expr > rng, rng
  33. if __name__ == "__main__":
  34. ops = [add_v, div, mul, add_num, mod]
  35. for _ in range(1000):
  36. upper_bounds = [*list(range(1, 10)), 16, 32, 64, 128, 256]
  37. u1 = Variable("v1", 0, random.choice(upper_bounds))
  38. u2 = Variable("v2", 0, random.choice(upper_bounds))
  39. u3 = Variable("v3", 0, random.choice(upper_bounds))
  40. v = [u1,u2,u3]
  41. tape = [random.choice(ops) for _ in range(random.randint(2, 30))]
  42. # 10% of the time, add one of lt, le, gt, ge
  43. if random.random() < 0.1: tape.append(random.choice([lt, le, gt, ge]))
  44. expr = NumNode(0)
  45. rngs = []
  46. for t in tape:
  47. expr, rng = t(expr)
  48. if DEBUG >= 1: print(t.__name__, rng)
  49. rngs.append(rng)
  50. if DEBUG >=1: print(expr)
  51. space = list(itertools.product(range(u1.min, u1.max+1), range(u2.min, u2.max+1), range(u3.min, u3.max+1)))
  52. volume = len(space)
  53. for (v1, v2, v3) in random.sample(space, min(100, volume)):
  54. v = [v1,v2,v3]
  55. rn = 0
  56. for t,r in zip(tape, rngs): rn, _ = t(rn, r)
  57. num = eval(expr.render())
  58. assert num == rn, f"mismatched {expr.render()} at {v1=} {v2=} {v3=} = {num} != {rn}"
  59. if DEBUG >= 1: print(f"matched {expr.render()} at {v1=} {v2=} {v3=} = {num} == {rn}")