fuzz_uops.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. import itertools
  2. from collections import defaultdict
  3. import numpy as np
  4. from dataclasses import replace
  5. from typing import DefaultDict, Dict, List, Tuple
  6. from tinygrad.codegen.uops import END_FOR_UOP, UOp
  7. from tinygrad.codegen.uopgraph import UOpGraph
  8. from tinygrad.device import Buffer, Device
  9. from tinygrad.engine.realize import CompiledRunner
  10. from tinygrad.helpers import DEBUG, colored
  11. from tinygrad.shape.symbolic import Variable
  12. from tinygrad.tensor import _to_np_dtype
  13. from test.external.fuzz_schedule import FUZZ_SCHEDULE_MAX_PATHS, find_all_toposorts
  14. def fuzz_uops(uops:UOpGraph) -> List[Tuple[UOp, ...]]:
  15. blocks: List[List[UOp]] = [[]]
  16. for u in uops:
  17. if u.op in END_FOR_UOP: blocks.append([u])
  18. elif u.op in {x[1] for x in END_FOR_UOP.values()}: blocks.extend([[u], []])
  19. else: blocks[-1].append(u)
  20. paths_for_block: Dict[int, List[Tuple[UOp, ...]]] = {}
  21. for bi, bb in enumerate(blocks):
  22. children: DefaultDict[UOp, List[UOp]] = defaultdict(list)
  23. in_degree: Dict[UOp, int] = {}
  24. for u in bb:
  25. in_degree[u] = 0
  26. for x in u.src:
  27. if x in bb:
  28. children[x].append(u)
  29. in_degree[u] += 1
  30. paths_for_block[bi] = find_all_toposorts(children, in_degree)
  31. paths: Dict[Tuple[UOp, ...], None] = {}
  32. for up in itertools.product(*paths_for_block.values()):
  33. paths[tuple(uop for path in up for uop in path)] = None
  34. if len(paths) >= FUZZ_SCHEDULE_MAX_PATHS: break
  35. return list(paths)
  36. class UOpsFuzzerRunner(CompiledRunner):
  37. def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False):
  38. assert self.p.uops is not None and len(self.p.uops._fuzz_paths) >= 1
  39. init_rawbufs, init_name = {x:x.as_buffer() for x in rawbufs}, self.p.function_name
  40. init_globals = {i[0]:buf for i, buf in zip(self.p.globals, rawbufs)}
  41. if DEBUG >= 1: print(colored(f"fuzzing {len(self.p.uops._fuzz_paths)} uop permutations for {init_name}", "yellow"))
  42. super().__call__(rawbufs, var_vals, wait)
  43. ground_truth = {x:np.frombuffer(x.as_buffer(), _to_np_dtype(x.dtype)) for x in rawbufs}
  44. for i, path in enumerate(self.p.uops._fuzz_paths):
  45. # setup prg
  46. uops = UOpGraph([])
  47. uops._uops = list(path)
  48. if DEBUG >= 5: uops.print()
  49. self.p = replace(self.p, name=(name:=f"{init_name}fuzz{i}"), src=Device[self.p.dname].renderer.render(name, uops), uops=uops)
  50. if DEBUG >= 4: print(self.p.src)
  51. self.lib = Device[self.p.dname].compiler.compile_cached(self.p.src)
  52. self.clprg = Device[self.p.dname].runtime(name, self.lib)
  53. for x in (rawbufs:=[init_globals[i[0]] for i in self.p.globals]): x.copyin(init_rawbufs[x])
  54. # verify
  55. super().__call__(rawbufs, var_vals, wait)
  56. for i, x in enumerate(rawbufs):
  57. try:
  58. np.testing.assert_allclose(np.frombuffer(x.as_buffer(), _to_np_dtype(x.dtype)), ground_truth[x], atol=1e-6, rtol=1e-6)
  59. if DEBUG >= 2: print(colored(name, "green"))
  60. except AssertionError as e:
  61. print(colored(name, "red"))
  62. raise e