cstyle.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394
  1. from typing import Dict, List, Optional, Tuple, Union, DefaultDict, cast, Literal, Callable
  2. import os, math
  3. from collections import defaultdict, Counter
  4. from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
  5. from tinygrad.helpers import strip_parens, getenv, prod, dedup
  6. from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, ConstType
  7. from tinygrad.codegen.uops import UOps, UOp
  8. from tinygrad.codegen.uopgraph import UOpGraph
  9. from tinygrad.renderer import Renderer, TensorCore
  10. class CStyleLanguage(Renderer):
  11. kernel_prefix: str = ""
  12. buffer_prefix: str = ""
  13. buffer_suffix: str = ""
  14. smem_align: str = ""
  15. smem_prefix: str = ""
  16. smem_prefix_for_cast: bool = True
  17. arg_int_prefix: str = "const int"
  18. barrier: str = ""
  19. code_for_workitem: Dict[Union[Literal["g"], Literal["l"], Literal["i"]], Callable] = {}
  20. extra_args: List[str] = []
  21. float4: Optional[str] = None
  22. uses_vload: bool = False
  23. uses_ptr_arithmetic: bool = False
  24. type_map: Dict[DType, str] = {}
  25. code_for_op: Dict = {
  26. UnaryOps.NEG: lambda x,dtype: f"(!{x})" if dtype == dtypes.bool else f"(-{x})", UnaryOps.SQRT: lambda x,dtype: f"sqrt({x})",
  27. UnaryOps.RECIP: lambda x,dtype: f"(1/{x})",
  28. UnaryOps.EXP2: lambda x,dtype: f"exp2({x})", UnaryOps.LOG2: lambda x,dtype: f"log2({x})", UnaryOps.SIN: lambda x,dtype: f"sin({x})",
  29. BinaryOps.ADD: lambda a,b,dtype: f"({a}+{b})", BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})",
  30. BinaryOps.IDIV: lambda a,b,dtype: f"({a}/{b})", BinaryOps.MUL: lambda a,b,dtype: f"({a}*{b})", BinaryOps.MOD: lambda a,b,dtype: f"({a}%{b})",
  31. BinaryOps.CMPLT: lambda a,b,dtype: f"({a}<{b})", BinaryOps.CMPNE: lambda a,b,dtype: f"({a}!={b})", BinaryOps.XOR: lambda a,b,dtype: f"({a}^{b})",
  32. BinaryOps.AND: lambda a,b,dtype: f"({a}&{b})", BinaryOps.OR: lambda a,b,dtype: f"({a}|{b})",
  33. TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})"}
  34. # returns a str expression of the casted xs with the given type
  35. def render_cast(self, x:str, var_dtype:DType, bitcast=False) -> str:
  36. if bitcast: return f"(*(({self.buffer_prefix}{self.render_dtype(var_dtype)}*)&{x}))"
  37. return f"({self.render_dtype(var_dtype)})({x})"
  38. # returns a str expression of the vectorized xs with the given type
  39. def render_vectorize(self, x:List[str], var_dtype:DType) -> str:
  40. assert len(x) == var_dtype.count, f"cast is wrong size {len(x)} != {var_dtype.count}"
  41. assert self.float4 is not None, "vectorized cast is not supported on this platform"
  42. return f"{self.float4.replace('float4', self.render_dtype(var_dtype))}({','.join(x)})"
  43. # returns a str expression of the const with the given type
  44. def render_const(self, x:ConstType, dtype:DType) -> str:
  45. if math.isnan(x): val = "NAN"
  46. elif math.isinf(x): val = ("-" if x < 0 else "") + "INFINITY"
  47. elif dtype.scalar() == dtypes.bool: val = "1" if x else "0"
  48. elif dtype.scalar() == dtypes.float: val = f"{x}f"
  49. else: val = str(x)
  50. if dtype.count > 1: return self.render_vectorize([val] * dtype.count, dtype)
  51. return (self.render_cast(val, dtype) if dtype not in [dtypes.float, dtypes.int, dtypes.bool] else val)
  52. # returns a str expression of the loaded value with the output type
  53. def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str:
  54. if isinstance(buf_dtype, ImageDType):
  55. assert output_dtype == dtypes.float.vec(4), f"images must be float4, getting {output_dtype}"
  56. return f"read_imagef({buf_name}, smp, {idx})"
  57. if self.uses_vload and buf_dtype.scalar() == dtypes.float16 and output_dtype.scalar() != dtypes.float16:
  58. return f"vload_half{'' if output_dtype.count == 1 else str(output_dtype.count)}(0, {buf_name}+{idx})"
  59. if output_dtype.count > 1:
  60. return f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{self.render_dtype(buf_dtype)}{output_dtype.count}*)({buf_name}+{idx}))" # noqa: E501
  61. return f"*({buf_name}+{idx})" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}]"
  62. def get_kernel_modifier(self, uops:UOpGraph) -> str: return ""
  63. def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:UOpGraph, prefix=None) -> str:
  64. tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,(dtype,_) in bufs) else "" # noqa: E501
  65. buftypes = [(name,f"{'write_only' if mutable else 'read_only'} image2d_t" if dtype.name.startswith('image') else
  66. ("" if mutable else "const ")+self.buffer_prefix+self.render_dtype(dtype)+"*"+self.buffer_suffix if isinstance(dtype, PtrDType) else
  67. self.arg_int_prefix if dtype == dtypes.int else None) for name,(dtype,mutable) in bufs]
  68. prg = ''.join([f"{self.kernel_prefix}void {self.get_kernel_modifier(uops)}{function_name}(",] +
  69. [', '.join([f'{t} {name}' for name,t in buftypes] + self.extra_args)] +
  70. [") {\n" + tmp] + ['\n'.join(kernel), "\n}"])
  71. return prg if prefix is None else "\n".join(prefix)+f"\n{prg}"
  72. # returns a str statement that does the store
  73. def render_store(self, buf_name:str, buf_dtype:DType, var_name:str, var_dtype:DType, idx:str, local=False) -> str:
  74. if isinstance(buf_dtype, ImageDType):
  75. assert var_dtype == dtypes.float.vec(4), f"images must be float4, getting {var_dtype}"
  76. return f"write_imagef({buf_name}, {idx}, {var_name});"
  77. if self.uses_vload and buf_dtype.scalar() == dtypes.float16 and var_dtype.scalar() != dtypes.float16:
  78. return f"vstore_half{'' if var_dtype.count == 1 else str(var_dtype.count)}({var_name}, 0, {buf_name}+{idx});"
  79. if var_dtype.count > 1:
  80. prefix = self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix
  81. return f"*(({prefix}{self.render_dtype(buf_dtype)}{var_dtype.count}*)({buf_name}+{idx})) = {var_name};"
  82. return f"*({buf_name}+{idx}) = {var_name};" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}] = {var_name};"
  83. def render_local(self, name:str, dtype:DType, size:int): return self.smem_align + self.smem_prefix + f"{self.render_dtype(dtype)} {name}[{size}];"
  84. def render_dtype(self, var_dtype:DType) -> str: return self.type_map.get(var_dtype, var_dtype.name)
  85. def render(self, name:str, uops:UOpGraph) -> str:
  86. kernel = []
  87. bufs: List[Tuple[str, Tuple[DType, bool]]] = []
  88. depth = 1
  89. def kk(s): kernel.append(" "*depth+s)
  90. c: DefaultDict[str, int] = defaultdict(int)
  91. r: Dict[UOp, str] = {}
  92. def ssa(prefix:str, u:Optional[UOp]=None):
  93. nonlocal c, r
  94. ret = f"{prefix}{c[prefix]}"
  95. if u is not None: r[u] = ret
  96. c[prefix] += 1
  97. return ret
  98. child_count = Counter(v for ru in uops for v in ru.src)
  99. seen_vars = set()
  100. for u in uops:
  101. uop,dtype,src,args = u.op,u.dtype,u.src,u.arg
  102. # these four uops don't have output dtypes
  103. if uop is UOps.IF:
  104. kk(f"if ({r[src[0]]}) {{")
  105. depth += 1
  106. elif uop is UOps.BARRIER: kk(self.barrier)
  107. elif uop in {UOps.ENDRANGE, UOps.ENDIF}:
  108. depth -= 1
  109. kk("}")
  110. elif uop is UOps.STORE:
  111. assert src[0].dtype is not None and src[2].dtype is not None
  112. rendered_store = self.render_store(r[src[0]], src[0].dtype, r[src[2]], src[2].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL)
  113. kk(f"if ({r[src[3]]}) {{ {rendered_store} }}" if len(src) > 3 else rendered_store)
  114. else:
  115. assert dtype is not None, f"None dtype for uop {uop}"
  116. if uop is UOps.RANGE:
  117. kk(f"for (int {(expr := ssa('ridx',u))} = {r[src[0]]}; {expr} < {r[src[1]]}; {expr}++) {{")
  118. depth += 1
  119. elif uop is UOps.ALU:
  120. # remove parens if ALU types are the same. TODO: can do more here
  121. if args in {BinaryOps.ADD,BinaryOps.MUL,BinaryOps.XOR}: operands = [strip_parens(r[v]) if v.arg == args else r[v]for v in src]
  122. else: operands = [r[v] for v in src]
  123. val = self.code_for_op[args](*operands, dtype)
  124. assert child_count[u] != 0, f"childless ALU op found {u}"
  125. # TODO: fix index rendering issue. fix clang nested max macro issue
  126. if child_count[u] <= 1 and args is not BinaryOps.MAX and not getenv("EXPAND_SSA"): r[u] = val
  127. else: kk(f"{self.render_dtype(dtype)} {ssa('alu',u)} = {val};")
  128. elif uop is UOps.SPECIAL:
  129. kk(f"int {args[1]} = {self.code_for_workitem[args[1][0]](args[0])}; /* {args[2]} */")
  130. r[u] = args[1]
  131. elif uop is UOps.LOAD:
  132. val = self.render_load(dtype, r[src[0]], src[0].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL)
  133. # NOTE: this relies on the load not happening if it's in the unselected branch
  134. if len(src) > 3 and src[2].op is UOps.ALU: val = self.code_for_op[TernaryOps.WHERE](r[src[2]], val, r[src[3]], dtype)
  135. kk(f"{self.render_dtype(dtype)} {ssa('val',u)} = {val};")
  136. elif uop is UOps.PHI:
  137. kk(f"{r[src[0]]} = {r[src[1]]};")
  138. r[u] = r[src[0]]
  139. elif uop in {UOps.CAST, UOps.BITCAST, UOps.VECTORIZE}:
  140. assert len(src) == 1 or (uop is UOps.VECTORIZE and len(src) > 1), "Invalid source length for operation"
  141. if uop is UOps.BITCAST:
  142. precast = ssa('precast')
  143. kk(f"{self.render_dtype(cast(DType, src[0].dtype))} {precast} = {r[src[0]]};")
  144. val = self.render_cast(precast, dtype, bitcast=True)
  145. elif uop is UOps.CAST: val = self.render_cast(r[src[0]], dtype, bitcast=False)
  146. else: val = self.render_vectorize([r[x] for x in src], dtype)
  147. if child_count[u] <= 1: r[u] = val
  148. else: kk(f"{self.render_dtype(dtype)} {ssa('cast',u)} = {val};")
  149. elif uop is UOps.DEFINE_LOCAL:
  150. kk(self.render_local(args[0], dtype, args[1]))
  151. r[u] = args[0]
  152. elif uop is UOps.DEFINE_VAR:
  153. assert args.expr not in seen_vars, f"duplicate variable {args.expr}"
  154. seen_vars.add(args.expr)
  155. bufs.append((args.expr, (dtype,False)))
  156. r[u] = args.expr
  157. elif uop is UOps.DEFINE_GLOBAL:
  158. bufs.append((nm:=f"data{args[0]}", (dtype,args[1])))
  159. r[u] = nm
  160. elif uop is UOps.WMMA: kk(f"{self.render_dtype(dtype)} {ssa('wmma',u)} = __{args[0]}({r[src[0]]}, {r[src[1]]}, {r[src[2]]});")
  161. elif uop is UOps.DEFINE_ACC: kk(f"{self.render_dtype(dtype)} {ssa('acc',u)} = {self.render_const(src[0].arg, dtype)};")
  162. elif uop is UOps.CONST: r[u] = self.render_const(args, dtype) if args >= 0 else f"({self.render_const(args, dtype)})"
  163. elif uop is UOps.GEP:
  164. assert src[0].dtype is not None
  165. from_ssa = src[0].op in {UOps.LOAD, UOps.WMMA, UOps.DEFINE_ACC}
  166. r[u] = (r[src[0]] if from_ssa else f"{(r[src[0]])}") + (f"[{args}]" if src[0].dtype.count > 4 else f".{'xyzw'[args]}")
  167. else: raise RuntimeError(f"failed to render {u}")
  168. return self.render_kernel(name, kernel, bufs, uops)
  169. class ClangRenderer(CStyleLanguage):
  170. device = "CLANG"
  171. supports_float4 = False
  172. has_local = False
  173. global_max = None
  174. # language options
  175. buffer_suffix = " restrict"
  176. type_map = {dtypes.bool:"_Bool", dtypes.half:"__fp16"}
  177. code_for_op = {**CStyleLanguage().code_for_op, BinaryOps.MAX: lambda a,b,dtype: f"(({a}>{b})?{a}:{b})"}
  178. class OpenCLRenderer(CStyleLanguage):
  179. device = "GPU"
  180. # language options
  181. kernel_prefix = "__kernel "
  182. buffer_prefix = "__global "
  183. smem_align = "__attribute__ ((aligned (16))) "
  184. smem_prefix = "__local "
  185. barrier = "barrier(CLK_LOCAL_MEM_FENCE);"
  186. float4 = "(float4)"
  187. code_for_workitem = {"g": lambda x: f"get_group_id({x})", "l": lambda x: f"get_local_id({x})", "i": lambda x: f"get_global_id({x})"}
  188. uses_vload = True
  189. type_map = { dtypes.uint8: "uchar", dtypes.uint32: "uint", dtypes.uint16: "ushort", dtypes.uint64: "ulong" }
  190. def render_cast(self, x, var_dtype, bitcast=False) -> str:
  191. return f"as_{self.render_dtype(var_dtype)}({x})" if bitcast else super().render_cast(x, var_dtype)
  192. def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
  193. if any(uop.dtype == dtypes.half for uop in uops): prefix = ["#pragma OPENCL EXTENSION cl_khr_fp16 : enable"]
  194. return super().render_kernel(function_name, kernel, bufs, uops, prefix)
  195. class MetalRenderer(CStyleLanguage):
  196. device = "METAL"
  197. shared_max = 32768
  198. tensor_cores = [TensorCore(dims=(8,8,8), threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[[2],[2],[2]], dtype_in=di, dtype_out=do) for (di, do) in [(dtypes.float, dtypes.float), (dtypes.half, dtypes.float), (dtypes.half, dtypes.half)]] # noqa: E501
  199. def __init__(self): self.tensor_cores = MetalRenderer.tensor_cores if os.uname().machine == "arm64" else []
  200. # language options
  201. kernel_prefix = "kernel "
  202. buffer_prefix = "device "
  203. smem_prefix = "threadgroup "
  204. arg_int_prefix = "constant int&"
  205. barrier = "threadgroup_barrier(mem_flags::mem_threadgroup);"
  206. float4 = "float4"
  207. uses_ptr_arithmetic = True
  208. code_for_workitem = {"g": lambda x: f"gid.{chr(120+x)}", "l": lambda x: f"lid.{chr(120+x)}"}
  209. # uint3 used for gid/lid - TODO: this should probably be `ushort3 lid [[thread_position_in_threadgroup]]`
  210. extra_args = ['uint3 gid [[threadgroup_position_in_grid]]', 'uint3 lid [[thread_position_in_threadgroup]]']
  211. type_map = {dtypes.bfloat16: "bfloat"}
  212. code_for_op = {**CStyleLanguage().code_for_op,
  213. BinaryOps.MAX: lambda a,b,dtype: f"(bfloat)max((float){a},(float){b})" if dtype == dtypes.bfloat16 else f"max({a},{b})",
  214. UnaryOps.SQRT: lambda x,dtype: f"(bfloat)sqrt({x})" if dtype == dtypes.bfloat16 else f"sqrt({x})",
  215. UnaryOps.EXP2: lambda x,dtype: f"(bfloat)exp2({x})" if dtype == dtypes.bfloat16 else f"exp2({x})",
  216. UnaryOps.LOG2: lambda x,dtype: f"(bfloat)log2({x})" if dtype == dtypes.bfloat16 else f"log2({x})",
  217. UnaryOps.SIN: lambda x,dtype: f"(bfloat)precise::sin({x})" if dtype == dtypes.bfloat16 else f"precise::sin({x})",}
  218. def render_cast(self, x:str, var_dtype:DType, bitcast=False) -> str:
  219. return f"as_type<{self.render_dtype(var_dtype)}>({x})" if bitcast else super().render_cast(x, var_dtype)
  220. def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
  221. prefix, wmma_args = ["#include <metal_stdlib>","using namespace metal;"], set([uop.arg for uop in uops if uop.op is UOps.WMMA])
  222. for arg in wmma_args: prefix.append(f"""{arg[3].name}2 __{arg[0]}({arg[2].name}2 m, {arg[2].name}2 n, {arg[3].name}2 o) {{
  223. simdgroup_{arg[3].name}8x8 a,b,c; a.thread_elements()[0] = m.x; a.thread_elements()[1] = m.y; b.thread_elements()[0] = n.x;
  224. b.thread_elements()[1] = n.y; c.thread_elements()[0] = o.x; c.thread_elements()[1] = o.y; simdgroup_multiply_accumulate(c, a, b, c);
  225. return {arg[3].name}2(c.thread_elements()[0], c.thread_elements()[1]);\n}}""")
  226. return super().render_kernel(function_name, kernel, bufs, uops, prefix)
  227. code_for_op_half = {UnaryOps.RECIP: lambda x,dtype: f"hrcp({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"1/{x}",
  228. BinaryOps.MAX: lambda a,b,dtype: f"__hmax({a},{b})" if dtype in (dtypes.half, dtypes.bfloat16) else f"max({a},{b})",
  229. UnaryOps.SQRT: lambda x,dtype: f"hsqrt({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sqrt({x})",
  230. UnaryOps.SIN: lambda x,dtype: f"hsin({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sin({x})",
  231. UnaryOps.LOG2: lambda x,dtype: f"hlog2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"log2({x})",
  232. UnaryOps.EXP2: lambda x,dtype: f"hexp2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"exp2({x})",}
  233. _nms = "xyzwabcdefghijkl"
  234. def _make_cuda_dtype(base_type, name, cnt):
  235. vec, elems, header = f"{name}{cnt}", ', '.join(_nms[:cnt]), ', '.join([f"{base_type} {x}" for x in _nms[:cnt]])
  236. return f"struct {vec} {{ {base_type} {elems}; }}; __device__ {vec} make_{vec}({header}) {{ {vec} r={{{elems}}}; return r; }}"
  237. class CUDARenderer(CStyleLanguage):
  238. device = "CUDA"
  239. global_max = (2147483647, 65535, 65535)
  240. local_max = (1024, 1024, 64)
  241. shared_max = 49152
  242. tensor_cores = [TensorCore(dims=(8,16,16), threads=[(0,2),(0,2),(1,2),(1,2),(0,2)], thread_local_sizes=[[2,2,2],[2,2],[2,2]], dtype_in=di, dtype_out=do) for (di, do) in ([(dtypes.half, dtypes.float), (dtypes.bfloat16, dtypes.float)])] # noqa: E501
  243. def __init__(self, arch:str): self.tensor_cores = CUDARenderer.tensor_cores if int(arch[3:]) >= 80 else []
  244. # language options
  245. kernel_prefix = "extern \"C\" __global__ "
  246. smem_prefix = "__shared__ "
  247. smem_prefix_for_cast = False
  248. barrier = "__syncthreads();"
  249. float4 = "make_float4"
  250. code_for_workitem = {"g": lambda x: f"blockIdx.{chr(120+x)}", "l": lambda x: f"threadIdx.{chr(120+x)}",
  251. "i": lambda x: f"(blockIdx.{chr(120+x)}*blockDim.{chr(120+x)}+threadIdx.{chr(120+x)})"}
  252. code_for_op = {**CStyleLanguage().code_for_op, **code_for_op_half}
  253. type_map = {dtypes.bfloat16: "nv_bfloat16"}
  254. def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
  255. # TODO: why is dtypes.bfloat16.name == "__bf16"? would be easier not override dtypes.name
  256. dt_map = { dtypes.float: ("float","f32"), dtypes.half: ("half","f16"), dtypes.bfloat16: ("bfloat16","bf16"), }
  257. prefix = ["#define INFINITY (__int_as_float(0x7f800000))","#define NAN (__int_as_float(0x7fffffff))"]
  258. if any(uop.dtype == dtypes.half for uop in uops):
  259. prefix += ["#include <cuda_fp16.h>"] + [_make_cuda_dtype("half", "half", x) for x in [4, 8]]
  260. if any(uop.dtype == dtypes.bfloat16 for uop in uops):
  261. prefix += ["#include <cuda_bf16.h>"] + [_make_cuda_dtype("nv_bfloat16", "bfloat16", x) for x in [4, 8]]
  262. # TODO: this has to be way better to generate for arbitrary M,N,K: use arg[1] for MNK, use arg[4] for vec sizes, encode register packing
  263. for arg in dedup([uop.arg for uop in uops if uop.op is UOps.WMMA]):
  264. fn, ti, to, ci, co = arg[0], dt_map[arg[2]][0], dt_map[arg[3]][0], dt_map[arg[2]][1], dt_map[arg[3]][1]
  265. prefix.append(f"""__device__ {to}4 __{fn}({ti}8 a, {ti}4 b, {to}4 c) {{ int *a_pk = (int *) (&a), *b_pk = (int *) (&b);
  266. asm( "mma.sync.aligned.m16n8k16.row.col.{co}.{ci}.{ci}.{co} {{ %0, %1, %2, %3 }}, {{ %4, %5, %6, %7 }}, {{ %8, %9 }}, {{ %0, %1, %2, %3 }};"
  267. : "+f"(c.x), "+f"(c.y), "+f"(c.z), "+f"(c.w) : "r"(a_pk[0]), "r"(a_pk[1]), "r"(a_pk[2]), "r"(a_pk[3]), "r"(b_pk[0]), "r"(b_pk[1]) );
  268. return c;}}""")
  269. return super().render_kernel(function_name, kernel, bufs, uops, prefix=prefix)
  270. code_for_op_hip = { UnaryOps.SQRT: lambda x,dtype: f"__ocml_sqrt_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
  271. UnaryOps.SIN: lambda x,dtype: f"__ocml_sin_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
  272. UnaryOps.LOG2: lambda x,dtype: f"__ocml_log2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
  273. UnaryOps.EXP2: lambda x,dtype: f"__ocml_exp2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
  274. # TODO: MAX with int uses fmax_f32?
  275. BinaryOps.MAX: lambda a,b,dtype: f"__ocml_fmax_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32) }({a},{b})",}
  276. def _make_hip_code_for_op():
  277. def wrapper(key, func):
  278. def cast_bf16(*args):
  279. if args[-1] == dtypes.bfloat16:
  280. operands = tuple(f"(float)({arg})" for arg in (args[1:-1] if key is TernaryOps.WHERE else args[:-1]))
  281. return f"(hip_bfloat16)({func(*(((args[0],) if key is TernaryOps.WHERE else ()) + operands), dtypes.float)})"
  282. return func(*args)
  283. return cast_bf16
  284. return { k:wrapper(k,v) for k,v in {**CStyleLanguage().code_for_op, **code_for_op_hip}.items() }
  285. def _make_hip_dtype(base_type, name, cnt):
  286. elems, header = ', '.join(_nms[:cnt]), ', '.join([f"{base_type} {x}" for x in _nms[:cnt]])
  287. return f"typedef {base_type} {name}{cnt} __attribute__((ext_vector_type({cnt})));\n" + \
  288. f"static inline __attribute__((device)) {name}{cnt} make_{name}{cnt}({header}) {{ return {{{elems}}}; }}"
  289. class AMDRenderer(CStyleLanguage):
  290. device = "AMD"
  291. shared_max = 65536
  292. tensor_cores = [TensorCore(dims=(16,16,16), threads=[(0,8),(0,2),(1,2)], thread_local_sizes=[[16],[16],[4,2]], dtype_in=di, dtype_out=do) for (di, do) in [(dtypes.half, dtypes.float), (dtypes.half, dtypes.half)]] # noqa: E501
  293. # language options
  294. kernel_prefix = """extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_id(unsigned int);
  295. extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_group_id(unsigned int);
  296. extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_size(unsigned int);
  297. extern "C" {\n""" + "".join([
  298. f""" __attribute__((device)) __attribute__((const)) {dt} __ocml_fmax_f{n}({dt}, {dt});
  299. __attribute__((device)) __attribute__((pure)) {dt} __ocml_exp2_f{n}({dt});
  300. __attribute__((device)) __attribute__((pure)) {dt} __ocml_log2_f{n}({dt});
  301. __attribute__((device)) __attribute__((const)) {dt} __ocml_sqrt_f{n}({dt});
  302. __attribute__((device)) {dt} __ocml_sin_f{n}({dt});\n""" for dt,n in [("float",32), ("double",64), ("_Float16",16)]]) +\
  303. '}\nextern "C" __attribute__((global))'
  304. code_for_workitem = {"g": lambda x: f"__ockl_get_group_id({x})", "l": lambda x: f"__ockl_get_local_id({x})",
  305. "i": lambda x: f"(__ockl_get_group_id({x})*__ockl_get_local_size({x})+__ockl_get_local_id({x}))"}
  306. code_for_op = _make_hip_code_for_op()
  307. smem_prefix = "__attribute__((shared))"
  308. barrier = '__builtin_amdgcn_fence(__ATOMIC_RELEASE, "workgroup");' + '__builtin_amdgcn_s_barrier();' + \
  309. '__builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "workgroup");'
  310. float4 = "make_float4"
  311. uses_ptr_arithmetic = False # NOTE: this fixes TestLinearizerOverflowAlt
  312. type_map = {dtypes.bfloat16: "hip_bfloat16"}
  313. def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
  314. prefix = ["#define INFINITY (__builtin_inff())", "#define NAN (__builtin_nanf(\"\"))", "typedef long unsigned int size_t;"]
  315. vec_dts = [("float", "float", 2), ("float", "float", 4), ("float", "float", 8), ("signed int", "int", 4), ("signed int", "int", 2)]
  316. # TODO: add BF16 vec dts
  317. if any(uop.dtype == dtypes.bfloat16 for uop in uops): prefix.append("""
  318. struct hip_bfloat16 {
  319. unsigned short data;
  320. inline __attribute__((device)) hip_bfloat16(float val) {
  321. union { float fp32; unsigned int u32; } u = {val};
  322. if (~u.u32 & 0x7f800000) { u.u32 += 0x7fff + ((u.u32 >> 16) & 1); } else if (u.u32 & 0xffff) { u.u32 |= 0x10000; }
  323. data = (u.u32 >> 16);
  324. }
  325. inline __attribute__((device)) operator float() const {
  326. unsigned int uval = data << 16;
  327. return *reinterpret_cast<float*>(&uval);
  328. }
  329. };
  330. static inline __attribute__((device)) bool operator<(hip_bfloat16 a, hip_bfloat16 b) { return ((float)a) < ((float)b); }
  331. static inline __attribute__((device)) bool operator==(hip_bfloat16 a, hip_bfloat16 b) { return ((float)a) == ((float)b); }
  332. """)
  333. if any(uop.dtype == dtypes.half for uop in uops):
  334. prefix.append("#define half _Float16")
  335. vec_dts += [("_Float16", "half", 2), ("_Float16", "half", 4), ("_Float16", "half", 8), ("_Float16", "half", 16)]
  336. prefix += [_make_hip_dtype(*x) for x in vec_dts]
  337. for arg in dedup([uop.arg for uop in uops if uop.op is UOps.WMMA]): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper
  338. if arg[3] == dtypes.float: prefix.append(f"#define __{arg[0]} __builtin_amdgcn_wmma_f32_16x16x16_f16_w32")
  339. else: prefix.append(f"static inline __attribute__((device)) half8 __{arg[0]}"+"""(half16 a, half16 b, half8 c) {
  340. half16 c_frag = {}; half8 d; for (int n = 0; n < 8; n++) { c_frag[n*2] = c[n]; }
  341. c_frag = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a, b, c_frag, false);
  342. for (int n = 0; n < 8; n++) { d[n] = c_frag[n*2]; } return d;\n}""")
  343. return super().render_kernel(function_name, kernel, bufs, uops, prefix)
  344. def get_kernel_modifier(self, uops:UOpGraph) -> str:
  345. requiredMaxThreadsPerBlock = prod(u.arg[2] for u in uops if u.op is UOps.SPECIAL and u.arg[1][0] == "l")
  346. # https://clang.llvm.org/docs/AttributeReference.html#amdgpu-flat-work-group-size
  347. # NOTE: this makes hlb_cifar10 twice as fast, there may be more gains in tweaking these parameters
  348. return f"__attribute__((amdgpu_flat_work_group_size(1, {requiredMaxThreadsPerBlock})))"
  349. class NVRenderer(CUDARenderer): device = "NV"
  350. class HIPRenderer(AMDRenderer): device = "HIP"