nvgpu.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. import ctypes, ctypes.util, time
  2. import tinygrad.runtime.autogen.nv_gpu as nv_gpu
  3. from enum import Enum, auto
  4. from extra.mockgpu.gpu import VirtGPU
  5. from tinygrad.helpers import to_mv, init_c_struct_t
  6. def make_qmd_struct_type():
  7. fields = []
  8. bits = [(name,dt) for name,dt in nv_gpu.__dict__.items() if name.startswith("NVC6C0_QMDV03_00") and isinstance(dt, tuple)]
  9. bits += [(name+f"_{i}",dt(i)) for name,dt in nv_gpu.__dict__.items() for i in range(8) if name.startswith("NVC6C0_QMDV03_00") and callable(dt)]
  10. bits = sorted(bits, key=lambda x: x[1][1])
  11. for i,(name, data) in enumerate(bits):
  12. if i > 0 and (gap:=(data[1] - bits[i-1][1][0] - 1)) != 0: fields.append((f"_reserved{i}", ctypes.c_uint32, gap))
  13. fields.append((name.replace("NVC6C0_QMDV03_00_", "").lower(), ctypes.c_uint32, data[0]-data[1]+1))
  14. return init_c_struct_t(tuple(fields))
  15. qmd_struct_t = make_qmd_struct_type()
  16. assert ctypes.sizeof(qmd_struct_t) == 0x40 * 4
  17. try:
  18. gpuocelot_lib = ctypes.CDLL(ctypes.util.find_library("gpuocelot"))
  19. gpuocelot_lib.ptx_run.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int] # noqa: E501
  20. except Exception: pass
  21. class SchedResult(Enum): CONT = auto(); YIELD = auto() # noqa: E702
  22. class GPFIFO:
  23. def __init__(self, token, base, entries_cnt):
  24. self.token, self.base, self.entries_cnt = token, base, entries_cnt
  25. self.gpfifo = to_mv(self.base, self.entries_cnt * 8).cast("Q")
  26. self.ctrl = nv_gpu.AmpereAControlGPFifo.from_address(self.base + self.entries_cnt * 8)
  27. self.state = {}
  28. # Buf exec state
  29. self.buf = None
  30. self.buf_sz = 0
  31. self.buf_ptr = 0
  32. def _next_dword(self):
  33. assert self.buf is not None
  34. x = self.buf[self.buf_ptr]
  35. self.buf_ptr += 1
  36. return x
  37. def _next_header(self):
  38. header = self._next_dword()
  39. typ = (header >> 28) & 0b111
  40. size = (header >> 16) & 0xFFF
  41. subc = (header >> 13) & 0x7
  42. mthd = (header & 0x1FFF) << 2
  43. return typ, size, subc, mthd
  44. def _state(self, reg): return self.state[reg]
  45. def _state64(self, reg): return (self.state[reg] << 32) + self.state[reg + 4]
  46. def _state64_le(self, reg): return (self.state[reg + 4] << 32) + self.state[reg]
  47. def _reset_buf_state(self): self.buf, self.buf_ptr = None, 0
  48. def _set_buf_state(self, gpfifo_entry):
  49. ptr = ((gpfifo_entry >> 2) & 0xfffffffff) << 2
  50. sz = ((gpfifo_entry >> 42) & 0x1fffff) << 2
  51. self.buf = to_mv(ptr, sz).cast("I")
  52. self.buf_sz = sz // 4
  53. def execute(self) -> bool:
  54. initial_off = self.buf_ptr
  55. while self.ctrl.GPGet != self.ctrl.GPPut:
  56. self._set_buf_state(self.gpfifo[self.ctrl.GPGet])
  57. if not self.execute_buf():
  58. # Buffer isn't executed fully, check if any progress and report.
  59. # Do not move GPGet in this case, will continue from the same state next time.
  60. return self.buf_ptr != initial_off
  61. self.ctrl.GPGet = (self.ctrl.GPGet + 1) % self.entries_cnt
  62. self._reset_buf_state()
  63. return True
  64. def execute_buf(self) -> bool:
  65. while self.buf_ptr < self.buf_sz:
  66. init_off = self.buf_ptr
  67. typ, size, subc, mthd = self._next_header()
  68. cmd_end_off = self.buf_ptr + size
  69. while self.buf_ptr < cmd_end_off:
  70. res = self.execute_cmd(mthd)
  71. if res == SchedResult.YIELD:
  72. self.buf_ptr = init_off # just revert to the header
  73. return False
  74. mthd += 4
  75. return True
  76. def execute_qmd(self, qmd_addr):
  77. qmd = qmd_struct_t.from_address(qmd_addr)
  78. prg_addr = qmd.program_address_lower + (qmd.program_address_upper << 32)
  79. const0 = to_mv(qmd.constant_buffer_addr_lower_0 + (qmd.constant_buffer_addr_upper_0 << 32), 0x160).cast('I')
  80. args_cnt, vals_cnt = const0[0], const0[1]
  81. args_addr = qmd.constant_buffer_addr_lower_0 + (qmd.constant_buffer_addr_upper_0 << 32) + 0x160
  82. args = to_mv(args_addr, args_cnt*8).cast('Q')
  83. vals = to_mv(args_addr + args_cnt*8, vals_cnt*4).cast('I')
  84. cargs = [ctypes.cast(args[i], ctypes.c_void_p) for i in range(args_cnt)] + [ctypes.cast(vals[i], ctypes.c_void_p) for i in range(vals_cnt)]
  85. gx, gy, gz = qmd.cta_raster_width, qmd.cta_raster_height, qmd.cta_raster_depth
  86. lx, ly, lz = qmd.cta_thread_dimension0, qmd.cta_thread_dimension1, qmd.cta_thread_dimension2
  87. gpuocelot_lib.ptx_run(ctypes.cast(prg_addr, ctypes.c_char_p), args_cnt+vals_cnt, (ctypes.c_void_p*len(cargs))(*cargs), lx, ly, lz, gx, gy, gz, 0)
  88. if qmd.release0_enable:
  89. rel0 = to_mv(qmd.release0_address_lower + (qmd.release0_address_upper << 32), 0x8).cast('Q')
  90. rel0[0] = qmd.release0_payload_lower + (qmd.release0_payload_upper << 32)
  91. if qmd.dependent_qmd0_enable:
  92. if qmd.dependent_qmd0_action == 1: self.execute_qmd(qmd.dependent_qmd0_pointer << 8)
  93. else: raise RuntimeError("unsupported dependent qmd action")
  94. def execute_cmd(self, cmd) -> SchedResult:
  95. if cmd == nv_gpu.NVC56F_SEM_EXECUTE: return self._exec_signal()
  96. elif cmd == nv_gpu.NVC6C0_LAUNCH_DMA: return self._exec_nvc6c0_dma()
  97. elif cmd == nv_gpu.NVC6B5_LAUNCH_DMA: return self._exec_nvc6b5_dma()
  98. elif cmd == nv_gpu.NVC6C0_SEND_SIGNALING_PCAS2_B: return self._exec_pcas2()
  99. elif cmd == 0x0320: return self._exec_load_inline_qmd() # NVC6C0_LOAD_INLINE_QMD_DATA
  100. else: self.state[cmd] = self._next_dword() # just state update
  101. return SchedResult.CONT
  102. def _exec_signal(self) -> SchedResult:
  103. signal = self._state64_le(nv_gpu.NVC56F_SEM_ADDR_LO)
  104. val = self._state64_le(nv_gpu.NVC56F_SEM_PAYLOAD_LO)
  105. flags = self._next_dword()
  106. typ = (flags >> 0) & 0b111
  107. timestamp = (flags & (1 << 25)) == (1 << 25)
  108. if typ == 1:
  109. to_mv(signal, 8).cast('Q')[0] = val
  110. if timestamp: to_mv(signal + 8, 8).cast('Q')[0] = int(time.perf_counter() * 1e9)
  111. elif typ == 3:
  112. mval = to_mv(signal, 8).cast('Q')[0]
  113. return SchedResult.CONT if mval >= val else SchedResult.YIELD
  114. else: raise RuntimeError(f"Unsupported type={typ} in exec wait/signal")
  115. return SchedResult.CONT
  116. def _exec_load_inline_qmd(self):
  117. qmd_addr = self._state64(nv_gpu.NVC6C0_SET_INLINE_QMD_ADDRESS_A) << 8
  118. assert qmd_addr != 0x0, f"invalid qmd address {qmd_addr}"
  119. qmd_data = [self._next_dword() for _ in range(0x40)]
  120. cdata = (ctypes.c_uint32 * len(qmd_data))(*qmd_data)
  121. ctypes.memmove(qmd_addr, cdata, 0x40 * 4)
  122. self.execute_qmd(qmd_addr)
  123. def _exec_nvc6c0_dma(self):
  124. addr = self._state64(nv_gpu.NVC6C0_OFFSET_OUT_UPPER)
  125. sz = self._state(nv_gpu.NVC6C0_LINE_LENGTH_IN)
  126. lanes = self._state(nv_gpu.NVC6C0_LINE_COUNT)
  127. assert lanes == 1, f"unsupported lanes > 1 in _exec_nvc6c0_dma: {lanes}"
  128. flags = self._next_dword()
  129. assert flags == 0x41, f"unsupported flags in _exec_nvc6c0_dma: {flags}"
  130. typ, dsize, subc, mthd = self._next_header()
  131. assert typ == 6 and mthd == nv_gpu.NVC6C0_LOAD_INLINE_DATA, f"Expected inline data not found after nvc6c0_dma, {typ=} {mthd=}"
  132. copy_data = [self._next_dword() for _ in range(dsize)]
  133. assert len(copy_data) * 4 == sz, f"different copy sizes in _exec_nvc6c0_dma: {len(copy_data) * 4} != {sz}"
  134. cdata = (ctypes.c_uint32 * len(copy_data))(*copy_data)
  135. ctypes.memmove(addr, cdata, sz)
  136. def _exec_nvc6b5_dma(self):
  137. flags = self._next_dword()
  138. if (flags & 0b11) != 0:
  139. src = self._state64(nv_gpu.NVC6B5_OFFSET_IN_UPPER)
  140. dst = self._state64(nv_gpu.NVC6B5_OFFSET_OUT_UPPER)
  141. sz = self._state(nv_gpu.NVC6B5_LINE_LENGTH_IN)
  142. assert flags == 0x182, f"unsupported flags in _exec_nvc6b5_dma: {flags}"
  143. ctypes.memmove(dst, src, sz)
  144. elif ((flags >> 3) & 0b11) != 0:
  145. src = to_mv(self._state64(nv_gpu.NVC6B5_SET_SEMAPHORE_A), 0x4).cast('I')
  146. val = self._state(nv_gpu.NVC6B5_SET_SEMAPHORE_PAYLOAD)
  147. src[0] = val
  148. else: raise RuntimeError("unknown nvc6b5_dma flags")
  149. def _exec_pcas2(self):
  150. qmd_addr = self._state(nv_gpu.NVC6C0_SEND_PCAS_A) << 8
  151. typ = self._next_dword()
  152. if typ == 2 or typ == 9: # schedule
  153. self.execute_qmd(qmd_addr)
  154. class NVGPU(VirtGPU):
  155. def __init__(self, gpuid):
  156. super().__init__(gpuid)
  157. self.mapped_ranges = set()
  158. self.queues = []
  159. def map_range(self, vaddr, size): self.mapped_ranges.add((vaddr, size))
  160. def unmap_range(self, vaddr, size): self.mapped_ranges.remove((vaddr, size))
  161. def add_gpfifo(self, base, entries_count):
  162. self.queues.append(GPFIFO(token:=len(self.queues), base, entries_count))
  163. return token
  164. def gpu_uuid(self, sz=16): return self.gpuid.to_bytes(sz, byteorder='big', signed=False)