| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111 |
- from __future__ import annotations
- import os, subprocess, pathlib, ctypes, tempfile, functools
- import Metal, libdispatch
- from typing import List, Set, Any, Tuple, Optional
- from tinygrad.helpers import prod, getenv, DEBUG, unwrap2
- from tinygrad.device import Compiled, Compiler, CompileError, LRUAllocator
- from tinygrad.renderer.cstyle import MetalRenderer
- def wait_check(cbuf: Any):
- cbuf.waitUntilCompleted()
- if (error := cbuf.error()) is not None:
- raise RuntimeError(error)
- class MetalCompiler(Compiler):
- def __init__(self, device:Optional[MetalDevice]):
- self.device = device
- super().__init__("compile_metal")
- def compile(self, src:str) -> bytes:
- if self.device is None:
- # NOTE: if you run llvm-dis on "air" you can see the llvm bytecode
- air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=src.encode('utf-8'))
- return subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metallib', '-', '-o', '-'], input=air)
- options = Metal.MTLCompileOptions.new()
- options.setFastMathEnabled_(getenv("METAL_FAST_MATH"))
- try: library = unwrap2(self.device.device.newLibraryWithSource_options_error_(src, options, None))
- except AssertionError as e: raise CompileError(e) from e
- return library.libraryDataContents().bytes().tobytes()
- class MetalProgram:
- def __init__(self, device:MetalDevice, name:str, lib:bytes):
- self.device, self.name, self.lib = device, name, lib
- if DEBUG >= 6:
- with tempfile.NamedTemporaryFile(delete=True) as shader:
- shader.write(lib)
- shader.flush()
- ret = os.system(f"cd {pathlib.Path(__file__).parents[2]}/extra/disassemblers/applegpu && python3 compiler_explorer.py {shader.name}")
- if ret:
- print("Error running disassembler: Make sure you have https://github.com/dougallj/applegpu cloned to tinygrad/extra/disassemblers/applegpu")
- assert lib[:4] == b"MTLB", "Invalid Metal library. Could be due to using conda. Try system python or METAL_XCODE=1 DISABLE_COMPILER_CACHE=1."
- data = libdispatch.dispatch_data_create(lib, len(lib), None, None)
- self.library = unwrap2(self.device.device.newLibraryWithData_error_(data, None))
- self.fxn = self.library.newFunctionWithName_(name)
- self.pipeline_state = unwrap2(self.device.device.newComputePipelineStateWithFunction_error_(self.fxn, None))
- def __call__(self, *bufs, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
- if prod(local_size) > self.pipeline_state.maxTotalThreadsPerThreadgroup(): raise RuntimeError(f"local size {local_size} bigger than {self.pipeline_state.maxTotalThreadsPerThreadgroup()} with exec width {self.pipeline_state.threadExecutionWidth()} memory length {self.pipeline_state.staticThreadgroupMemoryLength()}") # noqa: E501
- command_buffer = self.device.mtl_queue.commandBuffer()
- encoder = command_buffer.computeCommandEncoder()
- encoder.setComputePipelineState_(self.pipeline_state)
- for i,a in enumerate(bufs): encoder.setBuffer_offset_atIndex_(a.buf, a.offset, i)
- for i,a in enumerate(vals,start=len(bufs)): encoder.setBytes_length_atIndex_(ctypes.c_int32(a), 4, i)
- encoder.dispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
- encoder.endEncoding()
- command_buffer.commit()
- if wait:
- wait_check(command_buffer)
- return command_buffer.GPUEndTime() - command_buffer.GPUStartTime()
- self.device.mtl_buffers_in_flight.append(command_buffer)
- class MetalBuffer:
- def __init__(self, buf:Any, size:int, offset=0): self.buf, self.size, self.offset = buf, size, offset
- class MetalAllocator(LRUAllocator):
- def __init__(self, device:MetalDevice):
- self.device:MetalDevice = device
- self.track_cross_device: Set[MetalDevice] = set()
- super().__init__()
- def free_cache(self):
- self.device.synchronize()
- for x in self.track_cross_device: x.synchronize()
- self.track_cross_device.clear()
- return super().free_cache()
- def _alloc(self, size:int, options) -> MetalBuffer:
- ret = self.device.device.newBufferWithLength_options_(size, Metal.MTLResourceStorageModeShared)
- if ret is None: raise MemoryError(f"Metal OOM while allocating {size=}")
- return MetalBuffer(ret, size)
- def _free(self, opaque:MetalBuffer, options): opaque.buf.release()
- def transfer(self, dest:MetalBuffer, src:MetalBuffer, sz:int, src_dev: MetalDevice, **kwargs):
- src_dev.synchronize()
- command_buffer = self.device.mtl_queue.commandBuffer()
- encoder = command_buffer.blitCommandEncoder()
- encoder.copyFromBuffer_sourceOffset_toBuffer_destinationOffset_size_(src.buf, src.offset, dest.buf, dest.offset, sz)
- encoder.endEncoding()
- command_buffer.commit()
- self.device.mtl_buffers_in_flight.append(command_buffer)
- def from_buffer(self, src:memoryview) -> Optional[Any]:
- ret = self.device.device.newBufferWithBytesNoCopy_length_options_deallocator_(src, src.nbytes, Metal.MTLResourceStorageModeShared, None)
- if ret: self.device.mv_in_metal.append(src)
- return MetalBuffer(ret, src.nbytes)
- def as_buffer(self, src:MetalBuffer) -> memoryview:
- self.device.synchronize()
- return src.buf.contents().as_buffer(src.offset+src.size)[src.offset:]
- def copyin(self, dest:MetalBuffer, src:memoryview): self.as_buffer(dest)[:] = src
- def copyout(self, dest:memoryview, src:MetalBuffer): dest[:] = self.as_buffer(src)
- def offset(self, buf:MetalBuffer, size:int, offset:int): return MetalBuffer(buf.buf, size, offset)
- class MetalDevice(Compiled):
- def __init__(self, device:str):
- self.device = Metal.MTLCreateSystemDefaultDevice()
- self.mtl_queue = self.device.newCommandQueueWithMaxCommandBufferCount_(1024)
- self.mtl_buffers_in_flight: List[Any] = []
- self.mv_in_metal: List[memoryview] = []
- self.track_cross_buffer: List[Any] = []
- from tinygrad.runtime.graph.metal import MetalGraph
- super().__init__(device, MetalAllocator(self), MetalRenderer(), MetalCompiler(None if getenv("METAL_XCODE") else self),
- functools.partial(MetalProgram, self), MetalGraph)
- def synchronize(self):
- for cbuf in self.mtl_buffers_in_flight: wait_check(cbuf)
- self.mv_in_metal.clear()
- self.mtl_buffers_in_flight.clear()
- self.track_cross_buffer.clear()
|