mockgpu.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. import ctypes, ctypes.util, struct, platform, pathlib, re, time, os, builtins, atexit
  2. from extra.mockgpu.nv.nvdriver import NVDriver
  3. from extra.mockgpu.amd.amddriver import AMDDriver
  4. from tinygrad.helpers import from_mv, to_mv
  5. start = time.perf_counter()
  6. # *** ioctl lib ***
  7. libc = ctypes.CDLL(ctypes.util.find_library("c"))
  8. libc.mmap.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_long]
  9. libc.mmap.restype = ctypes.c_void_p
  10. libc.munmap.argtypes = [ctypes.c_void_p, ctypes.c_size_t]
  11. libc.munmap.restype = ctypes.c_int
  12. libc.fdopendir.argtypes = [ctypes.c_int]
  13. libc.fdopendir.restype = ctypes.c_void_p
  14. processor = platform.processor()
  15. OPEN_SYSCALL = {"aarch64": None, "x86_64": 2}[processor]
  16. CLOSE_SYSCALL = {"aarch64": 57, "x86_64": 3}[processor]
  17. READ_SYSCALL = {"aarch64": 63, "x86_64": 0}[processor]
  18. IOCTL_SYSCALL = {"aarch64": 29, "x86_64": 16}[processor]
  19. MMAP_SYSCALL = {"aarch64": 222, "x86_64": 9}[processor]
  20. LSEEK_SYSCALL = {"aarch64": 62, "x86_64": 8}[processor]
  21. NEWFSTATAT_SYSCALL = {"aarch64": 79, "x86_64": 262}[processor]
  22. GETDENTS64_SYSCALL = {"aarch64": 61, "x86_64": 217}[processor]
  23. def install_hook(c_function, python_function):
  24. python_function_addr = ctypes.cast(ctypes.byref(python_function), ctypes.POINTER(ctypes.c_ulong)).contents.value
  25. if processor == "x86_64":
  26. # tramp = b"\x49\xB8" + struct.pack("Q", python_function_addr) + b"\x41\xFF\xE0"
  27. # push r9
  28. # push r9
  29. # mov r9, 0x1122334455667788
  30. # mov [rsp+8], r9
  31. # pop r9
  32. # ret
  33. tramp = b"\x41\x51\x41\x51\x49\xB9" + struct.pack("Q", python_function_addr) + b"\x4C\x89\x4C\x24\x08\x41\x59\xC3"
  34. else:
  35. raise Exception(f"processor {processor} not supported")
  36. original_bc = (ctypes.c_char * 64)()
  37. # get real ioctl address
  38. ioctl_address = ctypes.cast(ctypes.byref(c_function), ctypes.POINTER(ctypes.c_ulong))
  39. # hook ioctl
  40. ret = libc.mprotect(ctypes.c_ulong((ioctl_address.contents.value//0x1000)*0x1000), 0x2000, 7)
  41. assert ret == 0
  42. libc.memcpy(original_bc, ioctl_address.contents, len(tramp))
  43. libc.memcpy(ioctl_address.contents, ctypes.create_string_buffer(tramp), len(tramp))
  44. # Restore correct functions to close libs after python exits
  45. def __restore(): libc.memcpy(ioctl_address.contents, original_bc, len(tramp))
  46. atexit.register(__restore)
  47. drivers = [AMDDriver(), NVDriver()]
  48. tracked_fds = {}
  49. @ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_char_p, ctypes.c_int, ctypes.c_ulong)
  50. def _open(name, flags, mode):
  51. for d in drivers:
  52. pyname = name.decode()
  53. for x in d.tracked_files:
  54. if pyname == x.path:
  55. virtfd = d.open(pyname, flags, mode, x)
  56. tracked_fds[virtfd.fd] = virtfd
  57. return virtfd.fd
  58. libc.syscall.argtypes = [ctypes.c_ulong, ctypes.c_char_p, ctypes.c_int, ctypes.c_ulong]
  59. libc.syscall.restype = ctypes.c_int
  60. return libc.syscall(OPEN_SYSCALL, name, flags, mode)
  61. @ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.c_char_p)
  62. def _opendir(name):
  63. fd = _open(name, os.O_RDONLY| os.O_DIRECTORY, 0)
  64. if fd >= 0x80:
  65. fake_dirfd = _open(".".encode(), os.O_RDONLY| os.O_DIRECTORY, 0)
  66. st = libc.fdopendir(fake_dirfd)
  67. to_mv(st, 8).cast('Q')[0] = fd
  68. return st
  69. else: return libc.fdopendir(fd)
  70. @ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_int)
  71. def _close(fd):
  72. if fd in tracked_fds:
  73. tracked_fds[fd].close(fd)
  74. tracked_fds.pop(fd)
  75. return 0
  76. libc.syscall.argtypes = [ctypes.c_ulong, ctypes.c_int]
  77. libc.syscall.restype = ctypes.c_int
  78. return libc.syscall(CLOSE_SYSCALL, fd)
  79. @ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_void_p)
  80. def _closedir(st): return _close(to_mv(st, 8).cast('Q')[0])
  81. @ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_int, ctypes.c_ulong, ctypes.c_void_p)
  82. def _ioctl(fd, request, argp):
  83. if fd in tracked_fds: return tracked_fds[fd].ioctl(fd, request, argp)
  84. libc.syscall.argtypes = [ctypes.c_ulong, ctypes.c_int, ctypes.c_ulong, ctypes.c_void_p]
  85. libc.syscall.restype = ctypes.c_int
  86. return libc.syscall(IOCTL_SYSCALL, ctypes.c_int(fd), ctypes.c_ulong(request), ctypes.c_void_p(argp))
  87. @ctypes.CFUNCTYPE(ctypes.c_long, ctypes.c_int, ctypes.c_void_p, ctypes.c_size_t)
  88. def _read(fd, buf, sz):
  89. if fd in tracked_fds: return tracked_fds[fd].read(fd, buf, sz)
  90. libc.syscall.argtypes = [ctypes.c_ulong, ctypes.c_int, ctypes.c_void_p, ctypes.c_size_t]
  91. libc.syscall.restype = ctypes.c_int
  92. return libc.syscall(READ_SYSCALL, ctypes.c_int(fd), ctypes.c_void_p(buf), ctypes.c_size_t(sz))
  93. @ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_int, ctypes.c_ulong, ctypes.c_int)
  94. def _lseek64(fd, off, whence):
  95. if fd in tracked_fds: return tracked_fds[fd].lseek(fd, off, whence)
  96. libc.syscall.argtypes = [ctypes.c_ulong, ctypes.c_int, ctypes.c_ulong, ctypes.c_int]
  97. libc.syscall.restype = ctypes.c_int
  98. return libc.syscall(LSEEK_SYSCALL, fd, off, whence)
  99. @ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_char_p, ctypes.c_void_p)
  100. def _stat64(name, buf):
  101. for d in drivers:
  102. pyname = name.decode()
  103. for x in d.tracked_files:
  104. if pyname == x.path:
  105. virtfd = d.open(pyname, 0, 0, x)
  106. return virtfd.fstat(virtfd.fd, buf)
  107. libc.syscall.argtypes = [ctypes.c_ulong, ctypes.c_int, ctypes.c_char_p, ctypes.c_void_p, ctypes.c_ulong]
  108. libc.syscall.restype = ctypes.c_int
  109. return libc.syscall(NEWFSTATAT_SYSCALL, -100, name, ctypes.c_void_p(buf), 0)
  110. @ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_int, ctypes.c_void_p)
  111. def _fstat64(fd, buf):
  112. if fd in tracked_fds: return tracked_fds[fd].fstat(fd, buf)
  113. empty_str = (ctypes.c_char*1)()
  114. libc.syscall.argtypes = [ctypes.c_ulong, ctypes.c_int, ctypes.c_char_p, ctypes.c_void_p, ctypes.c_ulong]
  115. libc.syscall.restype = ctypes.c_int
  116. return libc.syscall(NEWFSTATAT_SYSCALL, ctypes.c_int(fd), empty_str, ctypes.c_void_p(buf), 0x1000)
  117. @ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_int, ctypes.c_void_p, ctypes.c_ulong)
  118. def _getdents64(fd, buf, sz):
  119. if fd in tracked_fds: return tracked_fds[fd].getdents(fd, buf, sz)
  120. libc.syscall.argtypes = [ctypes.c_ulong, ctypes.c_int, ctypes.c_void_p, ctypes.c_ulong]
  121. libc.syscall.restype = ctypes.c_int
  122. return libc.syscall(GETDENTS64_SYSCALL, fd, buf, sz)
  123. def _mmap(start, sz, prot, flags, fd, offset):
  124. if fd in tracked_fds: return tracked_fds[fd].mmap(start, sz, prot, flags, fd, offset)
  125. return libc.mmap(start, sz, prot, flags, fd, offset)
  126. def _munmap(buf, sz):
  127. return libc.munmap(buf, sz)
  128. orignal_memoryview = builtins.memoryview
  129. class TrackedMemoryView:
  130. def __init__(self, data, rcb, wcb):
  131. self.mv = orignal_memoryview(data)
  132. self.rcb, self.wcb = rcb, wcb
  133. def __getitem__(self, index):
  134. self.rcb(self.mv, index)
  135. return self.mv[index]
  136. def __setitem__(self, index, value):
  137. self.mv[index] = value
  138. self.wcb(self.mv, index)
  139. def cast(self, new_type, **kwargs):
  140. self.mv = self.mv.cast(new_type, **kwargs)
  141. return self
  142. @property
  143. def nbytes(self): return self.mv.nbytes
  144. def __len__(self): return len(self.mv)
  145. def __repr__(self): return repr(self.mv)
  146. def _memoryview(mem):
  147. if isinstance(mem, int) or isinstance(mem, ctypes.Array):
  148. addr = ctypes.addressof(mem) if isinstance(mem, ctypes.Array) else mem
  149. for d in drivers:
  150. for st,en,rcb,wcb in d.tracked_addresses:
  151. if st <= addr <= en: return TrackedMemoryView(mem, rcb, wcb)
  152. return orignal_memoryview(mem)
  153. install_hook(libc.open, _open)
  154. install_hook(libc.opendir, _opendir)
  155. install_hook(libc.close, _close)
  156. install_hook(libc.closedir, _closedir)
  157. install_hook(libc.ioctl, _ioctl)
  158. install_hook(libc.read, _read)
  159. install_hook(libc.lseek64, _lseek64)
  160. install_hook(libc.stat64, _stat64)
  161. install_hook(libc.fstat64, _fstat64)
  162. install_hook(libc.getdents64, _getdents64)
  163. builtins.memoryview = _memoryview # type: ignore
  164. # rewrite autogen's libc mmaps functions.
  165. import tinygrad.runtime.autogen.libc as autogen_libc
  166. autogen_libc.mmap = _mmap # type: ignore
  167. autogen_libc.munmap = _munmap # type: ignore