metal.py 4.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. from typing import List, Any, Dict, cast, Optional
  2. import Metal
  3. from tinygrad.dtype import dtypes
  4. from tinygrad.helpers import dedup, unwrap2, GraphException
  5. from tinygrad.device import Buffer
  6. from tinygrad.engine.realize import ExecItem, CompiledRunner
  7. from tinygrad.engine.jit import GraphRunner
  8. from tinygrad.shape.symbolic import Variable
  9. from tinygrad.runtime.ops_metal import wait_check
  10. class MetalGraph(GraphRunner):
  11. def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
  12. super().__init__(jit_cache, input_rawbuffers, var_vals)
  13. if not all(isinstance(ji.prg, CompiledRunner) for ji in jit_cache): raise GraphException
  14. # create metal batch exec
  15. icb_descriptor = Metal.MTLIndirectCommandBufferDescriptor.new()
  16. icb_descriptor.setCommandTypes_(Metal.MTLIndirectCommandType(Metal.MTLIndirectCommandTypeConcurrentDispatch))
  17. icb_descriptor.setInheritBuffers_(False)
  18. icb_descriptor.setInheritPipelineState_(False)
  19. icb_descriptor.setMaxKernelBufferBindCount_(31)
  20. self.icb = self.device.device.newIndirectCommandBufferWithDescriptor_maxCommandCount_options_(icb_descriptor, len(self.jit_cache),
  21. Metal.MTLResourceOptions(0))
  22. if self.icb is None: raise GraphException("create indirect command buffer failed, does your system support this?")
  23. if len(self.vars): self.int_buf = self.device.allocator.alloc(len(self.vars)*dtypes.int32.itemsize)
  24. all_resources = [self.int_buf.buf] if len(self.vars) else []
  25. for j,ji in enumerate(self.jit_cache):
  26. prg: CompiledRunner = cast(CompiledRunner, ji.prg)
  27. descriptor = Metal.MTLComputePipelineDescriptor.new()
  28. descriptor.setComputeFunction_(prg.clprg.fxn)
  29. descriptor.setSupportIndirectCommandBuffers_(True)
  30. icb_command = self.icb.indirectComputeCommandAtIndex_(j)
  31. icb_command.setComputePipelineState_(unwrap2(
  32. self.device.device.newComputePipelineStateWithDescriptor_options_reflection_error_(descriptor, Metal.MTLPipelineOption(0), None, None)))
  33. for i,b in enumerate(ji.bufs):
  34. if b is not None:
  35. icb_command.setKernelBuffer_offset_atIndex_(b._buf.buf, b._buf.offset, i)
  36. all_resources.append(b._buf.buf)
  37. for i,v in enumerate(prg.p.vars): icb_command.setKernelBuffer_offset_atIndex_(self.int_buf.buf, self.vars.index(v)*4, len(ji.bufs)+i)
  38. if j not in self.jc_idx_with_updatable_launch_dims:
  39. global_size, local_size = prg.p.launch_dims(var_vals)
  40. icb_command.concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
  41. icb_command.setBarrier()
  42. self.all_resources = dedup(all_resources)
  43. self.command_buffer: Any = None
  44. if len(self.vars): self.int_buf_view = self.int_buf.buf.contents().as_buffer(self.int_buf.buf.length()).cast('i')
  45. def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
  46. if self.command_buffer is not None and self.command_buffer in self.device.mtl_buffers_in_flight: wait_check(self.command_buffer)
  47. all_resources = dedup(self.all_resources + [x._buf.buf for x in input_rawbuffers])
  48. for (j,i),input_idx in self.input_replace.items():
  49. self.icb.indirectComputeCommandAtIndex_(j).setKernelBuffer_offset_atIndex_(input_rawbuffers[input_idx]._buf.buf,
  50. input_rawbuffers[input_idx]._buf.offset, i)
  51. for j in self.jc_idx_with_updatable_launch_dims:
  52. global_size, local_size = cast(CompiledRunner, self.jit_cache[j].prg).p.launch_dims(var_vals)
  53. self.icb.indirectComputeCommandAtIndex_(j).concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size),
  54. Metal.MTLSize(*local_size))
  55. for j, var in enumerate(self.vars): self.int_buf_view[j] = var_vals[var]
  56. command_buffer = self.device.mtl_queue.commandBuffer()
  57. encoder = command_buffer.computeCommandEncoder()
  58. encoder.useResources_count_usage_(all_resources, len(all_resources), Metal.MTLResourceUsageRead | Metal.MTLResourceUsageWrite)
  59. encoder.executeCommandsInBuffer_withRange_(self.icb, Metal.MTLIndirectCommandBufferExecutionRangeMake(0, len(self.jit_cache)))
  60. encoder.endEncoding()
  61. command_buffer.commit()
  62. self.command_buffer = command_buffer
  63. if wait:
  64. wait_check(command_buffer)
  65. return command_buffer.GPUEndTime() - command_buffer.GPUStartTime()
  66. self.device.mtl_buffers_in_flight.append(command_buffer)
  67. return None