fuzz_shapetracker.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. import random
  2. from tinygrad.helpers import DEBUG, getenv
  3. from test.unit.test_shapetracker import CheckingShapeTracker
  4. def do_permute(st):
  5. perm = list(range(0, len(st.shape)))
  6. random.shuffle(perm)
  7. perm = tuple(perm)
  8. if DEBUG >= 1: print("st.permute(", perm, ")")
  9. st.permute(perm)
  10. def do_pad(st):
  11. c = random.randint(0, len(st.shape)-1)
  12. pad = tuple((random.randint(0,2), random.randint(0,2)) if i==c else (0,0) for i in range(len(st.shape)))
  13. if DEBUG >= 1: print("st.pad(", pad, ")")
  14. st.pad(pad)
  15. def do_reshape_split_one(st):
  16. c = random.randint(0, len(st.shape)-1)
  17. poss = [n for n in [1,2,3,4,5] if st.shape[c]%n == 0]
  18. spl = random.choice(poss)
  19. shp = st.shape[0:c] + (st.shape[c]//spl, spl) + st.shape[c+1:]
  20. if DEBUG >= 1: print("st.reshape(", shp, ")")
  21. st.reshape(shp)
  22. def do_reshape_combine_two(st):
  23. if len(st.shape) < 2: return
  24. c = random.randint(0, len(st.shape)-2)
  25. shp = st.shape[:c] + (st.shape[c] * st.shape[c+1], ) + st.shape[c+2:]
  26. if DEBUG >= 1: print("st.reshape(", shp, ")")
  27. st.reshape(shp)
  28. def do_shrink(st):
  29. c = random.randint(0, len(st.shape)-1)
  30. while 1:
  31. shrink = tuple((random.randint(0,s), random.randint(0,s)) if i == c else (0,s) for i,s in enumerate(st.shape))
  32. if all(x<y for (x,y) in shrink): break
  33. if DEBUG >= 1: print("st.shrink(", shrink, ")")
  34. st.shrink(shrink)
  35. def do_stride(st):
  36. c = random.randint(0, len(st.shape)-1)
  37. stride = tuple(random.choice([-2,-1,2]) if i==c else 1 for i in range(len(st.shape)))
  38. if DEBUG >= 1: print("st.stride(", stride, ")")
  39. st.stride(stride)
  40. def do_flip(st):
  41. c = random.randint(0, len(st.shape)-1)
  42. stride = tuple(-1 if i==c else 1 for i in range(len(st.shape)))
  43. if DEBUG >= 1: print("st.stride(", stride, ")")
  44. st.stride(stride)
  45. def do_expand(st):
  46. c = [i for i,s in enumerate(st.shape) if s==1]
  47. if len(c) == 0: return
  48. c = random.choice(c)
  49. expand = tuple(random.choice([2,3,4]) if i==c else s for i,s in enumerate(st.shape))
  50. if DEBUG >= 1: print("st.expand(", expand, ")")
  51. st.expand(expand)
  52. shapetracker_ops = [do_permute, do_pad, do_shrink, do_reshape_split_one, do_reshape_combine_two, do_stride, do_expand]
  53. if __name__ == "__main__":
  54. random.seed(42)
  55. for _ in range(getenv("CNT", 200)):
  56. st = CheckingShapeTracker((random.randint(2, 10), random.randint(2, 10), random.randint(2, 10)))
  57. for i in range(8): random.choice(shapetracker_ops)(st)
  58. st.assert_same()