ops_rhip.py 866 B

123456789101112131415161718
  1. import ctypes
  2. from tinygrad.device import Compiled, MallocAllocator
  3. from tinygrad.renderer.cstyle import HIPRenderer
  4. from tinygrad.runtime.ops_hsa import HSACompiler
  5. rhip = ctypes.CDLL("/usr/local/lib/libremu.so")
  6. class RHIPProgram:
  7. def __init__(self, name:str, lib:bytes):
  8. self.name, self.lib = name, lib
  9. def __call__(self, *args, global_size, local_size, vals=(), wait=False):
  10. args = (*args, *vals)
  11. rhip.hipModuleLaunchKernel(self.lib, len(self.lib), *global_size, *local_size, 0, None, None,
  12. len(args), (ctypes.c_void_p * len(args))(*[ctypes.cast(x, ctypes.c_void_p) for x in args]))
  13. class RHIPDevice(Compiled):
  14. def __init__(self, device:str=""):
  15. self.device = int(device.split(":")[1]) if ":" in device else 0
  16. super().__init__(device, MallocAllocator, HIPRenderer(), HSACompiler("gfx1100"), RHIPProgram)