| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156 |
- from typing import Final, Dict, Callable, Any, List, Optional
- from llvmlite import ir
- from tinygrad.dtype import DType, PtrDType, dtypes
- from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps
- from tinygrad.codegen.uops import UOps, UOp
- from tinygrad.codegen.uopgraph import UOpGraph
- from tinygrad.renderer import Renderer
- MFLAGS = ('nsz', 'arcp', 'contract', 'afn', 'reassoc') # All from fast math, but nnan and ninf
- def is_bool_or_unsigned(dtype: DType): return dtype == dtypes.bool or dtypes.is_unsigned(dtype)
- code_for_op: Final[Dict[Op, Callable]] = {
- UnaryOps.NEG: lambda builder, x, dtype: builder.neg(x) if dtypes.is_int(dtype) else \
- (builder.not_(x) if dtype == dtypes.bool else builder.fneg(x, flags=MFLAGS)),
- UnaryOps.EXP2: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.exp2', [x.type]), [x], fastmath=MFLAGS),
- UnaryOps.LOG2: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.log2', [x.type]), [x], fastmath=MFLAGS),
- UnaryOps.RECIP: lambda builder, x, dtype: builder.fdiv(const(1, dtype), x, flags=MFLAGS),
- UnaryOps.SIN: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.sin', [x.type]), [x], fastmath=MFLAGS),
- UnaryOps.SQRT: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.sqrt', [x.type]), [x], fastmath=MFLAGS),
- BinaryOps.ADD: lambda builder, x, y, dtype: builder.or_(x, y) if dtype == dtypes.bool else builder.add(x, y) if dtypes.is_int(dtype) else builder.fadd(x, y, flags=MFLAGS), # noqa: E501
- BinaryOps.MUL: lambda builder, x, y, dtype: builder.mul(x, y) if is_bool_or_unsigned(dtype) or dtypes.is_int(dtype) else builder.fmul(x, y, flags=MFLAGS), # noqa: E501
- BinaryOps.IDIV: lambda builder, x, y, dtype: builder.udiv(x, y) if is_bool_or_unsigned(dtype) else builder.sdiv(x, y),
- BinaryOps.CMPLT: lambda builder, x, y, dtype: builder.icmp_unsigned("<", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed("<", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered("<", x, y, flags=MFLAGS), # noqa: E501
- BinaryOps.CMPNE: lambda builder, x, y, dtype: builder.icmp_unsigned("!=", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed("!=", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered("!=", x, y, flags=MFLAGS), # noqa: E501
- BinaryOps.MAX: lambda builder, x, y, dtype: builder.select(builder.icmp_unsigned(">", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed(">", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered(">", x, y, flags=MFLAGS), x, y), # noqa: E501
- BinaryOps.MOD: lambda builder, x, y, dtype: builder.urem(x, y) if is_bool_or_unsigned(dtype) else builder.srem(x, y) if dtypes.is_int(dtype) else builder.frem(x, y), # noqa: E501
- BinaryOps.XOR: lambda builder, x, y, dtype: builder.xor(x, y), BinaryOps.AND: lambda builder, x, y, dtype: builder.and_(x, y), BinaryOps.OR: lambda builder, x, y, dtype: builder.or_(x, y), # noqa: E501
- TernaryOps.WHERE: lambda builder, x, y, z, dtype: builder.select(x, y, z)}
- dtype_to_llvm_dtype = { dtypes.bool:ir.IntType(1), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.int16:ir.IntType(16),
- dtypes.uint16:ir.IntType(16), dtypes.int32:ir.IntType(32), dtypes.uint32:ir.IntType(32), dtypes.int64:ir.IntType(64), dtypes.uint64:ir.IntType(64),
- dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.float64:ir.DoubleType() }
- def cast(bb, val, input_type, output_type, bitcast=False):
- if input_type == output_type: return val
- llvm_type = dtype_to_llvm_dtype[output_type]
- if bitcast: return bb[-1].bitcast(val, llvm_type)
- if input_type == dtypes.bfloat16:
- val = bb[-1].bitcast(bb[-1].shl(bb[-1].sext(val, ir.IntType(32)), ir.Constant(ir.IntType(32), 16)),val, ir.FloatType())
- input_type = dtypes.float32
- if output_type == dtypes.bfloat16:
- val = cast(bb, val, input_type, dtypes.float32)
- return bb[-1].trunc(bb[-1].lshr(bb[-1].bitcast(val, ir.IntType(32)), ir.Constant(ir.IntType(32), 16)), ir.IntType(16))
- if dtypes.is_float(input_type):
- if dtypes.is_float(output_type):
- return bb[-1].fpext(val, llvm_type) if output_type.itemsize > input_type.itemsize else bb[-1].fptrunc(val, llvm_type)
- if dtypes.is_int(output_type): return bb[-1].fptoui(val, llvm_type) if dtypes.is_unsigned(output_type) else bb[-1].fptosi(val, llvm_type)
- if output_type == dtypes.bool: return bb[-1].fcmp_unordered('!=', cast(bb, val, input_type, dtypes.float32), ir.Constant(ir.FloatType(), 0))
- if dtypes.is_unsigned(input_type) or input_type == dtypes.bool:
- if output_type == dtypes.float16: return bb[-1].fptrunc(bb[-1].uitofp(val, ir.FloatType()), ir.HalfType())
- if dtypes.is_float(output_type): return bb[-1].uitofp(val, dtype_to_llvm_dtype[output_type])
- if dtypes.is_int(output_type): return bb[-1].trunc(val, llvm_type) if input_type.itemsize > output_type.itemsize else bb[-1].zext(val, llvm_type)
- if output_type == dtypes.bool: return bb[-1].icmp_unsigned('!=', val, ir.Constant(val.type, 0))
- if dtypes.is_int(input_type):
- if output_type == dtypes.float16: return bb[-1].fptrunc(bb[-1].sitofp(val, ir.FloatType()), ir.HalfType())
- if dtypes.is_float(output_type): return bb[-1].sitofp(val, llvm_type)
- if dtypes.is_int(output_type): return bb[-1].trunc(val, llvm_type) if input_type.itemsize > output_type.itemsize else bb[-1].sext(val, llvm_type)
- if output_type == dtypes.bool: return bb[-1].icmp_signed('!=', val, ir.Constant(val.type, 0))
- raise NotImplementedError(f"cast from {input_type} -> {output_type} not implemented")
- def const(args, dtype): return ir.Constant(dtype_to_llvm_dtype[dtype], args)
- class LLVMRenderer(Renderer):
- device = "LLVM"
- supports_float4 = False
- has_local = False
- has_shared = False
- global_max = None
- def render(self, name:str, uops:UOpGraph) -> str:
- # all llvm stuff goes into a module
- module = ir.Module(name=__file__)
- # extract global buffers (NOTE: this isn't right if DEFINE_GLOBAL is out of order)
- buf_to_dtype = {u.arg:u.dtype for u in uops if u.op in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}}
- buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())}
- # create llvm function
- func_dtypes = [(dtype_to_llvm_dtype[dtype],dtype) for dtype in buf_to_dtype.values() if dtype is not None]
- func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() if isinstance(dt, PtrDType) else x for x,dt in func_dtypes]), name=name)
- for a in func.args:
- if a.type.is_pointer: a.add_attribute("noalias")
- bb = [ir.IRBuilder(func.append_basic_block("entry"))]
- loop_blocks: List = []
- reduce_phis: List = []
- # TODO: newvar probably shouldn't be optional
- lvars: Dict[Optional[UOp], Any] = {} # this Any is an llvm type
- for bufname,dtype in buf_to_dtype.items():
- if not isinstance(dtype, PtrDType) and dtype == dtypes.int32: lvars[bufname] = bb[-1].sext(func.args[buf_index[bufname]], ir.IntType(32))
- for u in uops:
- uop,dtype,src,args = u.op,u.dtype,u.src,u.arg
- if uop is UOps.STORE:
- element = cast(bb, lvars[src[2]], src[2].dtype, src[0].dtype)
- if len(src) > 3:
- with bb[-1].if_then(lvars[src[3]]):
- bb[-1].store(element, bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True))
- else:
- bb[-1].store(element, bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True))
- elif uop is UOps.ENDRANGE:
- loop_entry_bb, phis = loop_blocks.pop()
- idx_p1 = bb[-1].add(lvars[src[0]], ir.Constant(ir.IntType(32), 1))
- lvars[src[0]].add_incoming(idx_p1, bb[-1].block)
- for n,phi in phis: phi.add_incoming(lvars[n], bb[-1].block)
- bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{len(loop_blocks)}")))
- bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[src[0].src[1]]), loop_entry_bb, bb[-1].block)
- else:
- assert dtype is not None, f"None dtype for uop {uop}"
- if uop is UOps.RANGE:
- bb.append(ir.IRBuilder(func.append_basic_block(f"loop_body_{len(loop_blocks)}")))
- bb[-2].branch(bb[-1].block)
- phis = []
- for rp in reduce_phis:
- incoming = lvars[rp]
- lvars[rp] = bb[-1].phi(dtype_to_llvm_dtype[rp.dtype])
- lvars[rp].add_incoming(incoming, bb[-2].block)
- phis.append((rp, lvars[rp]))
- lvars[u] = bb[-1].phi(ir.IntType(32), name=f"loop{len(loop_blocks)}")
- lvars[u].add_incoming(lvars[src[0]], bb[-2].block)
- loop_blocks.append((bb[-1].block, phis))
- elif uop is UOps.DEFINE_ACC:
- lvars[u] = const(src[0].arg, dtype)
- reduce_phis.append(u)
- elif uop is UOps.LOAD:
- if len(src) > 2:
- aug_idx = bb[-1].select(lvars[src[2]], lvars[src[1]], ir.Constant(ir.IntType(32), 0))
- val = bb[-1].load(bb[-1].gep(lvars[src[0]], [aug_idx], inbounds=True))
- val = bb[-1].select(lvars[src[2]], val, lvars[src[3]])
- else:
- val = bb[-1].load(bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True))
- lvars[u] = val
- elif uop is UOps.PHI:
- lvars[u] = lvars[src[1]]
- # PHI UOps can link to other PHI Uops, backtrace this to DEFINE_ACC
- backward = src[0]
- while backward.op is UOps.PHI: backward = backward.src[0]
- lvars[backward] = lvars[u]
- elif uop is UOps.ALU:
- lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in src], dtype if args not in (BinaryOps.CMPLT, BinaryOps.CMPNE) else src[0].dtype)
- elif uop in {UOps.CAST, UOps.BITCAST}: lvars[u] = cast(bb, lvars[src[0]], src[0].dtype, dtype, bitcast=uop is UOps.BITCAST)
- elif uop in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}: lvars[u] = func.args[buf_index[args]]
- elif uop is UOps.CONST: lvars[u] = const(args, dtype)
- else: raise RuntimeError(f"failed to render {uop}")
- bb[-1].ret_void()
- return str(module)
|