graph_hip.py 2.0 KB

123456789101112131415161718192021222324252627
  1. import ctypes
  2. from typing import Tuple
  3. import tinygrad.runtime.autogen.hip as hip
  4. from tinygrad.helpers import init_c_var, time_execution_cuda_style
  5. from tinygrad.runtime.ops_hip import check, hip_set_device
  6. from tinygrad.runtime.graph.cuda import CUDAGraph
  7. # TODO: this is only used in graph
  8. def hip_time_execution(cb, enable=False): return time_execution_cuda_style(cb, hip.hipEvent_t, hip.hipEventCreate, hip.hipEventRecord, hip.hipEventSynchronize, hip.hipEventDestroy, hip.hipEventElapsedTime, enable=enable) # noqa: E501
  9. class HIPGraph(CUDAGraph):
  10. def __del__(self):
  11. if hasattr(self, 'graph'): check(hip.hipGraphDestroy(self.graph))
  12. if hasattr(self, 'instance'): check(hip.hipGraphExecDestroy(self.instance))
  13. def set_device(self): hip_set_device(self.device)
  14. def encode_args_info(self): return (hip.hipDeviceptr_t, (1,2,3))
  15. def graph_create(self): return init_c_var(hip.hipGraph_t(), lambda x: check(hip.hipGraphCreate(ctypes.byref(x), 0)))
  16. def graph_instantiate(self, graph):
  17. return init_c_var(hip.hipGraphExec_t(), lambda x: check(hip.hipGraphInstantiate(ctypes.byref(x), graph, None, None, 0)))
  18. def graph_add_kernel_node(self, graph, c_deps, c_params):
  19. return init_c_var(hip.hipGraphNode_t(), lambda x: check(hip.hipGraphAddKernelNode(ctypes.byref(x), graph, c_deps, ctypes.sizeof(c_deps)//8 if c_deps else 0, ctypes.byref(c_params)))) # noqa: E501
  20. def graph_launch(self, *args, wait=False): return hip_time_execution(lambda: check(hip.hipGraphLaunch(*args)), enable=wait)
  21. def graph_exec_kernel_node_set_params(self, *args): return check(hip.hipGraphExecKernelNodeSetParams(*args))
  22. def build_kernel_node_params(self, prg, global_size, local_size, c_config):
  23. return hip.hipKernelNodeParams(hip.dim3(*local_size), c_config, ctypes.cast(prg.clprg.prg, ctypes.c_void_p), hip.dim3(*global_size), None, 0)
  24. def set_kernel_node_launch_dims(self, node, global_size: Tuple[int, int, int], local_size: Tuple[int, int, int]):
  25. node.blockDim.x, node.blockDim.y, node.blockDim.z, node.gridDim.x, node.gridDim.y, node.gridDim.z = *local_size, *global_size