| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394 |
- from typing import Dict, List, Optional, Tuple, Union, DefaultDict, cast, Literal, Callable
- import os, math
- from collections import defaultdict, Counter
- from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
- from tinygrad.helpers import strip_parens, getenv, prod, dedup
- from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, ConstType
- from tinygrad.codegen.uops import UOps, UOp
- from tinygrad.codegen.uopgraph import UOpGraph
- from tinygrad.renderer import Renderer, TensorCore
- class CStyleLanguage(Renderer):
- kernel_prefix: str = ""
- buffer_prefix: str = ""
- buffer_suffix: str = ""
- smem_align: str = ""
- smem_prefix: str = ""
- smem_prefix_for_cast: bool = True
- arg_int_prefix: str = "const int"
- barrier: str = ""
- code_for_workitem: Dict[Union[Literal["g"], Literal["l"], Literal["i"]], Callable] = {}
- extra_args: List[str] = []
- float4: Optional[str] = None
- uses_vload: bool = False
- uses_ptr_arithmetic: bool = False
- type_map: Dict[DType, str] = {}
- code_for_op: Dict = {
- UnaryOps.NEG: lambda x,dtype: f"(!{x})" if dtype == dtypes.bool else f"(-{x})", UnaryOps.SQRT: lambda x,dtype: f"sqrt({x})",
- UnaryOps.RECIP: lambda x,dtype: f"(1/{x})",
- UnaryOps.EXP2: lambda x,dtype: f"exp2({x})", UnaryOps.LOG2: lambda x,dtype: f"log2({x})", UnaryOps.SIN: lambda x,dtype: f"sin({x})",
- BinaryOps.ADD: lambda a,b,dtype: f"({a}+{b})", BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})",
- BinaryOps.IDIV: lambda a,b,dtype: f"({a}/{b})", BinaryOps.MUL: lambda a,b,dtype: f"({a}*{b})", BinaryOps.MOD: lambda a,b,dtype: f"({a}%{b})",
- BinaryOps.CMPLT: lambda a,b,dtype: f"({a}<{b})", BinaryOps.CMPNE: lambda a,b,dtype: f"({a}!={b})", BinaryOps.XOR: lambda a,b,dtype: f"({a}^{b})",
- BinaryOps.AND: lambda a,b,dtype: f"({a}&{b})", BinaryOps.OR: lambda a,b,dtype: f"({a}|{b})",
- TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})"}
- # returns a str expression of the casted xs with the given type
- def render_cast(self, x:str, var_dtype:DType, bitcast=False) -> str:
- if bitcast: return f"(*(({self.buffer_prefix}{self.render_dtype(var_dtype)}*)&{x}))"
- return f"({self.render_dtype(var_dtype)})({x})"
- # returns a str expression of the vectorized xs with the given type
- def render_vectorize(self, x:List[str], var_dtype:DType) -> str:
- assert len(x) == var_dtype.count, f"cast is wrong size {len(x)} != {var_dtype.count}"
- assert self.float4 is not None, "vectorized cast is not supported on this platform"
- return f"{self.float4.replace('float4', self.render_dtype(var_dtype))}({','.join(x)})"
- # returns a str expression of the const with the given type
- def render_const(self, x:ConstType, dtype:DType) -> str:
- if math.isnan(x): val = "NAN"
- elif math.isinf(x): val = ("-" if x < 0 else "") + "INFINITY"
- elif dtype.scalar() == dtypes.bool: val = "1" if x else "0"
- elif dtype.scalar() == dtypes.float: val = f"{x}f"
- else: val = str(x)
- if dtype.count > 1: return self.render_vectorize([val] * dtype.count, dtype)
- return (self.render_cast(val, dtype) if dtype not in [dtypes.float, dtypes.int, dtypes.bool] else val)
- # returns a str expression of the loaded value with the output type
- def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str:
- if isinstance(buf_dtype, ImageDType):
- assert output_dtype == dtypes.float.vec(4), f"images must be float4, getting {output_dtype}"
- return f"read_imagef({buf_name}, smp, {idx})"
- if self.uses_vload and buf_dtype.scalar() == dtypes.float16 and output_dtype.scalar() != dtypes.float16:
- return f"vload_half{'' if output_dtype.count == 1 else str(output_dtype.count)}(0, {buf_name}+{idx})"
- if output_dtype.count > 1:
- return f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{self.render_dtype(buf_dtype)}{output_dtype.count}*)({buf_name}+{idx}))" # noqa: E501
- return f"*({buf_name}+{idx})" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}]"
- def get_kernel_modifier(self, uops:UOpGraph) -> str: return ""
- def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:UOpGraph, prefix=None) -> str:
- tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,(dtype,_) in bufs) else "" # noqa: E501
- buftypes = [(name,f"{'write_only' if mutable else 'read_only'} image2d_t" if dtype.name.startswith('image') else
- ("" if mutable else "const ")+self.buffer_prefix+self.render_dtype(dtype)+"*"+self.buffer_suffix if isinstance(dtype, PtrDType) else
- self.arg_int_prefix if dtype == dtypes.int else None) for name,(dtype,mutable) in bufs]
- prg = ''.join([f"{self.kernel_prefix}void {self.get_kernel_modifier(uops)}{function_name}(",] +
- [', '.join([f'{t} {name}' for name,t in buftypes] + self.extra_args)] +
- [") {\n" + tmp] + ['\n'.join(kernel), "\n}"])
- return prg if prefix is None else "\n".join(prefix)+f"\n{prg}"
- # returns a str statement that does the store
- def render_store(self, buf_name:str, buf_dtype:DType, var_name:str, var_dtype:DType, idx:str, local=False) -> str:
- if isinstance(buf_dtype, ImageDType):
- assert var_dtype == dtypes.float.vec(4), f"images must be float4, getting {var_dtype}"
- return f"write_imagef({buf_name}, {idx}, {var_name});"
- if self.uses_vload and buf_dtype.scalar() == dtypes.float16 and var_dtype.scalar() != dtypes.float16:
- return f"vstore_half{'' if var_dtype.count == 1 else str(var_dtype.count)}({var_name}, 0, {buf_name}+{idx});"
- if var_dtype.count > 1:
- prefix = self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix
- return f"*(({prefix}{self.render_dtype(buf_dtype)}{var_dtype.count}*)({buf_name}+{idx})) = {var_name};"
- return f"*({buf_name}+{idx}) = {var_name};" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}] = {var_name};"
- def render_local(self, name:str, dtype:DType, size:int): return self.smem_align + self.smem_prefix + f"{self.render_dtype(dtype)} {name}[{size}];"
- def render_dtype(self, var_dtype:DType) -> str: return self.type_map.get(var_dtype, var_dtype.name)
- def render(self, name:str, uops:UOpGraph) -> str:
- kernel = []
- bufs: List[Tuple[str, Tuple[DType, bool]]] = []
- depth = 1
- def kk(s): kernel.append(" "*depth+s)
- c: DefaultDict[str, int] = defaultdict(int)
- r: Dict[UOp, str] = {}
- def ssa(prefix:str, u:Optional[UOp]=None):
- nonlocal c, r
- ret = f"{prefix}{c[prefix]}"
- if u is not None: r[u] = ret
- c[prefix] += 1
- return ret
- child_count = Counter(v for ru in uops for v in ru.src)
- seen_vars = set()
- for u in uops:
- uop,dtype,src,args = u.op,u.dtype,u.src,u.arg
- # these four uops don't have output dtypes
- if uop is UOps.IF:
- kk(f"if ({r[src[0]]}) {{")
- depth += 1
- elif uop is UOps.BARRIER: kk(self.barrier)
- elif uop in {UOps.ENDRANGE, UOps.ENDIF}:
- depth -= 1
- kk("}")
- elif uop is UOps.STORE:
- assert src[0].dtype is not None and src[2].dtype is not None
- rendered_store = self.render_store(r[src[0]], src[0].dtype, r[src[2]], src[2].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL)
- kk(f"if ({r[src[3]]}) {{ {rendered_store} }}" if len(src) > 3 else rendered_store)
- else:
- assert dtype is not None, f"None dtype for uop {uop}"
- if uop is UOps.RANGE:
- kk(f"for (int {(expr := ssa('ridx',u))} = {r[src[0]]}; {expr} < {r[src[1]]}; {expr}++) {{")
- depth += 1
- elif uop is UOps.ALU:
- # remove parens if ALU types are the same. TODO: can do more here
- if args in {BinaryOps.ADD,BinaryOps.MUL,BinaryOps.XOR}: operands = [strip_parens(r[v]) if v.arg == args else r[v]for v in src]
- else: operands = [r[v] for v in src]
- val = self.code_for_op[args](*operands, dtype)
- assert child_count[u] != 0, f"childless ALU op found {u}"
- # TODO: fix index rendering issue. fix clang nested max macro issue
- if child_count[u] <= 1 and args is not BinaryOps.MAX and not getenv("EXPAND_SSA"): r[u] = val
- else: kk(f"{self.render_dtype(dtype)} {ssa('alu',u)} = {val};")
- elif uop is UOps.SPECIAL:
- kk(f"int {args[1]} = {self.code_for_workitem[args[1][0]](args[0])}; /* {args[2]} */")
- r[u] = args[1]
- elif uop is UOps.LOAD:
- val = self.render_load(dtype, r[src[0]], src[0].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL)
- # NOTE: this relies on the load not happening if it's in the unselected branch
- if len(src) > 3 and src[2].op is UOps.ALU: val = self.code_for_op[TernaryOps.WHERE](r[src[2]], val, r[src[3]], dtype)
- kk(f"{self.render_dtype(dtype)} {ssa('val',u)} = {val};")
- elif uop is UOps.PHI:
- kk(f"{r[src[0]]} = {r[src[1]]};")
- r[u] = r[src[0]]
- elif uop in {UOps.CAST, UOps.BITCAST, UOps.VECTORIZE}:
- assert len(src) == 1 or (uop is UOps.VECTORIZE and len(src) > 1), "Invalid source length for operation"
- if uop is UOps.BITCAST:
- precast = ssa('precast')
- kk(f"{self.render_dtype(cast(DType, src[0].dtype))} {precast} = {r[src[0]]};")
- val = self.render_cast(precast, dtype, bitcast=True)
- elif uop is UOps.CAST: val = self.render_cast(r[src[0]], dtype, bitcast=False)
- else: val = self.render_vectorize([r[x] for x in src], dtype)
- if child_count[u] <= 1: r[u] = val
- else: kk(f"{self.render_dtype(dtype)} {ssa('cast',u)} = {val};")
- elif uop is UOps.DEFINE_LOCAL:
- kk(self.render_local(args[0], dtype, args[1]))
- r[u] = args[0]
- elif uop is UOps.DEFINE_VAR:
- assert args.expr not in seen_vars, f"duplicate variable {args.expr}"
- seen_vars.add(args.expr)
- bufs.append((args.expr, (dtype,False)))
- r[u] = args.expr
- elif uop is UOps.DEFINE_GLOBAL:
- bufs.append((nm:=f"data{args[0]}", (dtype,args[1])))
- r[u] = nm
- elif uop is UOps.WMMA: kk(f"{self.render_dtype(dtype)} {ssa('wmma',u)} = __{args[0]}({r[src[0]]}, {r[src[1]]}, {r[src[2]]});")
- elif uop is UOps.DEFINE_ACC: kk(f"{self.render_dtype(dtype)} {ssa('acc',u)} = {self.render_const(src[0].arg, dtype)};")
- elif uop is UOps.CONST: r[u] = self.render_const(args, dtype) if args >= 0 else f"({self.render_const(args, dtype)})"
- elif uop is UOps.GEP:
- assert src[0].dtype is not None
- from_ssa = src[0].op in {UOps.LOAD, UOps.WMMA, UOps.DEFINE_ACC}
- r[u] = (r[src[0]] if from_ssa else f"{(r[src[0]])}") + (f"[{args}]" if src[0].dtype.count > 4 else f".{'xyzw'[args]}")
- else: raise RuntimeError(f"failed to render {u}")
- return self.render_kernel(name, kernel, bufs, uops)
- class ClangRenderer(CStyleLanguage):
- device = "CLANG"
- supports_float4 = False
- has_local = False
- global_max = None
- # language options
- buffer_suffix = " restrict"
- type_map = {dtypes.bool:"_Bool", dtypes.half:"__fp16"}
- code_for_op = {**CStyleLanguage().code_for_op, BinaryOps.MAX: lambda a,b,dtype: f"(({a}>{b})?{a}:{b})"}
- class OpenCLRenderer(CStyleLanguage):
- device = "GPU"
- # language options
- kernel_prefix = "__kernel "
- buffer_prefix = "__global "
- smem_align = "__attribute__ ((aligned (16))) "
- smem_prefix = "__local "
- barrier = "barrier(CLK_LOCAL_MEM_FENCE);"
- float4 = "(float4)"
- code_for_workitem = {"g": lambda x: f"get_group_id({x})", "l": lambda x: f"get_local_id({x})", "i": lambda x: f"get_global_id({x})"}
- uses_vload = True
- type_map = { dtypes.uint8: "uchar", dtypes.uint32: "uint", dtypes.uint16: "ushort", dtypes.uint64: "ulong" }
- def render_cast(self, x, var_dtype, bitcast=False) -> str:
- return f"as_{self.render_dtype(var_dtype)}({x})" if bitcast else super().render_cast(x, var_dtype)
- def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
- if any(uop.dtype == dtypes.half for uop in uops): prefix = ["#pragma OPENCL EXTENSION cl_khr_fp16 : enable"]
- return super().render_kernel(function_name, kernel, bufs, uops, prefix)
- class MetalRenderer(CStyleLanguage):
- device = "METAL"
- shared_max = 32768
- tensor_cores = [TensorCore(dims=(8,8,8), threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[[2],[2],[2]], dtype_in=di, dtype_out=do) for (di, do) in [(dtypes.float, dtypes.float), (dtypes.half, dtypes.float), (dtypes.half, dtypes.half)]] # noqa: E501
- def __init__(self): self.tensor_cores = MetalRenderer.tensor_cores if os.uname().machine == "arm64" else []
- # language options
- kernel_prefix = "kernel "
- buffer_prefix = "device "
- smem_prefix = "threadgroup "
- arg_int_prefix = "constant int&"
- barrier = "threadgroup_barrier(mem_flags::mem_threadgroup);"
- float4 = "float4"
- uses_ptr_arithmetic = True
- code_for_workitem = {"g": lambda x: f"gid.{chr(120+x)}", "l": lambda x: f"lid.{chr(120+x)}"}
- # uint3 used for gid/lid - TODO: this should probably be `ushort3 lid [[thread_position_in_threadgroup]]`
- extra_args = ['uint3 gid [[threadgroup_position_in_grid]]', 'uint3 lid [[thread_position_in_threadgroup]]']
- type_map = {dtypes.bfloat16: "bfloat"}
- code_for_op = {**CStyleLanguage().code_for_op,
- BinaryOps.MAX: lambda a,b,dtype: f"(bfloat)max((float){a},(float){b})" if dtype == dtypes.bfloat16 else f"max({a},{b})",
- UnaryOps.SQRT: lambda x,dtype: f"(bfloat)sqrt({x})" if dtype == dtypes.bfloat16 else f"sqrt({x})",
- UnaryOps.EXP2: lambda x,dtype: f"(bfloat)exp2({x})" if dtype == dtypes.bfloat16 else f"exp2({x})",
- UnaryOps.LOG2: lambda x,dtype: f"(bfloat)log2({x})" if dtype == dtypes.bfloat16 else f"log2({x})",
- UnaryOps.SIN: lambda x,dtype: f"(bfloat)precise::sin({x})" if dtype == dtypes.bfloat16 else f"precise::sin({x})",}
- def render_cast(self, x:str, var_dtype:DType, bitcast=False) -> str:
- return f"as_type<{self.render_dtype(var_dtype)}>({x})" if bitcast else super().render_cast(x, var_dtype)
- def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
- prefix, wmma_args = ["#include <metal_stdlib>","using namespace metal;"], set([uop.arg for uop in uops if uop.op is UOps.WMMA])
- for arg in wmma_args: prefix.append(f"""{arg[3].name}2 __{arg[0]}({arg[2].name}2 m, {arg[2].name}2 n, {arg[3].name}2 o) {{
- simdgroup_{arg[3].name}8x8 a,b,c; a.thread_elements()[0] = m.x; a.thread_elements()[1] = m.y; b.thread_elements()[0] = n.x;
- b.thread_elements()[1] = n.y; c.thread_elements()[0] = o.x; c.thread_elements()[1] = o.y; simdgroup_multiply_accumulate(c, a, b, c);
- return {arg[3].name}2(c.thread_elements()[0], c.thread_elements()[1]);\n}}""")
- return super().render_kernel(function_name, kernel, bufs, uops, prefix)
- code_for_op_half = {UnaryOps.RECIP: lambda x,dtype: f"hrcp({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"1/{x}",
- BinaryOps.MAX: lambda a,b,dtype: f"__hmax({a},{b})" if dtype in (dtypes.half, dtypes.bfloat16) else f"max({a},{b})",
- UnaryOps.SQRT: lambda x,dtype: f"hsqrt({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sqrt({x})",
- UnaryOps.SIN: lambda x,dtype: f"hsin({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sin({x})",
- UnaryOps.LOG2: lambda x,dtype: f"hlog2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"log2({x})",
- UnaryOps.EXP2: lambda x,dtype: f"hexp2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"exp2({x})",}
- _nms = "xyzwabcdefghijkl"
- def _make_cuda_dtype(base_type, name, cnt):
- vec, elems, header = f"{name}{cnt}", ', '.join(_nms[:cnt]), ', '.join([f"{base_type} {x}" for x in _nms[:cnt]])
- return f"struct {vec} {{ {base_type} {elems}; }}; __device__ {vec} make_{vec}({header}) {{ {vec} r={{{elems}}}; return r; }}"
- class CUDARenderer(CStyleLanguage):
- device = "CUDA"
- global_max = (2147483647, 65535, 65535)
- local_max = (1024, 1024, 64)
- shared_max = 49152
- tensor_cores = [TensorCore(dims=(8,16,16), threads=[(0,2),(0,2),(1,2),(1,2),(0,2)], thread_local_sizes=[[2,2,2],[2,2],[2,2]], dtype_in=di, dtype_out=do) for (di, do) in ([(dtypes.half, dtypes.float), (dtypes.bfloat16, dtypes.float)])] # noqa: E501
- def __init__(self, arch:str): self.tensor_cores = CUDARenderer.tensor_cores if int(arch[3:]) >= 80 else []
- # language options
- kernel_prefix = "extern \"C\" __global__ "
- smem_prefix = "__shared__ "
- smem_prefix_for_cast = False
- barrier = "__syncthreads();"
- float4 = "make_float4"
- code_for_workitem = {"g": lambda x: f"blockIdx.{chr(120+x)}", "l": lambda x: f"threadIdx.{chr(120+x)}",
- "i": lambda x: f"(blockIdx.{chr(120+x)}*blockDim.{chr(120+x)}+threadIdx.{chr(120+x)})"}
- code_for_op = {**CStyleLanguage().code_for_op, **code_for_op_half}
- type_map = {dtypes.bfloat16: "nv_bfloat16"}
- def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
- # TODO: why is dtypes.bfloat16.name == "__bf16"? would be easier not override dtypes.name
- dt_map = { dtypes.float: ("float","f32"), dtypes.half: ("half","f16"), dtypes.bfloat16: ("bfloat16","bf16"), }
- prefix = ["#define INFINITY (__int_as_float(0x7f800000))","#define NAN (__int_as_float(0x7fffffff))"]
- if any(uop.dtype == dtypes.half for uop in uops):
- prefix += ["#include <cuda_fp16.h>"] + [_make_cuda_dtype("half", "half", x) for x in [4, 8]]
- if any(uop.dtype == dtypes.bfloat16 for uop in uops):
- prefix += ["#include <cuda_bf16.h>"] + [_make_cuda_dtype("nv_bfloat16", "bfloat16", x) for x in [4, 8]]
- # TODO: this has to be way better to generate for arbitrary M,N,K: use arg[1] for MNK, use arg[4] for vec sizes, encode register packing
- for arg in dedup([uop.arg for uop in uops if uop.op is UOps.WMMA]):
- fn, ti, to, ci, co = arg[0], dt_map[arg[2]][0], dt_map[arg[3]][0], dt_map[arg[2]][1], dt_map[arg[3]][1]
- prefix.append(f"""__device__ {to}4 __{fn}({ti}8 a, {ti}4 b, {to}4 c) {{ int *a_pk = (int *) (&a), *b_pk = (int *) (&b);
- asm( "mma.sync.aligned.m16n8k16.row.col.{co}.{ci}.{ci}.{co} {{ %0, %1, %2, %3 }}, {{ %4, %5, %6, %7 }}, {{ %8, %9 }}, {{ %0, %1, %2, %3 }};"
- : "+f"(c.x), "+f"(c.y), "+f"(c.z), "+f"(c.w) : "r"(a_pk[0]), "r"(a_pk[1]), "r"(a_pk[2]), "r"(a_pk[3]), "r"(b_pk[0]), "r"(b_pk[1]) );
- return c;}}""")
- return super().render_kernel(function_name, kernel, bufs, uops, prefix=prefix)
- code_for_op_hip = { UnaryOps.SQRT: lambda x,dtype: f"__ocml_sqrt_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
- UnaryOps.SIN: lambda x,dtype: f"__ocml_sin_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
- UnaryOps.LOG2: lambda x,dtype: f"__ocml_log2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
- UnaryOps.EXP2: lambda x,dtype: f"__ocml_exp2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
- # TODO: MAX with int uses fmax_f32?
- BinaryOps.MAX: lambda a,b,dtype: f"__ocml_fmax_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32) }({a},{b})",}
- def _make_hip_code_for_op():
- def wrapper(key, func):
- def cast_bf16(*args):
- if args[-1] == dtypes.bfloat16:
- operands = tuple(f"(float)({arg})" for arg in (args[1:-1] if key is TernaryOps.WHERE else args[:-1]))
- return f"(hip_bfloat16)({func(*(((args[0],) if key is TernaryOps.WHERE else ()) + operands), dtypes.float)})"
- return func(*args)
- return cast_bf16
- return { k:wrapper(k,v) for k,v in {**CStyleLanguage().code_for_op, **code_for_op_hip}.items() }
- def _make_hip_dtype(base_type, name, cnt):
- elems, header = ', '.join(_nms[:cnt]), ', '.join([f"{base_type} {x}" for x in _nms[:cnt]])
- return f"typedef {base_type} {name}{cnt} __attribute__((ext_vector_type({cnt})));\n" + \
- f"static inline __attribute__((device)) {name}{cnt} make_{name}{cnt}({header}) {{ return {{{elems}}}; }}"
- class AMDRenderer(CStyleLanguage):
- device = "AMD"
- shared_max = 65536
- tensor_cores = [TensorCore(dims=(16,16,16), threads=[(0,8),(0,2),(1,2)], thread_local_sizes=[[16],[16],[4,2]], dtype_in=di, dtype_out=do) for (di, do) in [(dtypes.half, dtypes.float), (dtypes.half, dtypes.half)]] # noqa: E501
- # language options
- kernel_prefix = """extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_id(unsigned int);
- extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_group_id(unsigned int);
- extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_size(unsigned int);
- extern "C" {\n""" + "".join([
- f""" __attribute__((device)) __attribute__((const)) {dt} __ocml_fmax_f{n}({dt}, {dt});
- __attribute__((device)) __attribute__((pure)) {dt} __ocml_exp2_f{n}({dt});
- __attribute__((device)) __attribute__((pure)) {dt} __ocml_log2_f{n}({dt});
- __attribute__((device)) __attribute__((const)) {dt} __ocml_sqrt_f{n}({dt});
- __attribute__((device)) {dt} __ocml_sin_f{n}({dt});\n""" for dt,n in [("float",32), ("double",64), ("_Float16",16)]]) +\
- '}\nextern "C" __attribute__((global))'
- code_for_workitem = {"g": lambda x: f"__ockl_get_group_id({x})", "l": lambda x: f"__ockl_get_local_id({x})",
- "i": lambda x: f"(__ockl_get_group_id({x})*__ockl_get_local_size({x})+__ockl_get_local_id({x}))"}
- code_for_op = _make_hip_code_for_op()
- smem_prefix = "__attribute__((shared))"
- barrier = '__builtin_amdgcn_fence(__ATOMIC_RELEASE, "workgroup");' + '__builtin_amdgcn_s_barrier();' + \
- '__builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "workgroup");'
- float4 = "make_float4"
- uses_ptr_arithmetic = False # NOTE: this fixes TestLinearizerOverflowAlt
- type_map = {dtypes.bfloat16: "hip_bfloat16"}
- def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
- prefix = ["#define INFINITY (__builtin_inff())", "#define NAN (__builtin_nanf(\"\"))", "typedef long unsigned int size_t;"]
- vec_dts = [("float", "float", 2), ("float", "float", 4), ("float", "float", 8), ("signed int", "int", 4), ("signed int", "int", 2)]
- # TODO: add BF16 vec dts
- if any(uop.dtype == dtypes.bfloat16 for uop in uops): prefix.append("""
- struct hip_bfloat16 {
- unsigned short data;
- inline __attribute__((device)) hip_bfloat16(float val) {
- union { float fp32; unsigned int u32; } u = {val};
- if (~u.u32 & 0x7f800000) { u.u32 += 0x7fff + ((u.u32 >> 16) & 1); } else if (u.u32 & 0xffff) { u.u32 |= 0x10000; }
- data = (u.u32 >> 16);
- }
- inline __attribute__((device)) operator float() const {
- unsigned int uval = data << 16;
- return *reinterpret_cast<float*>(&uval);
- }
- };
- static inline __attribute__((device)) bool operator<(hip_bfloat16 a, hip_bfloat16 b) { return ((float)a) < ((float)b); }
- static inline __attribute__((device)) bool operator==(hip_bfloat16 a, hip_bfloat16 b) { return ((float)a) == ((float)b); }
- """)
- if any(uop.dtype == dtypes.half for uop in uops):
- prefix.append("#define half _Float16")
- vec_dts += [("_Float16", "half", 2), ("_Float16", "half", 4), ("_Float16", "half", 8), ("_Float16", "half", 16)]
- prefix += [_make_hip_dtype(*x) for x in vec_dts]
- for arg in dedup([uop.arg for uop in uops if uop.op is UOps.WMMA]): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper
- if arg[3] == dtypes.float: prefix.append(f"#define __{arg[0]} __builtin_amdgcn_wmma_f32_16x16x16_f16_w32")
- else: prefix.append(f"static inline __attribute__((device)) half8 __{arg[0]}"+"""(half16 a, half16 b, half8 c) {
- half16 c_frag = {}; half8 d; for (int n = 0; n < 8; n++) { c_frag[n*2] = c[n]; }
- c_frag = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a, b, c_frag, false);
- for (int n = 0; n < 8; n++) { d[n] = c_frag[n*2]; } return d;\n}""")
- return super().render_kernel(function_name, kernel, bufs, uops, prefix)
- def get_kernel_modifier(self, uops:UOpGraph) -> str:
- requiredMaxThreadsPerBlock = prod(u.arg[2] for u in uops if u.op is UOps.SPECIAL and u.arg[1][0] == "l")
- # https://clang.llvm.org/docs/AttributeReference.html#amdgpu-flat-work-group-size
- # NOTE: this makes hlb_cifar10 twice as fast, there may be more gains in tweaking these parameters
- return f"__attribute__((amdgpu_flat_work_group_size(1, {requiredMaxThreadsPerBlock})))"
- class NVRenderer(CUDARenderer): device = "NV"
- class HIPRenderer(AMDRenderer): device = "HIP"
|