hcq.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. import collections, time
  2. from typing import List, Any, Dict, cast, Optional, Tuple, Set
  3. from tinygrad.helpers import round_up, to_mv, PROFILE
  4. from tinygrad.device import Buffer, BufferOptions, Compiled, Device
  5. from tinygrad.shape.symbolic import Variable
  6. from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
  7. from tinygrad.engine.jit import MultiGraphRunner
  8. class HCQGraph(MultiGraphRunner):
  9. def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
  10. super().__init__(jit_cache, input_rawbuffers, var_vals)
  11. self.devices = list(set(cast(Any, d) for ji in jit_cache for d in [Device[cast(Buffer, x).device] for x in ji.bufs]))
  12. # Allocate kernel args.
  13. kernargs_size: Dict[Compiled, int] = collections.defaultdict(int)
  14. for ji in self.jit_cache:
  15. if not isinstance(ji.prg, CompiledRunner): continue
  16. kernargs_size[ji.prg.device] += round_up(ji.prg.clprg.kernargs_alloc_size, 16)
  17. self.kernargs_bufs: Dict[Compiled, Any] = {dev:dev.allocator._alloc(sz, BufferOptions(cpu_access=True)) for dev,sz in kernargs_size.items()}
  18. kernargs_ptrs: Dict[Compiled, int] = {dev:buf.va_addr for dev,buf in self.kernargs_bufs.items()}
  19. # Fill initial arguments.
  20. self.kargs_addrs: Dict[int, int] = {}
  21. self.ji_args_bufs: Dict[int, memoryview] = {}
  22. self.ji_args_vars: Dict[int, memoryview] = {}
  23. for j,ji in enumerate(self.jit_cache):
  24. if not isinstance(ji.prg, CompiledRunner): continue
  25. self.kargs_addrs[j] = kernargs_ptrs[ji.prg.device]
  26. kernargs_ptrs[ji.prg.device] += round_up(ji.prg.clprg.kernargs_alloc_size, 16)
  27. ji.prg.clprg.fill_kernargs(self.kargs_addrs[j], [cast(Buffer, b)._buf for b in ji.bufs], [var_vals[v] for v in ji.prg.p.vars])
  28. self.ji_args_bufs[j] = to_mv(self.kargs_addrs[j] + ji.prg.clprg.kernargs_args_offset, len(ji.bufs) * 8).cast('Q')
  29. self.ji_args_vars[j] = to_mv(self.kargs_addrs[j] + ji.prg.clprg.kernargs_args_offset + len(ji.bufs) * 8, len(ji.prg.p.vars) * 4).cast('I')
  30. # Schedule Dependencies.
  31. # There are two types of queues on each device: copy and compute. Both must synchronize with all external operations before launching any
  32. # graph-related tasks. This synchronization uses a global timeline signal per device. Within the graph, the compute queue coordinates with
  33. # global operations and sets a kickoff signal. Any queue accessing a buffer from another device waits for this signal from the device’s
  34. # compute queue to ensure exclusive access. The compute queue signals the completion of the graph, synchronizing with the device's copy queue.
  35. self.comp_queues: Dict[Compiled, Any] = {dev: dev.hw_compute_queue_t() for dev in self.devices}
  36. self.copy_queues: Dict[Compiled, Any] = {dev: dev.hw_copy_queue_t() for dev in self.devices}
  37. self.signal_sched: Dict[int, Tuple[List, Any, Optional[int], Optional[List]]] = {} # Dict[ji_idx, (deps, signal, sigval, prof_info)]
  38. self.signals = {q: self.devices[0]._alloc_signal(value=0) for q in list(self.comp_queues.values())+list(self.copy_queues.values())}
  39. self.dev_kickoff_signal = {dev: self.devices[0]._alloc_signal(value=0) for dev in self.devices + ['CPU']} # Dict[dev, signal]
  40. self.kickoff_value = 0
  41. self.save_devs: Dict[Any, Set] = {q: set() for q in list(self.comp_queues.values()) + list(self.copy_queues.values())}
  42. for dev in self.devices: self.save_devs[self.comp_queues[dev]].add(dev)
  43. self.graph_timeline = {dev: 0 for dev in self.devices} # Dict[dev, last graph sigval]
  44. self.last_ji: Dict[Any, Any] = {q: None for q in list(self.comp_queues.values()) + list(self.copy_queues.values())}
  45. for j,ji in enumerate(self.jit_cache):
  46. enqueue_dev = ji.prg.device if isinstance(ji.prg, CompiledRunner) else Device[ji.bufs[1].device] #type:ignore
  47. enqueue_queue = self.comp_queues[enqueue_dev] if isinstance(ji.prg, CompiledRunner) else self.copy_queues[enqueue_dev]
  48. out_signal = self.signals[enqueue_queue]
  49. writable_buffers = ji.prg.p.outcount if isinstance(ji.prg, CompiledRunner) else 1
  50. deps = self.access_resources(enqueue_queue, ji.bufs[writable_buffers:], ji.bufs[:writable_buffers], j + 1)
  51. if isinstance(ji.prg, CompiledRunner):
  52. # Update signal on compute kernel to depend on the previous kernel.
  53. if (last_j:=self.last_ji[enqueue_queue]) is not None: deps = [x for x in deps if id(x[0]) != id(out_signal)] + [(out_signal, last_j + 1)]
  54. # Remove self-dependency for AMD or NV with only 1 same-queue dep, since NV chains 2+ execs in this case, eliminating dep need.
  55. if (dname:=enqueue_dev.dname.split(":", 1)[0]) == "AMD" or (dname == "NV" and len(deps) == 1 and id(deps[0][0]) == id(out_signal)):
  56. deps = [x for x in deps if id(x[0]) != id(out_signal)]
  57. elif isinstance(ji.prg, BufferXfer): deps = [x for x in deps if id(x[0]) != id(out_signal)]
  58. # Go through all dependencies and, if we need the signal from that ji, enable it by setting the signal value in the signal schedule.
  59. for sig, val in deps:
  60. if id(sig) in [id(x) for x in self.signals.values()]:
  61. self.signal_sched[val - 1] = self.signal_sched[val - 1][:2] + (val,) + self.signal_sched[val - 1][3:]
  62. prof_ji_desc = ji.prg.clprg.name if isinstance(ji.prg, CompiledRunner) else f"{ji.bufs[1].device} -> {ji.bufs[0].device}" # type: ignore
  63. prof_info = ([enqueue_dev._alloc_signal() for _ in range(2)] + [enqueue_dev, prof_ji_desc, isinstance(ji.prg, BufferXfer)]) if PROFILE else None
  64. self.signal_sched[j] = (deps, out_signal, None if isinstance(ji.prg, CompiledRunner) else (j + 1), prof_info)
  65. self.last_ji[enqueue_queue] = j
  66. # Build hardware queues.
  67. self.op_cmd_idx: Dict[int, Tuple[Any, int]] = {}
  68. self.copy_to_devs: Dict[Compiled, Set[Compiled]] = {dev: set() for dev in self.devices}
  69. self.kickoff_wait_cmds: Dict[Any, List] = {q: list() for q in list(self.comp_queues.values()) + list(self.copy_queues.values())}
  70. for dev in self.devices:
  71. self.comp_queues[dev].memory_barrier().wait(dev.timeline_signal, dev.timeline_value - 1) \
  72. .wait(self.dev_kickoff_signal['CPU'], self.kickoff_value).signal(self.dev_kickoff_signal[dev], self.kickoff_value)
  73. for j,ji in enumerate(self.jit_cache):
  74. deps, signal, signal_val, prof_info = self.signal_sched[j]
  75. enqueue_queue = self.copy_queues[Device[ji.bufs[1].device]] if isinstance(ji.prg, BufferXfer) else self.comp_queues[ji.prg.device] #type:ignore
  76. # Encode waits and start profile timestamp (if needed).
  77. for sig, val in deps:
  78. enqueue_queue.wait(sig, val)
  79. if id(sig) in [id(x) for x in self.dev_kickoff_signal.values()]: self.kickoff_wait_cmds[enqueue_queue].append(len(enqueue_queue) - 1)
  80. if prof_info: enqueue_queue.timestamp(prof_info[0])
  81. # Encode main commands based on ji type.
  82. if isinstance(ji.prg, CompiledRunner): enqueue_queue.exec(ji.prg.clprg, self.kargs_addrs[j], *ji.prg.p.launch_dims(var_vals))
  83. elif isinstance(ji.prg, BufferXfer):
  84. dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
  85. Device[src.device]._gpu_map(dest._buf) #type: ignore
  86. enqueue_queue.copy(dest._buf.va_addr, src._buf.va_addr, dest.nbytes)
  87. self.copy_to_devs[Device[dest.device]].add(Device[src.device])
  88. self.op_cmd_idx[j] = (enqueue_queue, len(enqueue_queue) - 1)
  89. if signal_val is not None: enqueue_queue.signal(signal, signal_val)
  90. # Encode finish profile timestamp (if needed).
  91. if prof_info: enqueue_queue.timestamp(prof_info[1])
  92. for dev in self.devices:
  93. for dep_dev in list(self.copy_to_devs[dev]) + [dev]:
  94. if (last_j:=self.last_ji[self.copy_queues[dep_dev]]) is None: continue
  95. self.comp_queues[dev].wait(self.signals[self.copy_queues[dep_dev]], self.signal_sched[last_j][2])
  96. self.comp_queues[dev].signal(dev.timeline_signal, dev.timeline_value)
  97. if hasattr(self.comp_queues[dev], 'bind'): self.comp_queues[dev].bind(dev)
  98. if hasattr(self.copy_queues[dev], 'bind') and self.last_ji[self.copy_queues[dev]] is not None: self.copy_queues[dev].bind(dev)
  99. def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
  100. # Wait and restore signals
  101. self.kickoff_value += 1
  102. for dev in self.devices: dev._wait_signal(dev.timeline_signal, self.graph_timeline[dev])
  103. for queue in self.comp_queues.values(): self.devices[0]._set_signal(self.signals[queue], 0)
  104. for queue in self.copy_queues.values(): self.devices[0]._set_signal(self.signals[queue], 0)
  105. self.devices[0]._set_signal(self.dev_kickoff_signal['CPU'], self.kickoff_value)
  106. if PROFILE and self.kickoff_value > 1:
  107. for _,_,_,(st,en,dev,desc,is_cp) in self.signal_sched.values(): #type: ignore
  108. dev.raw_prof_records += [(dev._read_timestamp(st), dev._read_timestamp(en), desc, is_cp)]
  109. # Update rawbuffers
  110. for (j,i),input_idx in self.input_replace.items():
  111. if j in self.ji_args_bufs: self.ji_args_bufs[j][i] = input_rawbuffers[input_idx]._buf.va_addr
  112. else: self.op_cmd_idx[j][0].update_copy(self.op_cmd_idx[j][1], **{('dest' if i == 0 else 'src'): input_rawbuffers[input_idx]._buf.va_addr})
  113. # Update var_vals
  114. for j in self.jc_idx_with_updatable_var_vals:
  115. for i,v in enumerate(cast(CompiledRunner, self.jit_cache[j].prg).p.vars): self.ji_args_vars[j][i] = var_vals[v]
  116. for j in self.jc_idx_with_updatable_launch_dims:
  117. queue, cmd_ptr = self.op_cmd_idx[j]
  118. queue.update_exec(cmd_ptr, *cast(CompiledRunner, self.jit_cache[j].prg).p.launch_dims(var_vals))
  119. for dev in self.devices:
  120. self.comp_queues[dev].update_wait(1, dev.timeline_signal, dev.timeline_value - 1).update_wait(2, value=self.kickoff_value) \
  121. .update_signal(3, value=self.kickoff_value) \
  122. .update_signal(len(self.comp_queues[dev]) - 1, dev.timeline_signal, dev.timeline_value).submit(dev)
  123. if self.last_ji[(cp_queue:=self.copy_queues[dev])] is not None:
  124. for cmd_idx in self.kickoff_wait_cmds[cp_queue]: cp_queue.update_wait(cmd_idx, value=self.kickoff_value)
  125. cp_queue.submit(dev)
  126. self.graph_timeline[dev] = dev.timeline_value
  127. dev.timeline_value += 1
  128. if wait:
  129. st = time.perf_counter()
  130. for dev in self.devices: dev._wait_signal(dev.timeline_signal, self.graph_timeline[dev])
  131. return time.perf_counter() - st
  132. return None
  133. def access_resources(self, queue, read, write, new_val):
  134. deps = self._access_resources(read, write, (queue, new_val))
  135. sync_signals = []
  136. for dep_queue,_ in deps: self.save_devs[queue].update(self.save_devs[dep_queue])
  137. for buf in read+write:
  138. if buf.device not in self.save_devs[queue]:
  139. self.save_devs[queue].add(buf.device)
  140. sync_signals += [(self.dev_kickoff_signal[Device[buf.device]], self.kickoff_value)]
  141. return [(self.signals[k], max(v for x, v in deps if id(x) == idk)) for idk, k in {id(x[0]): x[0] for x in deps}.items()] + sync_signals
  142. def __del__(self):
  143. for dev in self.devices: dev._wait_signal(dev.timeline_signal, self.graph_timeline[dev])
  144. # Graph is destructed. No need to keep signals any more, so return them as part of profiling.
  145. if PROFILE and self.kickoff_value > 1:
  146. for _,_,_,(st,en,dev,desc,is_cp) in self.signal_sched.values(): dev.sig_prof_records += [(st, en, desc, is_cp)] #type: ignore
  147. map(self.devices[0]._free_signal, list(self.dev_kickoff_signal.values()) + list(self.signals.values()))
  148. for dev, buf in self.kernargs_bufs.items(): dev.allocator._free(buf, BufferOptions(cpu_access=True))