cuda.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. import ctypes
  2. from typing import Any, Optional, Tuple, Dict, List, cast
  3. import tinygrad.runtime.autogen.cuda as cuda
  4. from tinygrad.helpers import init_c_var, GraphException, dedup
  5. from tinygrad.device import Buffer, Device
  6. from tinygrad.runtime.ops_cuda import CUDADevice, check, encode_args, cu_time_execution
  7. from tinygrad.shape.symbolic import Variable
  8. from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
  9. from tinygrad.engine.jit import MultiGraphRunner
  10. class CUDAGraph(MultiGraphRunner):
  11. def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
  12. super().__init__(jit_cache, input_rawbuffers, var_vals)
  13. # Check all jit items are compatible.
  14. if not all(isinstance(ji.prg, (CompiledRunner, BufferXfer)) for ji in jit_cache): raise GraphException
  15. self.jc_idx_with_updatable_rawbufs = dedup([x[0] for x in self.input_replace.keys()])
  16. self.updatable_nodes: Dict[int, Tuple[Any, Any, Any, bool]] = {} # Dict[jc index] = tuple(graph node, node params, input kernel params, is memcpy)
  17. self.graph = init_c_var(cuda.CUgraph(), lambda x: check(cuda.cuGraphCreate(ctypes.byref(x), 0)))
  18. for j,ji in enumerate(self.jit_cache):
  19. if isinstance(ji.prg, CompiledRunner):
  20. global_size, local_size = ji.prg.p.launch_dims(var_vals)
  21. new_node = cuda.CUgraphNode()
  22. deps = self._access_resources([x.base for x in ji.bufs[ji.prg.p.outcount:] if x is not None],
  23. [x.base for x in ji.bufs[:ji.prg.p.outcount] if x is not None], new_dependency=new_node)
  24. c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None
  25. c_args, vargs = encode_args([cast(Buffer, x)._buf for x in ji.bufs], [var_vals[x] for x in ji.prg.p.vars])
  26. kern_params = cuda.CUDA_KERNEL_NODE_PARAMS(ji.prg.clprg.prg, *global_size, *local_size, 0, None, vargs)
  27. check(cuda.cuGraphAddKernelNode(ctypes.byref(new_node), self.graph, c_deps, len(deps), ctypes.byref(kern_params)))
  28. if j in self.jc_idx_with_updatable_launch_dims or j in self.jc_idx_with_updatable_var_vals or j in self.jc_idx_with_updatable_rawbufs:
  29. self.updatable_nodes[j] = (new_node, kern_params, c_args, False)
  30. elif isinstance(ji.prg, BufferXfer):
  31. dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
  32. src_dev = cast(CUDADevice, Device[src.device])
  33. node_from = cuda.CUgraphNode()
  34. deps = self._access_resources(read=[src.base], write=[dest.base], new_dependency=node_from)
  35. c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None
  36. cp_params = cuda.CUDA_MEMCPY3D_v2(srcMemoryType=cuda.CU_MEMORYTYPE_DEVICE, srcDevice=src._buf, srcPitch=src.nbytes, srcHeight=1,
  37. dstMemoryType=cuda.CU_MEMORYTYPE_DEVICE, dstDevice=dest._buf, dstPitch=dest.nbytes, dstHeight=1,
  38. WidthInBytes=dest.nbytes, Height=1, Depth=1)
  39. check(cuda.cuGraphAddMemcpyNode(ctypes.byref(node_from), self.graph, c_deps, len(deps), ctypes.byref(cp_params), src_dev.context))
  40. if j in self.jc_idx_with_updatable_rawbufs: self.updatable_nodes[j] = (node_from, cp_params, src_dev.context, True)
  41. self.instance = init_c_var(cuda.CUgraphExec(), lambda x: check(cuda.cuGraphInstantiate_v2(ctypes.byref(x), self.graph, None, None, 0)))
  42. def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
  43. # Update rawbuffers in the c_args struct.
  44. for (j,i),input_idx in self.input_replace.items():
  45. if not self.updatable_nodes[j][3]: setattr(self.updatable_nodes[j][2], f'f{i}', input_rawbuffers[input_idx]._buf)
  46. else:
  47. if i == 0: self.updatable_nodes[j][1].destDevice = input_rawbuffers[input_idx]._buf
  48. elif i == 1: self.updatable_nodes[j][1].srcDevice = input_rawbuffers[input_idx]._buf
  49. # Update var_vals in the c_args struct.
  50. for j in self.jc_idx_with_updatable_var_vals:
  51. for i,v in enumerate(cast(CompiledRunner, self.jit_cache[j].prg).p.vars):
  52. setattr(self.updatable_nodes[j][2], f'v{i}', var_vals[v])
  53. # Update launch dims in the kern_params struct.
  54. for j in self.jc_idx_with_updatable_launch_dims:
  55. self.set_kernel_node_launch_dims(self.updatable_nodes[j][1], *cast(CompiledRunner, self.jit_cache[j].prg).p.launch_dims(var_vals))
  56. # Update graph nodes with the updated structs.
  57. for node, c_node_params, c_args, is_copy in self.updatable_nodes.values():
  58. if not is_copy: check(cuda.cuGraphExecKernelNodeSetParams(self.instance, node, ctypes.byref(c_node_params)))
  59. else: check(cuda.cuGraphExecMemcpyNodeSetParams(self.instance, node, ctypes.byref(c_node_params), c_args))
  60. return cu_time_execution(lambda: check(cuda.cuGraphLaunch(self.instance, None)), enable=wait)
  61. def __del__(self):
  62. if hasattr(self, 'graph'): check(cuda.cuGraphDestroy(self.graph))
  63. if hasattr(self, 'instance'): check(cuda.cuGraphExecDestroy(self.instance))
  64. def set_kernel_node_launch_dims(self, node, global_size: Tuple[int, int, int], local_size: Tuple[int, int, int]):
  65. node.blockDimX, node.blockDimY, node.blockDimZ, node.gridDimX, node.gridDimY, node.gridDimZ = *local_size, *global_size