clang.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. from typing import List, Dict, cast
  2. import ctypes
  3. from tinygrad.helpers import dedup, cpu_time_execution, GraphException, DEBUG
  4. from tinygrad.engine.jit import GraphRunner
  5. from tinygrad.device import Buffer, Device
  6. from tinygrad.engine.realize import ExecItem, CompiledRunner
  7. from tinygrad.shape.symbolic import Variable
  8. from tinygrad.runtime.ops_clang import ClangProgram
  9. from tinygrad.renderer.cstyle import ClangRenderer
  10. render_dtype = ClangRenderer().render_dtype
  11. class ClangGraph(GraphRunner):
  12. def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
  13. super().__init__(jit_cache, input_rawbuffers, var_vals)
  14. if not all(isinstance(ji.prg, CompiledRunner) for ji in jit_cache): raise GraphException
  15. prgs = '\n'.join(dedup([cast(CompiledRunner, ji.prg).p.src for ji in jit_cache]))
  16. args = [f"{render_dtype(x.dtype)}* arg{i}" for i,x in enumerate(input_rawbuffers)]
  17. args += sorted([f"int {v.expr}" for v in var_vals])
  18. code = ["void batched("+','.join(args)+") {"]
  19. for ji in jit_cache:
  20. args = []
  21. for buf in ji.bufs:
  22. assert buf is not None
  23. if buf in input_rawbuffers:
  24. args.append(f"arg{input_rawbuffers.index(buf)}")
  25. else:
  26. args.append(f"({render_dtype(buf.dtype)}*)0x{ctypes.addressof(buf._buf):X}")
  27. args += [x.expr for x in cast(CompiledRunner, ji.prg).p.vars]
  28. code.append(f" {cast(CompiledRunner, ji.prg).p.function_name}({','.join(args)});")
  29. code.append("}")
  30. if DEBUG >= 4: print("\n".join(code))
  31. compiler = Device["CLANG"].compiler
  32. assert compiler is not None
  33. self.clprg = ClangProgram("batched", compiler.compile(prgs+"\n"+"\n".join(code))) # no point in caching the pointers
  34. def __call__(self, rawbufs: List[Buffer], var_vals: Dict[Variable, int], wait=False):
  35. return cpu_time_execution(
  36. lambda: self.clprg(*[x._buf for x in rawbufs], *[x[1] for x in sorted(var_vals.items(), key=lambda x: x[0].expr)]), enable=wait)