ops_hip.py 3.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. from __future__ import annotations
  2. import ctypes, functools
  3. from typing import Tuple
  4. import tinygrad.runtime.autogen.hip as hip
  5. from tinygrad.helpers import DEBUG, init_c_var, from_mv, init_c_struct_t
  6. from tinygrad.device import Compiled, LRUAllocator, BufferOptions
  7. from tinygrad.runtime.ops_amd import AMDCompiler, disasm
  8. from tinygrad.renderer.cstyle import HIPRenderer
  9. def check(status):
  10. if status != 0: raise RuntimeError(f"HIP Error {status}, {ctypes.string_at(hip.hipGetErrorString(status)).decode()}")
  11. class HIPProgram:
  12. def __init__(self, device:HIPDevice, name:str, lib:bytes):
  13. self.device, self.name, self.lib = device, name, lib
  14. if DEBUG >= 6: print(disasm(lib))
  15. check(hip.hipSetDevice(self.device.device_id))
  16. self.module = init_c_var(hip.hipModule_t(), lambda x: check(hip.hipModuleLoadData(ctypes.byref(x), lib)))
  17. self.prg = init_c_var(hip.hipFunction_t(), lambda x: check(hip.hipModuleGetFunction(ctypes.byref(x), self.module, name.encode("utf-8"))))
  18. def __del__(self):
  19. if hasattr(self, 'module'): check(hip.hipModuleUnload(self.module))
  20. def __call__(self, *args, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
  21. check(hip.hipSetDevice(self.device.device_id))
  22. if not hasattr(self, "vargs"):
  23. self.c_args = init_c_struct_t(tuple([(f'f{i}', hip.hipDeviceptr_t) for i in range(len(args))] +
  24. [(f'v{i}', ctypes.c_int) for i in range(len(vals))]))(*args, *vals)
  25. self.vargs = (ctypes.c_void_p * 5)(1, ctypes.cast(ctypes.byref(self.c_args), ctypes.c_void_p), 2,
  26. ctypes.cast(ctypes.pointer(ctypes.c_size_t(ctypes.sizeof(self.c_args))), ctypes.c_void_p), 3)
  27. for i in range(len(args)): self.c_args.__setattr__(f'f{i}', args[i])
  28. for i in range(len(vals)): self.c_args.__setattr__(f'v{i}', vals[i])
  29. if wait: check(hip.hipEventRecord(self.device.time_event_st, None))
  30. check(hip.hipModuleLaunchKernel(self.prg, *global_size, *local_size, 0, None, None, self.vargs))
  31. if wait:
  32. check(hip.hipEventRecord(self.device.time_event_en, None))
  33. check(hip.hipEventSynchronize(self.device.time_event_en))
  34. check(hip.hipEventElapsedTime(ctypes.byref(ret := ctypes.c_float()), self.device.time_event_st, self.device.time_event_en))
  35. return ret.value * 1e-3
  36. class HIPAllocator(LRUAllocator):
  37. def __init__(self, device:HIPDevice):
  38. self.device = device
  39. super().__init__()
  40. def _alloc(self, size:int, options:BufferOptions):
  41. check(hip.hipSetDevice(self.device.device_id))
  42. return init_c_var(hip.hipDeviceptr_t(), lambda x: check(hip.hipMalloc(ctypes.byref(x), size)))
  43. def _free(self, opaque): check(hip.hipFree(opaque))
  44. def copyin(self, dest, src: memoryview):
  45. check(hip.hipSetDevice(self.device.device_id))
  46. check(hip.hipMemcpy(dest, from_mv(src), len(src), hip.hipMemcpyHostToDevice))
  47. def copyout(self, dest:memoryview, src):
  48. self.device.synchronize()
  49. check(hip.hipMemcpy(from_mv(dest), src, len(dest), hip.hipMemcpyDeviceToHost))
  50. class HIPDevice(Compiled):
  51. def __init__(self, device:str=""):
  52. self.device_id = int(device.split(":")[1]) if ":" in device else 0
  53. self.arch = init_c_var(hip.hipDeviceProp_t(), lambda x: check(hip.hipGetDeviceProperties(x, self.device_id))).gcnArchName.decode()
  54. self.time_event_st, self.time_event_en = [init_c_var(hip.hipEvent_t(), lambda x: hip.hipEventCreate(ctypes.byref(x), 0)) for _ in range(2)]
  55. super().__init__(device, HIPAllocator(self), HIPRenderer(), AMDCompiler(self.arch), functools.partial(HIPProgram, self))
  56. def synchronize(self):
  57. check(hip.hipSetDevice(self.device_id))
  58. check(hip.hipDeviceSynchronize())