hip_ioctl.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. # type: ignore
  2. import ctypes, ctypes.util, struct, platform, pathlib, re, time, os
  3. start = time.perf_counter()
  4. # *** ioctl lib ***
  5. libc = ctypes.CDLL(ctypes.util.find_library("c"))
  6. processor = platform.processor()
  7. IOCTL_SYSCALL = {"aarch64": 0x1d, "x86_64":16}[processor]
  8. def get_struct(argp, stype):
  9. return ctypes.cast(ctypes.c_void_p(argp), ctypes.POINTER(stype)).contents
  10. def format_struct(s):
  11. sdats = []
  12. for field_name, field_type in s._fields_:
  13. dat = getattr(s, field_name)
  14. if isinstance(dat, int): sdats.append(f"{field_name}:0x{dat:X}")
  15. else: sdats.append(f"{field_name}:{dat}")
  16. return sdats
  17. def install_hook(c_function, python_function):
  18. python_function_addr = ctypes.cast(ctypes.byref(python_function), ctypes.POINTER(ctypes.c_ulong)).contents.value
  19. # AARCH64 trampoline to ioctl
  20. if processor == "aarch64":
  21. # 0x0000000000000000: 70 00 00 10 adr x16, #0xc
  22. # 0x0000000000000004: 10 02 40 F9 ldr x16, [x16]
  23. # 0x0000000000000008: 00 02 1F D6 br x16
  24. tramp = b"\x70\x00\x00\x10\x10\x02\x40\xf9\x00\x02\x1f\xd6"
  25. tramp += struct.pack("Q", python_function_addr)
  26. elif processor == "x86_64":
  27. # 0x0000000000000000: 49 B8 aa aa aa aa aa aa aa aa movabs r8, <address>
  28. # 0x000000000000000a: 41 FF E0 jmp r8
  29. tramp = b"\x49\xB8" + struct.pack("Q", python_function_addr) + b"\x41\xFF\xE0"
  30. else:
  31. raise Exception(f"processor {processor} not supported")
  32. # get real ioctl address
  33. ioctl_address = ctypes.cast(ctypes.byref(c_function), ctypes.POINTER(ctypes.c_ulong))
  34. # hook ioctl
  35. ret = libc.mprotect(ctypes.c_ulong((ioctl_address.contents.value//0x1000)*0x1000), 0x2000, 7)
  36. assert ret == 0
  37. libc.memcpy(ioctl_address.contents, ctypes.create_string_buffer(tramp), len(tramp))
  38. # *** ioctl lib end ***
  39. import tinygrad.runtime.autogen.kfd as kfd_ioctl
  40. def ioctls_from_header():
  41. hdr = pathlib.Path("/usr/include/linux/kfd_ioctl.h").read_text().replace("\\\n", "")
  42. pattern = r'#define\s+(AMDKFD_IOC_[A-Z0-9_]+)\s+AMDKFD_IOW?R?\((0x[0-9a-fA-F]+),\s+struct\s([A-Za-z0-9_]+)\)'
  43. matches = re.findall(pattern, hdr, re.MULTILINE)
  44. return {int(nr, 0x10):(name, getattr(kfd_ioctl, "struct_"+sname)) for name, nr, sname in matches}
  45. nrs = ioctls_from_header()
  46. @ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_int, ctypes.c_ulong, ctypes.c_void_p)
  47. def ioctl(fd, request, argp):
  48. st = time.perf_counter()
  49. ret = libc.syscall(IOCTL_SYSCALL, ctypes.c_int(fd), ctypes.c_ulong(request), ctypes.c_void_p(argp))
  50. et = time.perf_counter()-st
  51. idir, size, itype, nr = (request>>30), (request>>16)&0x3FFF, (request>>8)&0xFF, request&0xFF
  52. if nr in nrs and itype == 75:
  53. # /dev/kfd
  54. name, stype = nrs[nr]
  55. s = get_struct(argp, stype)
  56. print(f"{(st-start)*1000:7.2f} ms +{et*1000.:7.2f} ms : {ret:2d} = {name:40s}", ' '.join(format_struct(s)))
  57. if name == "AMDKFD_IOC_SVM":
  58. out = ctypes.cast(s.attrs, ctypes.POINTER(kfd_ioctl.struct_kfd_ioctl_svm_attribute))
  59. for i in range(s.nattr): print(f"{i}: {kfd_ioctl.kfd_ioctl_svm_attr_type__enumvalues[out[i].type]:40s}: {out[i].value:#x}")
  60. else:
  61. print(f"{(st-start)*1000:7.2f} ms +{et*1000.:7.2f} ms : ioctl",
  62. f"{idir=} {size=} {itype=} {nr=} {fd=} {ret=}", os.readlink(f"/proc/self/fd/{fd}") if fd >= 0 else "")
  63. return ret
  64. install_hook(libc.ioctl, ioctl)
  65. # AMD_LOG_LEVEL=4 HSAKMT_DEBUG_LEVEL=7
  66. if __name__ == "__main__":
  67. print("***** import tinygrad")
  68. from tinygrad import Tensor, Device, TinyJit
  69. print("***** access HIP")
  70. dev = Device["HIP"]
  71. print("***** create tensor a")
  72. a = Tensor([1.,2.]*1024*1024, device="HIP").realize()
  73. print("***** create tensor b")
  74. b = Tensor([3.,4.]*1024*1024, device="HIP").realize()
  75. @TinyJit
  76. def add(a, b): return (a+b).realize()
  77. for i in range(4):
  78. print(f"***** add tensors {i}")
  79. c = add(a, b)
  80. #dev.synchronize()
  81. c = add(b, a)
  82. dev.synchronize()
  83. print(f"***** copyout")
  84. nc = c.numpy()
  85. print(f"***** delete")
  86. del add, a, b, c, dev
  87. print(f"***** done")
  88. os._exit(0)