llvmir.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. from typing import Final, Dict, Callable, Any, List, Optional
  2. from llvmlite import ir
  3. from tinygrad.dtype import DType, PtrDType, dtypes
  4. from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps
  5. from tinygrad.codegen.uops import UOps, UOp
  6. from tinygrad.codegen.uopgraph import UOpGraph
  7. from tinygrad.renderer import Renderer
  8. MFLAGS = ('nsz', 'arcp', 'contract', 'afn', 'reassoc') # All from fast math, but nnan and ninf
  9. def is_bool_or_unsigned(dtype: DType): return dtype == dtypes.bool or dtypes.is_unsigned(dtype)
  10. code_for_op: Final[Dict[Op, Callable]] = {
  11. UnaryOps.NEG: lambda builder, x, dtype: builder.neg(x) if dtypes.is_int(dtype) else \
  12. (builder.not_(x) if dtype == dtypes.bool else builder.fneg(x, flags=MFLAGS)),
  13. UnaryOps.EXP2: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.exp2', [x.type]), [x], fastmath=MFLAGS),
  14. UnaryOps.LOG2: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.log2', [x.type]), [x], fastmath=MFLAGS),
  15. UnaryOps.RECIP: lambda builder, x, dtype: builder.fdiv(const(1, dtype), x, flags=MFLAGS),
  16. UnaryOps.SIN: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.sin', [x.type]), [x], fastmath=MFLAGS),
  17. UnaryOps.SQRT: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.sqrt', [x.type]), [x], fastmath=MFLAGS),
  18. BinaryOps.ADD: lambda builder, x, y, dtype: builder.or_(x, y) if dtype == dtypes.bool else builder.add(x, y) if dtypes.is_int(dtype) else builder.fadd(x, y, flags=MFLAGS), # noqa: E501
  19. BinaryOps.MUL: lambda builder, x, y, dtype: builder.mul(x, y) if is_bool_or_unsigned(dtype) or dtypes.is_int(dtype) else builder.fmul(x, y, flags=MFLAGS), # noqa: E501
  20. BinaryOps.IDIV: lambda builder, x, y, dtype: builder.udiv(x, y) if is_bool_or_unsigned(dtype) else builder.sdiv(x, y),
  21. BinaryOps.CMPLT: lambda builder, x, y, dtype: builder.icmp_unsigned("<", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed("<", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered("<", x, y, flags=MFLAGS), # noqa: E501
  22. BinaryOps.CMPNE: lambda builder, x, y, dtype: builder.icmp_unsigned("!=", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed("!=", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered("!=", x, y, flags=MFLAGS), # noqa: E501
  23. BinaryOps.MAX: lambda builder, x, y, dtype: builder.select(builder.icmp_unsigned(">", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed(">", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered(">", x, y, flags=MFLAGS), x, y), # noqa: E501
  24. BinaryOps.MOD: lambda builder, x, y, dtype: builder.urem(x, y) if is_bool_or_unsigned(dtype) else builder.srem(x, y) if dtypes.is_int(dtype) else builder.frem(x, y), # noqa: E501
  25. BinaryOps.XOR: lambda builder, x, y, dtype: builder.xor(x, y), BinaryOps.AND: lambda builder, x, y, dtype: builder.and_(x, y), BinaryOps.OR: lambda builder, x, y, dtype: builder.or_(x, y), # noqa: E501
  26. TernaryOps.WHERE: lambda builder, x, y, z, dtype: builder.select(x, y, z)}
  27. dtype_to_llvm_dtype = { dtypes.bool:ir.IntType(1), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.int16:ir.IntType(16),
  28. dtypes.uint16:ir.IntType(16), dtypes.int32:ir.IntType(32), dtypes.uint32:ir.IntType(32), dtypes.int64:ir.IntType(64), dtypes.uint64:ir.IntType(64),
  29. dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.float64:ir.DoubleType() }
  30. def cast(bb, val, input_type, output_type, bitcast=False):
  31. if input_type == output_type: return val
  32. llvm_type = dtype_to_llvm_dtype[output_type]
  33. if bitcast: return bb[-1].bitcast(val, llvm_type)
  34. if input_type == dtypes.bfloat16:
  35. val = bb[-1].bitcast(bb[-1].shl(bb[-1].sext(val, ir.IntType(32)), ir.Constant(ir.IntType(32), 16)),val, ir.FloatType())
  36. input_type = dtypes.float32
  37. if output_type == dtypes.bfloat16:
  38. val = cast(bb, val, input_type, dtypes.float32)
  39. return bb[-1].trunc(bb[-1].lshr(bb[-1].bitcast(val, ir.IntType(32)), ir.Constant(ir.IntType(32), 16)), ir.IntType(16))
  40. if dtypes.is_float(input_type):
  41. if dtypes.is_float(output_type):
  42. return bb[-1].fpext(val, llvm_type) if output_type.itemsize > input_type.itemsize else bb[-1].fptrunc(val, llvm_type)
  43. if dtypes.is_int(output_type): return bb[-1].fptoui(val, llvm_type) if dtypes.is_unsigned(output_type) else bb[-1].fptosi(val, llvm_type)
  44. if output_type == dtypes.bool: return bb[-1].fcmp_unordered('!=', cast(bb, val, input_type, dtypes.float32), ir.Constant(ir.FloatType(), 0))
  45. if dtypes.is_unsigned(input_type) or input_type == dtypes.bool:
  46. if output_type == dtypes.float16: return bb[-1].fptrunc(bb[-1].uitofp(val, ir.FloatType()), ir.HalfType())
  47. if dtypes.is_float(output_type): return bb[-1].uitofp(val, dtype_to_llvm_dtype[output_type])
  48. if dtypes.is_int(output_type): return bb[-1].trunc(val, llvm_type) if input_type.itemsize > output_type.itemsize else bb[-1].zext(val, llvm_type)
  49. if output_type == dtypes.bool: return bb[-1].icmp_unsigned('!=', val, ir.Constant(val.type, 0))
  50. if dtypes.is_int(input_type):
  51. if output_type == dtypes.float16: return bb[-1].fptrunc(bb[-1].sitofp(val, ir.FloatType()), ir.HalfType())
  52. if dtypes.is_float(output_type): return bb[-1].sitofp(val, llvm_type)
  53. if dtypes.is_int(output_type): return bb[-1].trunc(val, llvm_type) if input_type.itemsize > output_type.itemsize else bb[-1].sext(val, llvm_type)
  54. if output_type == dtypes.bool: return bb[-1].icmp_signed('!=', val, ir.Constant(val.type, 0))
  55. raise NotImplementedError(f"cast from {input_type} -> {output_type} not implemented")
  56. def const(args, dtype): return ir.Constant(dtype_to_llvm_dtype[dtype], args)
  57. class LLVMRenderer(Renderer):
  58. device = "LLVM"
  59. supports_float4 = False
  60. has_local = False
  61. has_shared = False
  62. global_max = None
  63. def render(self, name:str, uops:UOpGraph) -> str:
  64. # all llvm stuff goes into a module
  65. module = ir.Module(name=__file__)
  66. # extract global buffers (NOTE: this isn't right if DEFINE_GLOBAL is out of order)
  67. buf_to_dtype = {u.arg:u.dtype for u in uops if u.op in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}}
  68. buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())}
  69. # create llvm function
  70. func_dtypes = [(dtype_to_llvm_dtype[dtype],dtype) for dtype in buf_to_dtype.values() if dtype is not None]
  71. func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() if isinstance(dt, PtrDType) else x for x,dt in func_dtypes]), name=name)
  72. for a in func.args:
  73. if a.type.is_pointer: a.add_attribute("noalias")
  74. bb = [ir.IRBuilder(func.append_basic_block("entry"))]
  75. loop_blocks: List = []
  76. reduce_phis: List = []
  77. # TODO: newvar probably shouldn't be optional
  78. lvars: Dict[Optional[UOp], Any] = {} # this Any is an llvm type
  79. for bufname,dtype in buf_to_dtype.items():
  80. if not isinstance(dtype, PtrDType) and dtype == dtypes.int32: lvars[bufname] = bb[-1].sext(func.args[buf_index[bufname]], ir.IntType(32))
  81. for u in uops:
  82. uop,dtype,src,args = u.op,u.dtype,u.src,u.arg
  83. if uop is UOps.STORE:
  84. element = cast(bb, lvars[src[2]], src[2].dtype, src[0].dtype)
  85. if len(src) > 3:
  86. with bb[-1].if_then(lvars[src[3]]):
  87. bb[-1].store(element, bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True))
  88. else:
  89. bb[-1].store(element, bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True))
  90. elif uop is UOps.ENDRANGE:
  91. loop_entry_bb, phis = loop_blocks.pop()
  92. idx_p1 = bb[-1].add(lvars[src[0]], ir.Constant(ir.IntType(32), 1))
  93. lvars[src[0]].add_incoming(idx_p1, bb[-1].block)
  94. for n,phi in phis: phi.add_incoming(lvars[n], bb[-1].block)
  95. bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{len(loop_blocks)}")))
  96. bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[src[0].src[1]]), loop_entry_bb, bb[-1].block)
  97. else:
  98. assert dtype is not None, f"None dtype for uop {uop}"
  99. if uop is UOps.RANGE:
  100. bb.append(ir.IRBuilder(func.append_basic_block(f"loop_body_{len(loop_blocks)}")))
  101. bb[-2].branch(bb[-1].block)
  102. phis = []
  103. for rp in reduce_phis:
  104. incoming = lvars[rp]
  105. lvars[rp] = bb[-1].phi(dtype_to_llvm_dtype[rp.dtype])
  106. lvars[rp].add_incoming(incoming, bb[-2].block)
  107. phis.append((rp, lvars[rp]))
  108. lvars[u] = bb[-1].phi(ir.IntType(32), name=f"loop{len(loop_blocks)}")
  109. lvars[u].add_incoming(lvars[src[0]], bb[-2].block)
  110. loop_blocks.append((bb[-1].block, phis))
  111. elif uop is UOps.DEFINE_ACC:
  112. lvars[u] = const(src[0].arg, dtype)
  113. reduce_phis.append(u)
  114. elif uop is UOps.LOAD:
  115. if len(src) > 2:
  116. aug_idx = bb[-1].select(lvars[src[2]], lvars[src[1]], ir.Constant(ir.IntType(32), 0))
  117. val = bb[-1].load(bb[-1].gep(lvars[src[0]], [aug_idx], inbounds=True))
  118. val = bb[-1].select(lvars[src[2]], val, lvars[src[3]])
  119. else:
  120. val = bb[-1].load(bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True))
  121. lvars[u] = val
  122. elif uop is UOps.PHI:
  123. lvars[u] = lvars[src[1]]
  124. # PHI UOps can link to other PHI Uops, backtrace this to DEFINE_ACC
  125. backward = src[0]
  126. while backward.op is UOps.PHI: backward = backward.src[0]
  127. lvars[backward] = lvars[u]
  128. elif uop is UOps.ALU:
  129. lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in src], dtype if args not in (BinaryOps.CMPLT, BinaryOps.CMPNE) else src[0].dtype)
  130. elif uop in {UOps.CAST, UOps.BITCAST}: lvars[u] = cast(bb, lvars[src[0]], src[0].dtype, dtype, bitcast=uop is UOps.BITCAST)
  131. elif uop in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}: lvars[u] = func.args[buf_index[args]]
  132. elif uop is UOps.CONST: lvars[u] = const(args, dtype)
  133. else: raise RuntimeError(f"failed to render {uop}")
  134. bb[-1].ret_void()
  135. return str(module)