jit.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. from __future__ import annotations
  2. from typing import TypeVar, Generic, Callable, List, Tuple, Union, Dict, cast, Optional, Any
  3. import functools, itertools, collections
  4. from tinygrad.tensor import Tensor
  5. from tinygrad.lazy import LazyBuffer
  6. from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, GRAPH, BEAM, getenv, all_int, GraphException, colored, JIT
  7. from tinygrad.device import Buffer, Compiled, Device
  8. from tinygrad.dtype import DType
  9. from tinygrad.shape.shapetracker import ShapeTracker
  10. from tinygrad.shape.symbolic import Variable, sint
  11. from tinygrad.engine.realize import ExecItem, capturing, EmptyOp, ViewOp, BufferXfer, CompiledRunner, Runner
  12. from tinygrad.engine.schedule import _internal_memory_planner
  13. from tinygrad.nn.state import get_parameters
  14. from weakref import WeakKeyDictionary
  15. def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]) -> List[ExecItem]:
  16. # Split JIT cache into batches for faster graph execution.
  17. # This allows the accelerator to run some batches while subsequent graphs are still being updated.
  18. max_batch_size = getenv("JIT_BATCH_SIZE", 32)
  19. graphed_jit_cache: List[ExecItem] = []
  20. current_batch: List[ExecItem] = []
  21. current_device: Optional[Compiled] = None
  22. def flush_batch():
  23. nonlocal current_batch, current_device, max_batch_size
  24. try:
  25. if len(current_batch) <= 1 or current_device is None: raise GraphException("only one kernel doesn't graph")
  26. graph_runner = current_device.graph(current_batch, input_rawbuffers, var_vals)
  27. # clear jit inputs to allow their memory to be freed/reused
  28. for (j,i) in graph_runner.input_replace.keys(): graph_runner.jit_cache[j].bufs[i] = None
  29. graphed_jit_cache.append(ExecItem(graph_runner, cast(List[Optional[Buffer]], input_rawbuffers)))
  30. max_batch_size *= 2
  31. if DEBUG >= 2: print(f"\tJIT GRAPHing batch with {len(current_batch)} kernels on device {current_device}")
  32. except GraphException as e:
  33. graphed_jit_cache.extend(current_batch)
  34. if DEBUG >= 2: print(f"\tJIT GRAPHing failed batch with {len(current_batch)} kernels on device {current_device}: {e}")
  35. current_batch = []
  36. current_device = None
  37. for ji in jit_cache:
  38. if ji.prg.__class__ in {EmptyOp, ViewOp}: continue
  39. ji_graph_dev: Optional[Compiled] = None # device on which the ji will be graphed. Not graphed if None.
  40. if isinstance(ji.prg, CompiledRunner): ji_graph_dev = ji.prg.device
  41. elif isinstance(ji.prg, BufferXfer) and ji.bufs[0] and ji.bufs[0].device.split(":", 1)[0] in {"CUDA", "NV", "AMD"}:
  42. ji_graph_dev = Device[ji.bufs[0].device]
  43. graph_class = (ji_graph_dev.graph.func if isinstance(ji_graph_dev.graph, functools.partial) else ji_graph_dev.graph) if ji_graph_dev else None #type: ignore
  44. can_be_graphed = ji_graph_dev and ji_graph_dev.graph
  45. can_share_graph = (ji_graph_dev == current_device or (isinstance(graph_class, type) and issubclass(graph_class, MultiGraphRunner)) and
  46. type(ji_graph_dev) is type(current_device))
  47. can_extend_graph_batch = can_be_graphed and len(current_batch) < max_batch_size and can_share_graph
  48. if not can_extend_graph_batch and len(current_batch) > 0: flush_batch()
  49. if can_be_graphed: current_batch.append(ji)
  50. else: graphed_jit_cache.append(ji)
  51. current_device = ji_graph_dev
  52. if len(current_batch) > 0: flush_batch()
  53. return graphed_jit_cache
  54. def get_input_replace(jit_cache: List[ExecItem], input_rawbuffers:List[Buffer]) -> Dict[Tuple[int, int], int]:
  55. input_replace: Dict[Tuple[int, int], int] = {}
  56. for j,ji in enumerate(jit_cache):
  57. for i,a in enumerate(ji.bufs):
  58. if a in input_rawbuffers:
  59. input_replace[(j,i)] = input_rawbuffers.index(a)
  60. return input_replace
  61. class GraphRunner(Runner): # pylint: disable=abstract-method
  62. def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
  63. self.jit_cache = jit_cache
  64. self.input_replace = get_input_replace(jit_cache, input_rawbuffers)
  65. self.jc_idx_with_updatable_launch_dims = []
  66. self.jc_idx_with_updatable_var_vals = []
  67. op_estimate: sint = 0
  68. mem_estimate: sint = 0
  69. for j,ji in enumerate(jit_cache):
  70. op_estimate += ji.prg.op_estimate
  71. mem_estimate += ji.prg.mem_estimate
  72. if isinstance(ji.prg, CompiledRunner):
  73. if ji.prg.p.vars: self.jc_idx_with_updatable_var_vals.append(j)
  74. if (ji.prg.p.global_size and not all_int(ji.prg.p.global_size)) or (ji.prg.p.local_size and not all_int(ji.prg.p.local_size)):
  75. self.jc_idx_with_updatable_launch_dims.append(j)
  76. self.vars = sorted(var_vals.keys(), key=lambda v: v.expr)
  77. super().__init__(colored(f"<batched {len(self.jit_cache)}>", "cyan"), jit_cache[0].prg.dname.split(":")[0], op_estimate, mem_estimate)
  78. class MultiGraphRunner(GraphRunner): # pylint: disable=abstract-method
  79. def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
  80. self.w_dependency_map: Dict[Any, Any] = {}
  81. self.r_dependency_map: Dict[Any, List[Any]] = collections.defaultdict(list)
  82. super().__init__(jit_cache, input_rawbuffers, var_vals)
  83. def _access_resources(self, read, write, new_dependency:Any):
  84. # To synchronize access to resources, we monitor the necessary prerequisites for accessing each resource,
  85. # whether for write or read operations. A resource can be accessed by either a single writer or multiple readers.
  86. wait_nodes = []
  87. for rawbuf in read + write:
  88. if id(rawbuf.base._buf) in self.w_dependency_map: wait_nodes.append(self.w_dependency_map[id(rawbuf.base._buf)])
  89. for rawbuf in write:
  90. if id(rawbuf.base._buf) in self.r_dependency_map: wait_nodes.extend(self.r_dependency_map.pop(id(rawbuf.base._buf)))
  91. for rawbuf in read: self.r_dependency_map[id(rawbuf.base._buf)].append(new_dependency)
  92. for rawbuf in write: self.w_dependency_map[id(rawbuf.base._buf)] = new_dependency
  93. return list({id(x):x for x in wait_nodes}.values())
  94. ReturnType = TypeVar('ReturnType')
  95. class TinyJit(Generic[ReturnType]):
  96. def __init__(self, fxn:Callable[..., ReturnType]):
  97. self.fxn = fxn
  98. self.reset()
  99. def add_buffer(self, b:Buffer) -> Buffer:
  100. if found:=self.buffer_replace.get(b, None): return found
  101. if b.is_allocated() or b.lb_refcount > 0: return b
  102. if b._base is not None:
  103. self.buffer_replace[b] = ret = Buffer(b.device, b.size, b.dtype, base=self.add_buffer(b._base), offset=b.offset)
  104. else:
  105. self.buffer_replace[b] = ret = Buffer(b.device, b.size, b.dtype, options=b.options)
  106. return ret
  107. def add(self, ei:ExecItem):
  108. self.jit_cache.append(ExecItem(ei.prg, [self.add_buffer(buf) for buf in ei.bufs if buf is not None]))
  109. def reset(self):
  110. self.jit_cache: List[ExecItem] = []
  111. self.input_replace: Dict[Tuple[int, int], int] = {}
  112. self.extra_view_inputs: List[Tuple[int, int, str, int, DType]] = []
  113. self.buffer_replace: WeakKeyDictionary[Buffer, Buffer] = WeakKeyDictionary()
  114. self.cnt: int = 0
  115. def __get__(self, obj, objtype): return functools.partial(self.__call__, obj) # add support for instance methods
  116. def __call__(self, *args, **kwargs) -> ReturnType:
  117. input_tensors: List[Tuple[Union[int, str], Tensor]] = \
  118. [(cast(Union[int, str], name),t) for name,t in itertools.chain(enumerate(args), sorted(kwargs.items())) if t.__class__ is Tensor]
  119. if input_tensors: Tensor.realize(*[t for _,t in input_tensors])
  120. names: List[Union[int, str]] = [name for name,_ in input_tensors]
  121. lbs: List[LazyBuffer] = flatten([t.lazydata.lbs for _,t in input_tensors])
  122. st_varvals_dtype_device = [(*lb.st.unbind(), lb.dtype, lb.device) for lb in lbs]
  123. input_buffers: List[Buffer] = [lb.base.realized for lb in lbs if lb.base.realized is not None]
  124. assert len(set(input_buffers)) == len(input_buffers), "duplicate inputs to JIT"
  125. var_vals: Dict[Variable, int] = merge_dicts([varvals for _,varvals,_,_ in st_varvals_dtype_device] + \
  126. [dict(v.unbind() for v in itertools.chain(args, kwargs.values()) if isinstance(v, Variable))])
  127. st_vars_dtype_device = [(x[0], tuple(sorted(x[1].keys(), key=lambda v: v.expr)), x[2], x[3]) for x in st_varvals_dtype_device]
  128. if not JIT or self.cnt == 0:
  129. # jit ignore
  130. with Context(BEAM=0 if getenv("IGNORE_JIT_FIRST_BEAM") else BEAM.value):
  131. self.ret = self.fxn(*args, **kwargs)
  132. if len(params:=get_parameters(self.ret)): Tensor.realize(params[0], *params[1:])
  133. elif self.cnt == 1:
  134. # jit capture
  135. self.expected_names: List[Union[int, str]] = names
  136. self.expected_st_vars_dtype_device: List[Tuple[ShapeTracker, Tuple[Variable, ...], DType, str]] = st_vars_dtype_device
  137. if capturing: raise RuntimeError(f"having TinyJit inside another TinyJit is not supported {len(capturing)=} {capturing=}")
  138. with Context(GRAPH=getenv("JITGRAPH", GRAPH.value), BEAM=getenv("JITBEAM", BEAM.value)):
  139. capturing.append(self)
  140. try:
  141. self.ret = self.fxn(*args, **kwargs)
  142. if len(params:=get_parameters(self.ret)): Tensor.realize(params[0], *params[1:])
  143. except Exception as e: raise e
  144. finally: capturing.clear()
  145. del self.buffer_replace
  146. assert len(self.jit_cache), "didn't JIT anything!"
  147. if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_buffers)} inputs")
  148. # track inputs that are views of buffers
  149. for item in self.jit_cache:
  150. for b in item.bufs:
  151. if b is not None and b._base is not None and b._base in input_buffers:
  152. input_buffers.append(b)
  153. self.extra_view_inputs.append((input_buffers.index(b.base), b.offset, b.device, b.size, b.dtype))
  154. # memory planning (optional)
  155. # Exclude buffers involved in transfer ops to preserve parallelism.
  156. noopt_buffers = {b for ji in self.jit_cache if isinstance(ji.prg, BufferXfer) for b in ji.bufs}
  157. assigned = _internal_memory_planner([cast(List[Buffer], item.bufs) for item in self.jit_cache], noopt_buffers, debug_prefix="JIT ")
  158. self.jit_cache = [ExecItem(item.prg, [assigned.get(b,b).ensure_allocated() for b in item.bufs if b is not None]) for item in self.jit_cache]
  159. # Condense the items into a graph executor.
  160. if JIT < 2: self.jit_cache = apply_graph_to_jit(self.jit_cache, input_buffers, var_vals)
  161. self.input_replace = get_input_replace(self.jit_cache, input_buffers)
  162. if DEBUG >= 1 and len(set(self.input_replace.values())) != len(input_buffers): print("WARNING: some input tensors not found")
  163. elif self.cnt >= 2:
  164. # jit exec
  165. assert self.expected_names == names, f"args mismatch in JIT: {self.expected_names=} != {names}"
  166. assert self.expected_st_vars_dtype_device == st_vars_dtype_device, \
  167. f"args mismatch in JIT: {self.expected_st_vars_dtype_device=} != {st_vars_dtype_device=}"
  168. for idx, offset, device, size, dtype in self.extra_view_inputs:
  169. input_buffers.append(Buffer(device, size, dtype, base=input_buffers[idx], offset=offset).ensure_allocated())
  170. for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].bufs[i] = input_buffers[input_idx]
  171. if DEBUG >= 1 and len(self.jit_cache) >= 10: print(f"jit execs {len(self.jit_cache)} kernels")
  172. for ei in self.jit_cache: ei.run(var_vals, jit=True)
  173. # clear jit inputs
  174. for (j,i) in self.input_replace.keys(): self.jit_cache[j].bufs[i] = None
  175. self.cnt += 1
  176. return self.ret