realize.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. from typing import List, Dict, Optional, cast, Generator, Tuple
  2. import time, pprint
  3. from dataclasses import dataclass, replace
  4. from tinygrad.helpers import colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, Context, TRACEMETA
  5. from tinygrad.ops import MetaOps, LazyOp
  6. from tinygrad.dtype import dtypes
  7. from tinygrad.device import Device, Buffer
  8. from tinygrad.shape.symbolic import Variable, sym_infer, sint
  9. from tinygrad.renderer import Renderer, Program
  10. from tinygrad.codegen.kernel import Kernel
  11. from tinygrad.engine.schedule import ScheduleItem
  12. # **************** Program Creation ****************
  13. logkerns, logkerns_level = open(getenv("LOGKERNS", ""), "a") if getenv("LOGKERNS", "") else None, getenv("LOGKERNS_LEVEL", 1)
  14. def get_kernel(renderer:Renderer, ast:LazyOp) -> Kernel:
  15. if DEBUG >= 5:
  16. from tinygrad.engine.graph import print_tree
  17. print_tree(ast)
  18. k = Kernel(ast, opts=renderer)
  19. k.required_optimizations()
  20. if not NOOPT:
  21. if not (used_tensor_cores:=k.apply_tensor_cores(getenv("TC", 1))): k.hand_coded_optimizations()
  22. if BEAM >= 1:
  23. from tinygrad.engine.search import beam_search, time_linearizer, bufs_from_lin
  24. kb, k_opt = Kernel(ast, opts=renderer), k
  25. kb.required_optimizations()
  26. rawbufs = bufs_from_lin(kb, allocate=False)
  27. k = beam_search(kb, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))
  28. if beam_compare:=getenv("BEAM_COMPARE", 1):
  29. # TODO: move the HC/TC/BEAM compare to beam_search so it can be optionally cached which choice is better
  30. lins: List[Tuple[str, Kernel]] = [(f"beam{BEAM.value}", k), (("tc" if used_tensor_cores else "hc"), k_opt)]
  31. if used_tensor_cores:
  32. lins.append(("hc", Kernel(ast, opts=renderer)))
  33. lins[-1][1].hand_coded_optimizations()
  34. timed = sorted([(nm, tk, time_linearizer(tk, rawbufs, allow_test_size=False, clear_l2=True)) for nm, tk in lins], key=lambda x: x[2])
  35. if DEBUG >= 1: print(" < ".join(f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed))
  36. k = timed[0][1]
  37. if logkerns is not None and logkerns_level > 1: logkerns.writelines([f"{(lin.ast, lin.applied_opts)}\n" for (_,lin,_) in timed[1:]])
  38. if beam_compare == 2:
  39. from tinygrad import Tensor
  40. all_outs: List[List[Tensor]] = []
  41. with Context(DEBUG=0, BEAM=0, CAPTURING=0):
  42. rand_bufs = [Tensor.normal(buf.size, std=0.1, dtype=buf.dtype).data() if dtypes.is_float(buf.dtype) else \
  43. (Tensor.randint(buf.size, low=0, high=2).cast(buf.dtype).data() if buf.dtype == dtypes.bool else \
  44. Tensor.randint(buf.size, low=dtypes.min(buf.dtype), high=dtypes.max(buf.dtype), dtype=buf.dtype).data()) \
  45. for buf in rawbufs]
  46. for _, tk in lins[::-1]:
  47. for buf,data in zip(rawbufs, rand_bufs): buf.ensure_allocated().copyin(data)
  48. time_linearizer(tk, rawbufs, allow_test_size=False, clear_l2=True, disable_cache=True)
  49. all_outs.append([Tensor(bytes(buf.as_buffer()), dtype=buf.dtype) for buf in rawbufs[:len(ast.src)]])
  50. with Context(DEBUG=0, BEAM=0, CAPTURING=0):
  51. for bufs in zip(*all_outs):
  52. for b in bufs[1:]:
  53. if dtypes.is_float(bufs[0].dtype):
  54. # we check both atol and rtol here
  55. diff_count = (((b-bufs[0]).abs() > 1e-3) * (((b-bufs[0])/bufs[0]).abs() > 1e-3)).sum().item()
  56. else:
  57. diff_count = (b != bufs[0]).sum().item()
  58. if diff_count != 0:
  59. raise RuntimeError(f"mismatch of {diff_count}/{b.numel()} items with type {b.dtype}, max {(b-bufs[0]).abs().max().item()}")
  60. if logkerns is not None: logkerns.writelines([f"{(k.ast, k.applied_opts)}\n"])
  61. if DEBUG >= 5: print((k.ast, k.applied_opts)) # print here to show final applied_opts for all kernels instead of just in beam_search
  62. return k
  63. # **************** Runners ****************
  64. class Runner:
  65. def __init__(self, display_name:str, dname:str, op_estimate:sint=0, mem_estimate:sint=0):
  66. self.first_run, self.display_name, self.dname, self.op_estimate, self.mem_estimate = True, display_name, dname, op_estimate, mem_estimate
  67. @property
  68. def device(self): return Device[self.dname]
  69. def exec(self, rawbufs:List[Buffer], var_vals:Optional[Dict[Variable, int]]=None) -> Optional[float]:
  70. return self(rawbufs, {} if var_vals is None else var_vals)
  71. def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False) -> Optional[float]:
  72. raise NotImplementedError("override this")
  73. class CompiledRunner(Runner):
  74. def __init__(self, p:Program, precompiled:Optional[bytes]=None):
  75. if DEBUG >= 4: print(p.src)
  76. self.p:Program = p
  77. self.lib:bytes = precompiled if precompiled is not None else Device[p.dname].compiler.compile_cached(p.src)
  78. self.clprg = Device[p.dname].runtime(p.function_name, self.lib)
  79. super().__init__(p.name, p.dname, p.op_estimate, p.mem_estimate)
  80. def __reduce__(self): return self.__class__, (self.p, self.lib)
  81. def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False) -> Optional[float]:
  82. global_size, local_size = self.p.launch_dims(var_vals)
  83. if global_size is not None and local_size is None and all_int(self.p.global_size): # type: ignore[arg-type]
  84. # TODO: this is copied from get_program
  85. from tinygrad.engine.search import optimize_local_size
  86. local_size = optimize_local_size(self.clprg, global_size, rawbufs)
  87. global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)]
  88. self.p = replace(self.p, global_size=global_size, local_size=local_size)
  89. lra = {}
  90. if global_size:
  91. lra['global_size'] = global_size
  92. assert len(global_size) == 3, "global size must have len 3"
  93. if local_size:
  94. lra['local_size'] = local_size
  95. assert len(local_size) == 3, "local size must have len 3"
  96. return self.clprg(*[x._buf for x in rawbufs], **lra, vals=tuple(var_vals[k] for k in self.p.vars), wait=wait)
  97. class CustomOp(Runner):
  98. def __init__(self, fxn):
  99. self.fxn = fxn
  100. super().__init__(self.fxn.__name__, "CUSTOM", 0, 0)
  101. def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False): self.fxn(*rawbufs)
  102. class EmptyOp(Runner):
  103. def __init__(self, buf:Buffer): super().__init__(colored(f"empty {buf.size:10d} {buf.dtype}", "yellow"), buf.device)
  104. def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False): pass
  105. class ViewOp(Runner):
  106. def __init__(self, buf:Buffer): super().__init__(colored(f"view {buf.nbytes:8d} @ {buf.offset:<10d}", "yellow"), buf.device)
  107. def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False):
  108. assert rawbufs[0]._base is not None and rawbufs[0]._base == rawbufs[1].base, f"must be base {rawbufs}"
  109. class BufferCopy(Runner):
  110. def __init__(self, total_sz, dest_device, src_device):
  111. if total_sz >= 1e6: name = f"{type(self).__name__[6:].lower()} {total_sz/1e6:7.2f}M, {dest_device[:7]:>7s} <- {src_device[:7]:7s}"
  112. else: name = f"{type(self).__name__[6:].lower()} {total_sz:8d}, {dest_device[:7]:>7s} <- {src_device[:7]:7s}"
  113. super().__init__(colored(name, "yellow"), dest_device, 0, total_sz)
  114. def copy(self, dest, src):
  115. disk_supports_fast_copyout = src.device.startswith("DISK") and hasattr(src.allocator.device, 'io_uring') and hasattr(src.allocator.device, 'fd')
  116. if src.device.startswith("DISK") and hasattr(dest.allocator, 'copy_from_disk') and disk_supports_fast_copyout and src.nbytes >= 4096:
  117. dest.allocator.copy_from_disk(dest._buf, src._buf, src.nbytes)
  118. elif src.device.startswith("DISK") and hasattr(dest.allocator, 'as_buffer'):
  119. # fast(ish) path, uses readinto in diskbuffers
  120. src.allocator.copyout(dest.allocator.as_buffer(dest._buf), src._buf)
  121. else:
  122. dest.copyin(src.as_buffer(allow_zero_copy=True)) # may allocate a CPU buffer depending on allow_zero_copy
  123. def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False):
  124. dest, src = rawbufs[0:2]
  125. assert dest.size == src.size and dest.dtype == src.dtype, f"buffer copy mismatch, {dest.size} != {src.size}, {dest.dtype} != {src.dtype}"
  126. st = time.perf_counter()
  127. self.copy(dest, src)
  128. if wait:
  129. Device[dest.device].synchronize()
  130. return time.perf_counter() - st
  131. class BufferXfer(BufferCopy):
  132. def copy(self, dest, src):
  133. if hasattr(dest.allocator.device, "track_cross_buffer") and hasattr(src.allocator, "track_cross_device"):
  134. dest.allocator.device.track_cross_buffer.append(src)
  135. src.allocator.track_cross_device.add(dest.allocator.device)
  136. dest.allocator.transfer(dest._buf, src._buf, dest.nbytes, src_dev=src.allocator.device, dest_dev=dest.allocator.device)
  137. # **************** method cache ****************
  138. method_cache: Dict[Tuple[str, LazyOp, int, bool], CompiledRunner] = {}
  139. def get_runner(dname:str, ast:LazyOp) -> CompiledRunner:
  140. ckey = (dname, ast, BEAM.value, False)
  141. if cret:=method_cache.get(ckey): return cret
  142. bkey = (dname.split(":")[0], ast, BEAM.value, True)
  143. if bret:=method_cache.get(bkey):
  144. method_cache[ckey] = ret = CompiledRunner(replace(bret.p, dname=dname), bret.lib)
  145. else:
  146. prg: Program = get_kernel(Device[dname].renderer, ast).to_program()
  147. if hasattr(prg.uops, "_fuzz_paths"):
  148. from test.external.fuzz_uops import UOpsFuzzerRunner
  149. return UOpsFuzzerRunner(replace(prg, dname=dname))
  150. method_cache[ckey] = method_cache[bkey] = ret = CompiledRunner(replace(prg, dname=dname))
  151. return ret
  152. # **************** lowering functions ****************
  153. @dataclass(frozen=True)
  154. class ExecItem:
  155. prg: Runner
  156. bufs: List[Optional[Buffer]]
  157. metadata: Optional[List[Metadata]] = None
  158. def run(self, var_vals:Optional[Dict[Variable, int]]=None, wait=False, jit=False, do_update_stats=True) -> Optional[float]:
  159. bufs = [cast(Buffer, x) for x in self.bufs] if jit else [cast(Buffer, x).ensure_allocated() for x in self.bufs]
  160. et = self.prg(bufs, var_vals if var_vals is not None else {}, wait=wait or DEBUG >= 2)
  161. if do_update_stats:
  162. GlobalCounters.kernel_count += 1
  163. GlobalCounters.global_ops += (op_estimate:=sym_infer(self.prg.op_estimate, var_vals))
  164. GlobalCounters.global_mem += (mem_estimate:=sym_infer(self.prg.mem_estimate, var_vals))
  165. if et is not None: GlobalCounters.time_sum_s += et
  166. if DEBUG >= 2:
  167. ptm = (colored(f"{et*1e3:9.2f}ms", "yellow") if et > 0.01 else f"{et*1e6:9.2f}us") if et is not None else ""
  168. print(f"{colored(f'*** {self.prg.dname[:7]:7s} {GlobalCounters.kernel_count:4d}', 'magenta' if jit else ('green' if self.prg.first_run else None))} {self.prg.display_name+' '*(38-ansilen(self.prg.display_name))} arg {len(self.bufs):3d} mem {GlobalCounters.mem_used/1e9:5.2f} GB " + # noqa: E501
  169. (str() if et is None else f"tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)" + # noqa: E501
  170. f" {[repr(m) if TRACEMETA >= 2 else str(m) for m in self.metadata] if self.metadata else ''}"))
  171. self.prg.first_run = False
  172. return et
  173. def lower_schedule_item(si:ScheduleItem) -> ExecItem:
  174. assert len(set(x.device for x in si.bufs)) == 1 or si.ast.op is MetaOps.COPY or getenv("USE_COPY_KERNEL")
  175. if si.ast.op is MetaOps.KERNEL:
  176. runner = get_runner(si.outputs[0].device, si.ast)
  177. return ExecItem(runner, [si.bufs[x[0]] for x in runner.p.globals], si.metadata)
  178. out = si.outputs[0]
  179. if si.ast.op is MetaOps.COPY:
  180. kernel_type = BufferCopy
  181. if hasattr(Device[out.device].allocator, 'transfer') and out.device.split(":")[0] == si.inputs[0].device.split(":")[0]:
  182. kernel_type = BufferXfer
  183. return ExecItem(kernel_type(si.ast.arg, out.device, si.inputs[0].device), list(si.bufs))
  184. if si.ast.op is MetaOps.CUSTOM: return ExecItem(CustomOp(si.ast.arg), list(si.bufs))
  185. if si.ast.op is MetaOps.EMPTY: return ExecItem(EmptyOp(out), list(si.bufs))
  186. if si.ast.op is MetaOps.VIEW: return ExecItem(ViewOp(out), list(si.bufs))
  187. raise RuntimeError(f"don't know how to lower {si.ast}")
  188. def lower_schedule(schedule:List[ScheduleItem]) -> Generator[ExecItem, None, None]:
  189. while len(schedule):
  190. si = schedule.pop(0)
  191. try: yield lower_schedule_item(si)
  192. except Exception as e:
  193. if DEBUG >= 2:
  194. print(f"error lowering {si.ast.op}")
  195. print("tensor operations:")
  196. pprint.pprint(si.metadata, indent=2)
  197. raise e
  198. # **************** main run function ****************
  199. capturing: List = [] # put classes with an add method in here
  200. def run_schedule(schedule:List[ScheduleItem], var_vals:Optional[Dict[Variable, int]]=None, do_update_stats=True):
  201. for ei in lower_schedule(schedule):
  202. if len(capturing) and CAPTURING: capturing[0].add(ei)
  203. ei.run(var_vals, do_update_stats=do_update_stats)