fuzz_schedule.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. import itertools
  2. import numpy as np
  3. from typing import DefaultDict, Dict, List, Set, Tuple, TypeVar, Union
  4. from tinygrad.device import Buffer
  5. from tinygrad.engine.realize import CustomOp, capturing, lower_schedule_item
  6. from tinygrad.helpers import DEBUG, MULTIOUTPUT, colored, getenv
  7. from tinygrad.lazy import LazyBuffer
  8. from tinygrad.engine.schedule import _graph_schedule, ScheduleItem
  9. from tinygrad.ops import MetaOps
  10. from tinygrad.tensor import Tensor, _to_np_dtype
  11. ctx_vars = { MULTIOUTPUT: (0, 1) }
  12. FUZZ_SCHEDULE_MAX_PATHS = getenv("FUZZ_SCHEDULE_MAX_PATHS", 10)
  13. def fuzz_schedule(outs:List[LazyBuffer]):
  14. # find toposorts across all tunable params
  15. unique_ts: Dict[Tuple[LazyBuffer, ...], Tuple[Dict, Dict[LazyBuffer, Tuple]]] = {}
  16. for combination in itertools.product(*ctx_vars.values()):
  17. for var, val in zip(ctx_vars, combination): var.value = val
  18. graph, in_degree, prescheduled = _graph_schedule(outs, set())
  19. for ts in find_all_toposorts(graph, in_degree): unique_ts[ts] = (dict(zip([v.key for v in ctx_vars], combination)), prescheduled)
  20. toposorts = list(unique_ts.items())
  21. if DEBUG >= 1: print(colored(f"fuzzing {len(toposorts)} schedule permutations", "yellow"))
  22. # setup ground truth
  23. ground_truth: Dict[LazyBuffer, memoryview] = {}
  24. assign_targets: Dict[LazyBuffer, LazyBuffer] = {}
  25. # IMPORTANT: freeze prerealized bufs before ScheduleItem exec
  26. prerealized: Dict[LazyBuffer, memoryview] = {}
  27. seed = Tensor._seed
  28. ts, (_, prescheduled) = toposorts[0]
  29. for key in ts:
  30. for out in (ps:=prescheduled[key])[0]:
  31. # freeze assign state before exec
  32. if out.op is MetaOps.ASSIGN:
  33. prerealized[out] = out.buffer.as_buffer()
  34. assign_targets[out.srcs[1]] = out
  35. for x in ps[2]:
  36. if x not in ground_truth and x.device != "NPY": prerealized[x] = x.buffer.as_buffer()
  37. si = ScheduleItem(ps[1], tuple(x.buffer for x in ps[0]+ps[2] if x.size != 0))
  38. _exec_si(si, seed)
  39. for out in ps[0]:
  40. ground_truth[out] = out.buffer.as_buffer()
  41. del out.srcs # only schedule the LazyBuffer in this fuzz run
  42. # exec and validate each permutation with new Buffers
  43. for i, (ts, (ctx, prescheduled)) in enumerate(toposorts[1:]):
  44. if DEBUG >= 1: print(colored(f"testing permutation {i} {ctx}", "yellow"))
  45. rawbufs: Dict[LazyBuffer, Buffer] = {}
  46. for key in ts:
  47. for out in (ps:=prescheduled[key])[0]:
  48. rawbufs[out] = Buffer(out.buffer.device, out.buffer.size, out.buffer.dtype)
  49. if out.op is MetaOps.ASSIGN: rawbufs[out].ensure_allocated().copyin(prerealized[out])
  50. for x in ps[2]:
  51. if x not in rawbufs:
  52. # override the assign_target after ASSIGN
  53. if x in assign_targets and assign_targets[x] in rawbufs: rawbufs[x] = rawbufs[assign_targets[x]]
  54. elif x.device == "NPY": rawbufs[x] = x.buffer
  55. # copy the pre realized input
  56. else: rawbufs[x] = Buffer(x.buffer.device, x.buffer.size, x.buffer.dtype, initial_value=prerealized[x])
  57. si = ScheduleItem(ps[1], tuple(rawbufs[x] for x in ps[0]+ps[2] if x.size != 0))
  58. _exec_si(si, seed)
  59. for out in ps[0]:
  60. outbuf = np.frombuffer(rawbufs[out].as_buffer(), _to_np_dtype(out.dtype))
  61. try: np.testing.assert_allclose(outbuf, np.frombuffer(ground_truth[out], _to_np_dtype(out.dtype)), atol=1e-2, rtol=1e-2)
  62. except Exception as e:
  63. print(f"FAILED FOR {out}")
  64. raise e
  65. def _exec_si(si:ScheduleItem, seed:int):
  66. ei = lower_schedule_item(si)
  67. if len(capturing): capturing[0].add(ei)
  68. if isinstance(ei.prg, CustomOp): Tensor._seed = seed
  69. ei.run()
  70. T = TypeVar("T")
  71. def find_all_toposorts(graph:DefaultDict[T, List[T]], in_degree:Union[DefaultDict[T, int], Dict[T, int]]) -> List[Tuple[T, ...]]:
  72. visited: Set[T] = set()
  73. ret: List[Tuple[T, ...]] = []
  74. path: List[T] = []
  75. def recurse_paths(path:List[T]):
  76. for v, d in in_degree.items():
  77. if d != 0 or v in visited: continue
  78. for u in graph[v]: in_degree[u] -= 1
  79. path.append(v)
  80. visited.add(v)
  81. recurse_paths(path)
  82. if len(ret) >= FUZZ_SCHEDULE_MAX_PATHS: return
  83. # backtrack
  84. for u in graph[v]: in_degree[u] += 1
  85. path.pop()
  86. visited.remove(v)
  87. if len(path) == len(in_degree): ret.append(tuple(path))
  88. recurse_paths(path)
  89. if len(ret) == 0: raise RuntimeError("detected cycle in the graph")
  90. # verify all paths are unique
  91. assert len(ret) == len(set(ret))
  92. return ret