ops_cuda.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. from __future__ import annotations
  2. import subprocess, hashlib, tempfile, ctypes, ctypes.util, functools, re
  3. from pathlib import Path
  4. from typing import Tuple, Optional, List
  5. import tinygrad.runtime.autogen.cuda as cuda
  6. import tinygrad.runtime.autogen.nvrtc as nvrtc
  7. from tinygrad.helpers import DEBUG, getenv, from_mv, to_char_p_p, init_c_var, init_c_struct_t, colored
  8. from tinygrad.device import Compiled, Compiler, CompileError, BufferOptions, LRUAllocator
  9. from tinygrad.renderer.cstyle import CUDARenderer
  10. from tinygrad.renderer.assembly import PTXRenderer
  11. if getenv("IOCTL"): import extra.nv_gpu_driver.nv_ioctl # noqa: F401 # pylint: disable=unused-import
  12. def pretty_ptx(s):
  13. # all expressions match `<valid_before><expr><valid_after>` and replace it with `<valid_before>color(<expr>)<valid_after>`
  14. s = re.sub(r'([!@<\[\s,\+\-;\n])((?:[_%$][\w%\$_]+(?:\.[xyz])?\:?)|(?:buf\d+))([<>\]\s,\+\-;\n\)])', lambda m:m[1]+colored(m[2], "blue")+m[3], s, flags=re.M) # identifiers # noqa: E501
  15. s = re.sub(r'(.)((?:b|s|u|f)(?:8|16|32|64)|pred)([\.\s])', lambda m:m[1]+colored(m[2], "green")+m[3], s, flags=re.M) # types
  16. s = re.sub(r'^(\s*)([\w]+)(.*?;$)', lambda m:m[1]+colored(m[2], "yellow")+m[3], s, flags=re.M) # instructions
  17. s = re.sub(r'([<>\[\]\s,\+\-;])((?:0[fF][0-9a-fA-F]{8})|(?:[0-9]+)|(?:0[xX][0-9a-fA-F]+))([<>\[\]\s,\+\-;])', lambda m:m[1]+colored(m[2], "yellow")+m[3], s, flags=re.M) # numbers # noqa: E501
  18. s = re.sub(r'(\.)(param|reg|global)', lambda m:m[1]+colored(m[2], "magenta"), s, flags=re.M) # space
  19. s = re.sub(r'(\.)(version|target|address_size|visible|entry)', lambda m:m[1]+colored(m[2], "magenta"), s, flags=re.M) # derivatives
  20. return s
  21. PTX = getenv("PTX")
  22. def check(status):
  23. if status != 0: raise RuntimeError(f"CUDA Error {status}, {ctypes.string_at(init_c_var(ctypes.POINTER(ctypes.c_char)(), lambda x: cuda.cuGetErrorString(status, ctypes.byref(x)))).decode()}") # noqa: E501
  24. def encode_args(args, vals) -> Tuple[ctypes.Structure, ctypes.Array]:
  25. c_args = init_c_struct_t(tuple([(f'f{i}', cuda.CUdeviceptr_v2) for i in range(len(args))] +
  26. [(f'v{i}', ctypes.c_int) for i in range(len(vals))]))(*args, *vals)
  27. vargs = (ctypes.c_void_p * 5)(ctypes.c_void_p(1), ctypes.cast(ctypes.byref(c_args), ctypes.c_void_p), ctypes.c_void_p(2),
  28. ctypes.cast(ctypes.pointer(ctypes.c_size_t(ctypes.sizeof(c_args))), ctypes.c_void_p), ctypes.c_void_p(0))
  29. return c_args, vargs
  30. def cu_time_execution(cb, enable=False) -> Optional[float]:
  31. if not enable: return cb()
  32. evs = [init_c_var(cuda.CUevent(), lambda x: cuda.cuEventCreate(ctypes.byref(x), 0)) for _ in range(2)]
  33. cuda.cuEventRecord(evs[0], None)
  34. cb()
  35. cuda.cuEventRecord(evs[1], None)
  36. check(cuda.cuEventSynchronize(evs[1]))
  37. cuda.cuEventElapsedTime(ctypes.byref(ret := ctypes.c_float()), evs[0], evs[1])
  38. for ev in evs: cuda.cuEventDestroy_v2(ev)
  39. return ret.value * 1e-3
  40. def _get_bytes(arg, get_str, get_sz, check) -> bytes:
  41. sz = init_c_var(ctypes.c_size_t(), lambda x: check(get_sz(arg, ctypes.byref(x))))
  42. return ctypes.string_at(init_c_var(ctypes.create_string_buffer(sz.value), lambda x: check(get_str(arg, x))), size=sz.value)
  43. class PTXCompiler(Compiler):
  44. def __init__(self, arch:str):
  45. self.arch = arch
  46. self.version = "7.8" if arch >= "sm_89" else "7.5"
  47. super().__init__(f"compile_ptx_{self.arch}")
  48. def compile(self, src:str) -> bytes: return src.replace("TARGET", self.arch).replace("VERSION", self.version).encode()
  49. class CUDACompiler(Compiler):
  50. def __init__(self, arch:str):
  51. self.arch = arch
  52. check(nvrtc.nvrtcVersion((nvrtcMajor := ctypes.c_int()), (nvrtcMinor := ctypes.c_int())))
  53. self.compile_options = [f'--gpu-architecture={arch}', "-I/usr/local/cuda/include", "-I/usr/include", "-I/opt/cuda/include/"]
  54. if (nvrtcMajor.value, nvrtcMinor.value) >= (12, 4): self.compile_options.append("--minimal")
  55. super().__init__(f"compile_cuda_{self.arch}")
  56. def compile(self, src:str) -> bytes:
  57. check(nvrtc.nvrtcCreateProgram(ctypes.byref(prog := nvrtc.nvrtcProgram()), src.encode(), "<null>".encode(), 0, None, None))
  58. status = nvrtc.nvrtcCompileProgram(prog, len(self.compile_options), to_char_p_p([o.encode() for o in self.compile_options]))
  59. if status != 0: raise CompileError(f"compile failed: {_get_bytes(prog, nvrtc.nvrtcGetProgramLog, nvrtc.nvrtcGetProgramLogSize, check).decode()}")
  60. return _get_bytes(prog, nvrtc.nvrtcGetPTX, nvrtc.nvrtcGetPTXSize, check)
  61. def cuda_disassemble(lib, arch):
  62. try:
  63. fn = (Path(tempfile.gettempdir()) / f"tinycuda_{hashlib.md5(lib).hexdigest()}").as_posix()
  64. with open(fn + ".ptx", "wb") as f: f.write(lib)
  65. subprocess.run(["ptxas", f"-arch={arch}", "-o", fn, fn+".ptx"], check=True)
  66. print(subprocess.check_output(['nvdisasm', fn]).decode('utf-8'))
  67. except Exception as e: print("failed to generate SASS", str(e))
  68. class CUDAProgram:
  69. def __init__(self, device:CUDADevice, name:str, lib:bytes):
  70. self.device, self.name, self.lib = device, name, lib
  71. if DEBUG >= 5: print("\n".join([f"{i+1:>3} {line}" for i, line in enumerate(pretty_ptx(lib.decode('utf-8')).split("\n"))]))
  72. if DEBUG >= 6: cuda_disassemble(lib, device.arch)
  73. check(cuda.cuCtxSetCurrent(self.device.context))
  74. self.module = cuda.CUmodule()
  75. status = cuda.cuModuleLoadData(ctypes.byref(self.module), lib)
  76. if status != 0:
  77. del self.module
  78. cuda_disassemble(lib, device.arch)
  79. raise RuntimeError(f"module load failed with status code {status}: {cuda.cudaError_enum__enumvalues[status]}")
  80. check(cuda.cuModuleGetFunction(ctypes.byref(prg := cuda.CUfunction()), self.module, name.encode("utf-8")))
  81. self.prg = prg #type: ignore
  82. def __del__(self):
  83. if hasattr(self, 'module'): check(cuda.cuModuleUnload(self.module))
  84. def __call__(self, *args, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
  85. check(cuda.cuCtxSetCurrent(self.device.context))
  86. if not hasattr(self, "vargs"):
  87. self.c_args, self.vargs = encode_args(args, vals) #type: ignore
  88. else:
  89. for i in range(len(args)): self.c_args.__setattr__(f'f{i}', args[i])
  90. for i in range(len(vals)): self.c_args.__setattr__(f'v{i}', vals[i])
  91. return cu_time_execution(lambda: check(cuda.cuLaunchKernel(self.prg, *global_size, *local_size, 0, None, None, self.vargs)), enable=wait)
  92. class CUDAAllocator(LRUAllocator):
  93. def __init__(self, device:CUDADevice):
  94. self.device = device
  95. super().__init__()
  96. def _alloc(self, size, options:BufferOptions):
  97. check(cuda.cuCtxSetCurrent(self.device.context))
  98. if options.host: return init_c_var(ctypes.c_void_p(), lambda x: check(cuda.cuMemHostAlloc(ctypes.byref(x), size, 0x01)))
  99. return init_c_var(cuda.CUdeviceptr(), lambda x: check(cuda.cuMemAlloc_v2(ctypes.byref(x), size)))
  100. def _free(self, opaque, options:BufferOptions):
  101. if options.host: check(cuda.cuMemFreeHost(opaque))
  102. else: check(cuda.cuMemFree_v2(opaque))
  103. def copyin(self, dest, src:memoryview):
  104. check(cuda.cuCtxSetCurrent(self.device.context))
  105. host_mem = self.alloc(len(src), BufferOptions(host=True))
  106. self.device.pending_copyin.append((host_mem, len(src), BufferOptions(host=True)))
  107. ctypes.memmove(host_mem, from_mv(src), len(src))
  108. check(cuda.cuMemcpyHtoDAsync_v2(dest, host_mem, len(src), None))
  109. def copyout(self, dest:memoryview, src):
  110. CUDADevice.synchronize_system()
  111. check(cuda.cuCtxSetCurrent(self.device.context))
  112. check(cuda.cuMemcpyDtoH_v2(from_mv(dest), src, len(dest)))
  113. def transfer(self, dest, src, sz:int, src_dev, dest_dev):
  114. check(cuda.cuCtxSetCurrent(src_dev.context))
  115. check(cuda.cuEventCreate(ctypes.byref(sync_event := cuda.CUevent()), 0))
  116. check(cuda.cuMemcpyDtoDAsync_v2(dest, src, sz, None))
  117. check(cuda.cuEventRecord(sync_event, None))
  118. check(cuda.cuCtxSetCurrent(dest_dev.context))
  119. check(cuda.cuStreamWaitEvent(None, sync_event, 0)) # sync the default stream on the dest dev
  120. def offset(self, buf, size:int, offset:int): return ctypes.c_ulong(buf.value + offset)
  121. class CUDADevice(Compiled):
  122. devices: List[CUDADevice] = []
  123. peer_access = False
  124. def __init__(self, device:str):
  125. device_id = int(device.split(":")[1]) if ":" in device else 0
  126. check(cuda.cuInit(0))
  127. self.cu_device = init_c_var(cuda.CUdevice(), lambda x: check(cuda.cuDeviceGet(ctypes.byref(x), device_id)))
  128. self.context = init_c_var(cuda.CUcontext(), lambda x: check(cuda.cuCtxCreate_v2(ctypes.byref(x), 0, self.cu_device)))
  129. check(cuda.cuDeviceComputeCapability(ctypes.byref(major := ctypes.c_int()), ctypes.byref(minor := ctypes.c_int()), device_id))
  130. for dev in CUDADevice.devices:
  131. check(cuda.cuDeviceCanAccessPeer(ctypes.byref(val := ctypes.c_int()), self.cu_device, dev.cu_device))
  132. if val.value != 1: continue
  133. check(cuda.cuCtxSetCurrent(dev.context))
  134. check(cuda.cuCtxEnablePeerAccess(self.context, 0))
  135. check(cuda.cuCtxSetCurrent(self.context))
  136. check(cuda.cuCtxEnablePeerAccess(dev.context, 0))
  137. CUDADevice.peer_access = True
  138. self.arch = f"sm_{major.value}{minor.value}"
  139. self.pending_copyin: List[Tuple[int, int, Optional[BufferOptions]]] = []
  140. CUDADevice.devices.append(self)
  141. from tinygrad.runtime.graph.cuda import CUDAGraph
  142. super().__init__(device, CUDAAllocator(self), PTXRenderer(self.arch) if PTX else CUDARenderer(self.arch),
  143. PTXCompiler(self.arch) if PTX else CUDACompiler(self.arch), functools.partial(CUDAProgram, self), graph=CUDAGraph)
  144. def synchronize(self):
  145. check(cuda.cuCtxSetCurrent(self.context))
  146. check(cuda.cuCtxSynchronize())
  147. for opaque,sz,options in self.pending_copyin: self.allocator.free(opaque, sz, options)
  148. self.pending_copyin.clear()
  149. @staticmethod
  150. def synchronize_system():
  151. for d in CUDADevice.devices: d.synchronize()