ops_webgpu.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. from wgpu.utils.device import get_default_device
  2. from tinygrad.device import Compiled, Allocator, CompilerOptions
  3. from tinygrad.renderer.cstyle import WGSLRenderer
  4. import wgpu
  5. wgpu_device = get_default_device()
  6. def create_uniform(val: int) -> wgpu.GPUBuffer:
  7. buf = wgpu_device.create_buffer(size=4, usage=wgpu.BufferUsage.UNIFORM | wgpu.BufferUsage.COPY_DST)
  8. wgpu_device.queue.write_buffer(buf, 0, val.to_bytes(4, "little"))
  9. return buf
  10. class WebGPUProgram:
  11. def __init__(self, name:str, lib:bytes):
  12. self.name, self.lib, self.prg = name, lib, wgpu_device.create_shader_module(code=lib) # NOTE: this is the compiler
  13. def __call__(self, *bufs, global_size, local_size, vals=(), wait=False):
  14. assert len(bufs) <= 8, "WEBGPU only supports 8 buffers"
  15. binding_layouts = [{"binding": i, "visibility": wgpu.ShaderStage.COMPUTE, "buffer": {"type": wgpu.BufferBindingType.uniform if i >= len(bufs) else wgpu.BufferBindingType.storage }} for i in range(len(bufs)+len(vals))] # noqa: E501
  16. bindings = [{"binding": i, "resource": {"buffer": create_uniform(x) if i >= len(bufs) else x, "offset": 0, "size": 4 if i >= len(bufs) else x.size}} for i,x in enumerate(bufs+vals)] # noqa: E501
  17. bind_group_layout = wgpu_device.create_bind_group_layout(entries=binding_layouts)
  18. pipeline_layout = wgpu_device.create_pipeline_layout(bind_group_layouts=[bind_group_layout])
  19. bind_group = wgpu_device.create_bind_group(layout=bind_group_layout, entries=bindings)
  20. compute_pipeline = wgpu_device.create_compute_pipeline(layout=pipeline_layout,compute={"module": self.prg, "entry_point": self.name},)
  21. command_encoder = wgpu_device.create_command_encoder()
  22. compute_pass = command_encoder.begin_compute_pass()
  23. compute_pass.set_pipeline(compute_pipeline)
  24. compute_pass.set_bind_group(0, bind_group, [], 0, 999999) # last 2 not used
  25. compute_pass.dispatch_workgroups(*global_size) # x y z
  26. compute_pass.end()
  27. wgpu_device.queue.submit([command_encoder.finish()])
  28. class WebGpuAllocator(Allocator):
  29. def _alloc(self, size: int):
  30. return wgpu_device.create_buffer(size=size, usage=wgpu.BufferUsage.STORAGE | wgpu.BufferUsage.COPY_DST | wgpu.BufferUsage.COPY_SRC)
  31. def copyin(self, dest, src: memoryview): wgpu_device.queue.write_buffer(dest, 0, src)
  32. def copyout(self, dest, src: memoryview): dest[:] = wgpu_device.queue.read_buffer(src, 0) # TODO: remove this copy
  33. class WebGpuDevice(Compiled):
  34. def __init__(self, device:str):
  35. super().__init__(WebGpuAllocator(), CompilerOptions(device="WEBGPU", supports_float4=False, local_max=[256, 256, 64],
  36. global_max=[65535, 65535, 65535]), WGSLRenderer, lambda x: x, WebGPUProgram)