ops_metal.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. from __future__ import annotations
  2. import os, subprocess, pathlib, ctypes, tempfile, functools
  3. import Metal, libdispatch
  4. from typing import List, Set, Any, Tuple, Optional
  5. from tinygrad.helpers import prod, getenv, DEBUG, unwrap2
  6. from tinygrad.device import Compiled, Compiler, CompileError, LRUAllocator
  7. from tinygrad.renderer.cstyle import MetalRenderer
  8. def wait_check(cbuf: Any):
  9. cbuf.waitUntilCompleted()
  10. if (error := cbuf.error()) is not None:
  11. raise RuntimeError(error)
  12. class MetalCompiler(Compiler):
  13. def __init__(self, device:Optional[MetalDevice]):
  14. self.device = device
  15. super().__init__("compile_metal")
  16. def compile(self, src:str) -> bytes:
  17. if self.device is None:
  18. # NOTE: if you run llvm-dis on "air" you can see the llvm bytecode
  19. air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=src.encode('utf-8'))
  20. return subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metallib', '-', '-o', '-'], input=air)
  21. options = Metal.MTLCompileOptions.new()
  22. options.setFastMathEnabled_(getenv("METAL_FAST_MATH"))
  23. try: library = unwrap2(self.device.device.newLibraryWithSource_options_error_(src, options, None))
  24. except AssertionError as e: raise CompileError(e) from e
  25. return library.libraryDataContents().bytes().tobytes()
  26. class MetalProgram:
  27. def __init__(self, device:MetalDevice, name:str, lib:bytes):
  28. self.device, self.name, self.lib = device, name, lib
  29. if DEBUG >= 6:
  30. with tempfile.NamedTemporaryFile(delete=True) as shader:
  31. shader.write(lib)
  32. shader.flush()
  33. ret = os.system(f"cd {pathlib.Path(__file__).parents[2]}/extra/disassemblers/applegpu && python3 compiler_explorer.py {shader.name}")
  34. if ret:
  35. print("Error running disassembler: Make sure you have https://github.com/dougallj/applegpu cloned to tinygrad/extra/disassemblers/applegpu")
  36. assert lib[:4] == b"MTLB", "Invalid Metal library. Could be due to using conda. Try system python or METAL_XCODE=1 DISABLE_COMPILER_CACHE=1."
  37. data = libdispatch.dispatch_data_create(lib, len(lib), None, None)
  38. self.library = unwrap2(self.device.device.newLibraryWithData_error_(data, None))
  39. self.fxn = self.library.newFunctionWithName_(name)
  40. self.pipeline_state = unwrap2(self.device.device.newComputePipelineStateWithFunction_error_(self.fxn, None))
  41. def __call__(self, *bufs, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
  42. if prod(local_size) > self.pipeline_state.maxTotalThreadsPerThreadgroup(): raise RuntimeError(f"local size {local_size} bigger than {self.pipeline_state.maxTotalThreadsPerThreadgroup()} with exec width {self.pipeline_state.threadExecutionWidth()} memory length {self.pipeline_state.staticThreadgroupMemoryLength()}") # noqa: E501
  43. command_buffer = self.device.mtl_queue.commandBuffer()
  44. encoder = command_buffer.computeCommandEncoder()
  45. encoder.setComputePipelineState_(self.pipeline_state)
  46. for i,a in enumerate(bufs): encoder.setBuffer_offset_atIndex_(a.buf, a.offset, i)
  47. for i,a in enumerate(vals,start=len(bufs)): encoder.setBytes_length_atIndex_(ctypes.c_int32(a), 4, i)
  48. encoder.dispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
  49. encoder.endEncoding()
  50. command_buffer.commit()
  51. if wait:
  52. wait_check(command_buffer)
  53. return command_buffer.GPUEndTime() - command_buffer.GPUStartTime()
  54. self.device.mtl_buffers_in_flight.append(command_buffer)
  55. class MetalBuffer:
  56. def __init__(self, buf:Any, size:int, offset=0): self.buf, self.size, self.offset = buf, size, offset
  57. class MetalAllocator(LRUAllocator):
  58. def __init__(self, device:MetalDevice):
  59. self.device:MetalDevice = device
  60. self.track_cross_device: Set[MetalDevice] = set()
  61. super().__init__()
  62. def free_cache(self):
  63. self.device.synchronize()
  64. for x in self.track_cross_device: x.synchronize()
  65. self.track_cross_device.clear()
  66. return super().free_cache()
  67. def _alloc(self, size:int, options) -> MetalBuffer:
  68. ret = self.device.device.newBufferWithLength_options_(size, Metal.MTLResourceStorageModeShared)
  69. if ret is None: raise MemoryError(f"Metal OOM while allocating {size=}")
  70. return MetalBuffer(ret, size)
  71. def _free(self, opaque:MetalBuffer, options): opaque.buf.release()
  72. def transfer(self, dest:MetalBuffer, src:MetalBuffer, sz:int, src_dev: MetalDevice, **kwargs):
  73. src_dev.synchronize()
  74. command_buffer = self.device.mtl_queue.commandBuffer()
  75. encoder = command_buffer.blitCommandEncoder()
  76. encoder.copyFromBuffer_sourceOffset_toBuffer_destinationOffset_size_(src.buf, src.offset, dest.buf, dest.offset, sz)
  77. encoder.endEncoding()
  78. command_buffer.commit()
  79. self.device.mtl_buffers_in_flight.append(command_buffer)
  80. def from_buffer(self, src:memoryview) -> Optional[Any]:
  81. ret = self.device.device.newBufferWithBytesNoCopy_length_options_deallocator_(src, src.nbytes, Metal.MTLResourceStorageModeShared, None)
  82. if ret: self.device.mv_in_metal.append(src)
  83. return MetalBuffer(ret, src.nbytes)
  84. def as_buffer(self, src:MetalBuffer) -> memoryview:
  85. self.device.synchronize()
  86. return src.buf.contents().as_buffer(src.offset+src.size)[src.offset:]
  87. def copyin(self, dest:MetalBuffer, src:memoryview): self.as_buffer(dest)[:] = src
  88. def copyout(self, dest:memoryview, src:MetalBuffer): dest[:] = self.as_buffer(src)
  89. def offset(self, buf:MetalBuffer, size:int, offset:int): return MetalBuffer(buf.buf, size, offset)
  90. class MetalDevice(Compiled):
  91. def __init__(self, device:str):
  92. self.device = Metal.MTLCreateSystemDefaultDevice()
  93. self.mtl_queue = self.device.newCommandQueueWithMaxCommandBufferCount_(1024)
  94. self.mtl_buffers_in_flight: List[Any] = []
  95. self.mv_in_metal: List[memoryview] = []
  96. self.track_cross_buffer: List[Any] = []
  97. from tinygrad.runtime.graph.metal import MetalGraph
  98. super().__init__(device, MetalAllocator(self), MetalRenderer(), MetalCompiler(None if getenv("METAL_XCODE") else self),
  99. functools.partial(MetalProgram, self), MetalGraph)
  100. def synchronize(self):
  101. for cbuf in self.mtl_buffers_in_flight: wait_check(cbuf)
  102. self.mv_in_metal.clear()
  103. self.mtl_buffers_in_flight.clear()
  104. self.track_cross_buffer.clear()