ops_hip.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. from __future__ import annotations
  2. import ctypes, functools, subprocess, io
  3. from typing import Tuple, TypeVar, List, Any, cast, Set
  4. import tinygrad.runtime.autogen.hip as hip
  5. from tinygrad.helpers import DEBUG, getenv, init_c_var
  6. from tinygrad.helpers import from_mv, round_up, to_mv, colored, init_c_struct_t
  7. from tinygrad.device import Compiled, LRUAllocator, BufferOptions, Runner, Device, Buffer, MallocAllocator, update_stats, Compiler, CompilerOptions
  8. from tinygrad.renderer.cstyle import HIPRenderer
  9. from tinygrad.runtime.support.hip_comgr import compile_hip
  10. from tinygrad.renderer.rdna import uops_to_rdna
  11. class RDNACompiler(Compiler):
  12. linearizer_opts = LinearizerOptions("HIP", has_tensor_cores=True)
  13. def __init__(self, arch:str):
  14. self.arch = arch
  15. super().__init__(f"compile_rdna_{self.arch}")
  16. def render(self, name:str, uops) -> str: return uops_to_rdna(name, uops)
  17. def compile(self, src:str) -> bytes:
  18. ret = compile_hip(src, self.arch, True)
  19. #with open("/tmp/out.so", "wb") as f: f.write(ret)
  20. return ret
  21. class HIPCompiler(Compiler):
  22. compiler_opts = CompilerOptions("HIP", has_tensor_cores=True, shared_max=65536)
  23. def __init__(self, arch:str):
  24. self.arch = arch
  25. super().__init__(f"compile_hip_{self.arch}")
  26. def render(self, name:str, uops) -> str: return HIPRenderer(name, uops)
  27. def compile(self, src:str) -> bytes: return compile_hip(src, self.arch)
  28. hip_current_device = None
  29. def hip_set_device(d:int):
  30. global hip_current_device
  31. if d == hip_current_device: return
  32. check(hip.hipSetDevice(d))
  33. hip_current_device = d
  34. def check(status):
  35. if status != 0: raise RuntimeError(f"HIP Error {status}, {ctypes.string_at(hip.hipGetErrorString(status)).decode()}")
  36. class HIPProgram:
  37. def __init__(self, device:int, name:str, lib:bytes):
  38. self.device, self.name, self.lib = device, name, lib
  39. if DEBUG >= 6:
  40. asm = subprocess.check_output(["/opt/rocm/llvm/bin/llvm-objdump", '-d', '-'], input=lib)
  41. print('\n'.join([x for x in asm.decode('utf-8').split("\n") if 's_code_end' not in x]))
  42. hip_set_device(self.device)
  43. self.module = init_c_var(hip.hipModule_t(), lambda x: check(hip.hipModuleLoadData(ctypes.byref(x), lib)))
  44. self.prg = init_c_var(hip.hipFunction_t(), lambda x: check(hip.hipModuleGetFunction(ctypes.byref(x), self.module, name.encode("utf-8"))))
  45. def __del__(self):
  46. if hasattr(self, 'module'): check(hip.hipModuleUnload(self.module))
  47. def __call__(self, *args, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
  48. hip_set_device(self.device)
  49. if not hasattr(self, "vargs"):
  50. self.c_args = init_c_struct_t(tuple([(f'f{i}', hip.hipDeviceptr_t) for i in range(len(args))] +
  51. [(f'v{i}', ctypes.c_int) for i in range(len(vals))]))(*args, *vals)
  52. self.vargs = (ctypes.c_void_p * 5)(ctypes.c_void_p(1), ctypes.cast(ctypes.byref(self.c_args), ctypes.c_void_p),
  53. ctypes.c_void_p(2), ctypes.cast(ctypes.byref(ctypes.c_size_t(ctypes.sizeof(self.c_args))), ctypes.c_void_p),
  54. ctypes.c_void_p(3))
  55. else:
  56. for i in range(len(args)): self.c_args.__setattr__(f'f{i}', args[i])
  57. for i in range(len(vals)): self.c_args.__setattr__(f'v{i}', vals[i])
  58. if wait:
  59. evs = [init_c_var(hip.hipEvent_t(), lambda x: hip.hipEventCreate(ctypes.byref(x), 0)) for _ in range(2)]
  60. check(hip.hipEventRecord(evs[0], None))
  61. check(hip.hipModuleLaunchKernel(self.prg, *global_size, *local_size, 0, None, None, self.vargs))
  62. if wait:
  63. check(hip.hipEventRecord(evs[1], None))
  64. check(hip.hipEventSynchronize(evs[1]))
  65. check(hip.hipEventElapsedTime(ctypes.byref(ret := ctypes.c_float()), evs[0], evs[1]))
  66. for ev in evs: check(hip.hipEventDestroy(ev))
  67. return ret.value * 1e-3
  68. return None
  69. T = TypeVar("T")
  70. CHUNK_SIZE, PAGE_SIZE = 256*1024*1024, 0x1000
  71. class HIPAllocator(LRUAllocator):
  72. def __init__(self, device:HIPDevice):
  73. self.device = device
  74. self.track_cross_device: Set[HIPDevice] = set()
  75. super().__init__()
  76. def full_synchronize(self):
  77. self.device.synchronize()
  78. for x in self.track_cross_device: x.synchronize()
  79. self.track_cross_device.clear()
  80. def free_cache(self):
  81. self.full_synchronize()
  82. return super().free_cache()
  83. def _alloc(self, size:int):
  84. hip_set_device(self.device.device)
  85. return init_c_var(hip.hipDeviceptr_t(), lambda x: check(hip.hipMalloc(ctypes.byref(x), size)))
  86. def _alloc_with_options(self, size:int, options:BufferOptions):
  87. hip_set_device(self.device.device)
  88. if options.uncached:
  89. return init_c_var(hip.hipDeviceptr_t(), lambda x: check(hip.hipExtMallocWithFlags(ctypes.byref(x), size, 3))) # hipDeviceMallocUncached = 3
  90. elif options.host:
  91. return init_c_var(hip.hipDeviceptr_t(), lambda x: check(hip.hipHostMalloc(ctypes.byref(x), size, 2 if options.signal else 0)))
  92. else:
  93. raise Exception("no options")
  94. def _free(self, opaque:T): check(hip.hipFree(opaque))
  95. def copy_from_fd(self, dest, fd, offset, size):
  96. hip_set_device(self.device.device)
  97. if not hasattr(self, 'hb'):
  98. self.hb = [self._alloc_with_options(CHUNK_SIZE, BufferOptions(host=True)) for _ in range(2)]
  99. self.hb_events = [None, None]
  100. self.hb_polarity = 0
  101. fo = io.FileIO(fd, "a+b", closefd=False)
  102. fo.seek(offset - (minor_offset:=offset % PAGE_SIZE))
  103. copied_in = 0
  104. for local_offset in range(0, size+minor_offset, CHUNK_SIZE):
  105. local_size = min(round_up(size+minor_offset, PAGE_SIZE)-local_offset, CHUNK_SIZE)
  106. if self.hb_events[self.hb_polarity] is not None:
  107. # NOTE: block doesn't work here because we modify the CPU memory
  108. check(hip.hipEventSynchronize(self.hb_events[self.hb_polarity]))
  109. check(hip.hipEventDestroy(self.hb_events[self.hb_polarity]))
  110. self.hb_events[self.hb_polarity] = None
  111. fo.readinto(to_mv(self.hb[self.hb_polarity], local_size))
  112. check(hip.hipMemcpyAsync(ctypes.c_void_p(dest.value + copied_in), ctypes.c_void_p(self.hb[self.hb_polarity].value + minor_offset),
  113. copy_size:=min(local_size-minor_offset, size-copied_in), hip.hipMemcpyHostToDevice, None))
  114. self.hb_events[self.hb_polarity] = init_c_var(hip.hipEvent_t(), lambda x: check(hip.hipEventCreate(ctypes.byref(x))))
  115. check(hip.hipEventRecord(self.hb_events[self.hb_polarity], None))
  116. copied_in += copy_size
  117. self.hb_polarity = (self.hb_polarity+1) % len(self.hb)
  118. minor_offset = 0 # only on the first
  119. def copyin(self, dest:T, src: memoryview):
  120. hip_set_device(self.device.device)
  121. host_mem = self._alloc_with_options(len(src), BufferOptions(host=True))
  122. self.device.pending_copyin.append(host_mem)
  123. ctypes.memmove(host_mem, from_mv(src), len(src))
  124. check(hip.hipMemcpyAsync(dest, host_mem, len(src), hip.hipMemcpyHostToDevice, None))
  125. def copyout(self, dest:memoryview, src:T):
  126. self.full_synchronize()
  127. hip_set_device(self.device.device)
  128. check(hip.hipMemcpy(from_mv(dest), src, len(dest), hip.hipMemcpyDeviceToHost))
  129. def transfer(self, dest:T, src:T, sz:int, **kwargs):
  130. hip_set_device(self.device.device)
  131. check(hip.hipMemcpyAsync(dest, src, sz, hip.hipMemcpyDeviceToDevice, None))
  132. class HIPSyncEvent(Runner):
  133. def __init__(self, lb):
  134. self.lb, self.device, self.dname = lb, cast(HIPDevice, Device[lb.device]), lb.device
  135. super().__init__()
  136. def __call__(self, rawbufs:List[Buffer], var_vals, wait=False, jit=False):
  137. to_mv(rawbufs[0]._buf, 4).cast("I")[0] = 0
  138. hip_set_device(self.device.device)
  139. check(hip.hipStreamWriteValue32(None, rawbufs[0]._buf, 1, 0))
  140. update_stats(colored("sync", "red"), 0, 0, {}, None, 1, jit, device=self.dname)
  141. class HIPWaitEvent(Runner):
  142. def __init__(self, device):
  143. self.device, self.dname = cast(HIPDevice, Device[device]), device
  144. super().__init__()
  145. def __call__(self, rawbufs:List[Buffer], var_vals, wait=False, jit=False):
  146. hip_set_device(self.device.device)
  147. check(hip.hipStreamWaitValue32(None, rawbufs[0]._buf, 1, 1, 0xFFFFFFFF))
  148. update_stats(colored("wait", "RED"), 0, 0, {}, None, 1, jit, device=self.dname)
  149. if getenv("HIPCPU"):
  150. rhip = ctypes.CDLL("/usr/local/lib/libremu.so")
  151. class RHIPProgram:
  152. def __init__(self, name:str, lib:bytes):
  153. self.name, self.lib = name, lib
  154. def __call__(self, *args, global_size, local_size, vals=(), wait=False):
  155. args = (*args, *vals)
  156. rhip.hipModuleLaunchKernel(self.lib, len(self.lib), *global_size, *local_size, 0, None, None,
  157. len(args), (ctypes.c_void_p * len(args))(*[ctypes.cast(x, ctypes.c_void_p) for x in args]))
  158. class HIPDevice(Compiled):
  159. def __init__(self, device:str=""):
  160. self.device = int(device.split(":")[1]) if ":" in device else 0
  161. self.pending_copyin: List[ctypes.c_void_p] = []
  162. self.track_cross_buffer: List[Any] = []
  163. self.peers: Set[int] = set()
  164. if getenv("HIPCPU"):
  165. super().__init__(device, MallocAllocator, HIPCompiler("gfx1100"), RHIPProgram)
  166. else:
  167. self.arch = init_c_var(hip.hipDeviceProp_t(), lambda x: check(hip.hipGetDeviceProperties(x, self.device))).gcnArchName.decode()
  168. from tinygrad.runtime.graph.hip import HIPGraph
  169. super().__init__(device, HIPAllocator(self), RDNACompiler(self.arch) if getenv("RDNA") else HIPCompiler(self.arch),
  170. functools.partial(HIPProgram, self.device), HIPGraph)
  171. def synchronize(self):
  172. if getenv("HIPCPU"): return
  173. hip_set_device(self.device)
  174. check(hip.hipDeviceSynchronize())
  175. for opaque in self.pending_copyin: check(hip.hipFree(opaque))
  176. self.track_cross_buffer.clear()
  177. self.pending_copyin.clear()
  178. def enable_peer(self, dnum):
  179. if self.device == dnum or dnum in self.peers: return
  180. hip_set_device(self.device)
  181. check(hip.hipDeviceEnablePeerAccess(dnum, 0))
  182. self.peers.add(dnum)