triton.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. from typing import Dict, List, Final, Callable, DefaultDict
  2. from collections import defaultdict
  3. from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, Op
  4. from tinygrad.helpers import DType, PtrDType, dtypes, ImageDType, DEBUG, getenv
  5. from tinygrad.codegen.kernel import UOp, UOps
  6. from triton.compiler import compile as triton_compile
  7. import linecache
  8. import math
  9. import re
  10. triton_dtypes = {dtypes.double: "tl.float64", dtypes.float32: "tl.float32", dtypes.float16: "tl.float16", dtypes.bool: "tl.int1", dtypes.int8: "tl.int8", dtypes.uint8: "tl.uint8", dtypes.int32: "tl.int32", dtypes.int64: "tl.int64", dtypes.uint32: "tl.uint32", dtypes.uint64: "tl.uint64", dtypes.int16: "tl.int16", dtypes.uint16: "tl.uint16"}
  11. signature_dtypes = {dtypes.double: "fp64",dtypes.float32: "fp32", dtypes.float16: "fp16", dtypes.bool: "i8", dtypes.int8: "i1", dtypes.uint8: "u8", dtypes.int32: "i32", dtypes.int64: "i64", dtypes.uint32: "u32", dtypes.uint64: "u64", dtypes.int16: "i16", dtypes.uint16: "u16"}
  12. def next_power_of_2(x):
  13. return 1 << (x - 1).bit_length()
  14. def render_valid(valid):
  15. return '(' * (len(valid) -1) + ') and '.join(valid) if len(valid) else 'True'
  16. #NOTE Triton requires matching dimensions for load/store, disable this and see TestOps::test_output_padded_conv_transpose2d fail to compile
  17. def fill_dims_for_idx(idx, dims):
  18. return "(" + idx + "+ (" + (f"0*({'+'.join(d for d in dims)})))") if len(dims) else idx
  19. def get_max(var):
  20. if isinstance(var, int): return var
  21. return re.sub(r'\[(.*?)\]', '', str(var))[1:-1]
  22. #NOTE can be removed after https://github.com/gpuocelot/gpuocelot/issues/8 gets resolved
  23. def remove_single_scalar_curly_braces(ptx_code):
  24. return '\n'.join([re.sub(r'\{\s*(%\w+)\s*\}', r'\1', line) for line in ptx_code.split('\n')])
  25. def render_const(args,dtype:DType):
  26. return (('-' if args<0 else '') + 'tl.where(1,float("inf"),0)') if math.isinf(args) else ('tl.where(1,float("nan"),0)' if math.isnan(args) else f"{int(args)}" if dtypes.is_int(dtype) else str(args))
  27. def render_cast(x:str, dtype:DType, bitcast=False):
  28. return f"{x}.to({triton_dtypes[dtype]}, bitcast={bitcast})"
  29. def define_scalar(local_size, dtype, args):
  30. if len(local_size) > 0: return f"tl.full(({','.join([str(next_power_of_2(x)) for x in local_size])},),{render_const(args,dtype)}, dtype={triton_dtypes[dtype]})"
  31. return render_const(args,dtype)
  32. def uops_to_triton(function_name:str, uops:List[UOp]):
  33. local_size: List[int] = []
  34. depth = 1
  35. signatures, dims, bufs, kernel, valid = [], [], [], [], [] #type: ignore
  36. c: DefaultDict[str, int] = defaultdict(int)
  37. r: Dict[UOp, str] = {}
  38. def ssa(u, prefix="t"):
  39. nonlocal c, r
  40. c[prefix] += 1
  41. r[u]=f"{prefix}{c[prefix]-1}"
  42. return r[u]
  43. child_count: DefaultDict[UOp, int] = defaultdict(int)
  44. for ru in uops:
  45. for v in ru.vin:
  46. child_count[v] += 1
  47. def kk(s): kernel.append(" "*depth+s)
  48. code_for_op: Final[Dict[Op, Callable]] = {
  49. UnaryOps.EXP2: lambda x,dtype,: f"tl.math.exp2({x})",
  50. UnaryOps.LOG2: lambda x,dtype,: f"tl.math.log2({x})",
  51. UnaryOps.SIN: lambda x,dtype: f"tl.sin({x})",
  52. UnaryOps.SQRT: lambda x,dtype: f"tl.sqrt({x})",
  53. UnaryOps.NEG: lambda x,dtype: f"-{x}",
  54. BinaryOps.ADD: lambda x,y,dtype: f"({x}+{y})", BinaryOps.SUB: lambda x,y,: f"({x}-{y})",
  55. BinaryOps.MUL: lambda x,y,dtype: f"({x}*{y})", BinaryOps.DIV: lambda x,y,: f"({x}/{y})" if y != '0.0' else f"{x}*tl.where({x}==0.0, float('nan'), float('inf'))",
  56. BinaryOps.MAX: lambda x,y,dtype: f"tl.maximum({x},{y})",
  57. BinaryOps.CMPLT: lambda x,y,dtype: f"({x}<{y})",
  58. BinaryOps.MOD: lambda x,y,dtype: f"tl.abs({x})%tl.abs({y})*tl.where({x}<0,-1,1)",
  59. TernaryOps.MULACC: lambda x,y,z,dtype: f"(({x}*{y})+{z})",
  60. TernaryOps.WHERE: lambda x,y,z,dtype: f"tl.where({x},{y},{z})",
  61. }
  62. def int_div(x,y): return f"({x}//{y})" if y != '0' else f"{x}*tl.where({x}==0, float('nan'), float('inf'))"
  63. for u in uops:
  64. uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
  65. if uop == UOps.LOOP:
  66. kk(f"for {ssa(u, 'ridx')} in range({vin[0].arg}, {r[vin[1]]}):")
  67. depth += 1
  68. elif uop == UOps.END: depth -= 1
  69. elif uop == UOps.ALU:
  70. assert dtype is not None
  71. val = code_for_op[args](*[r[x] for x in vin])
  72. if child_count[u] <=1 or dtypes.is_int(dtype): r[u] = int_div(*[r[x] for x in vin]) if args == BinaryOps.DIV and dtypes.is_int(dtype) else val
  73. else: kk(f"{ssa(u, 'alu')} = ({val})")
  74. elif uop == UOps.LOAD:
  75. assert dtype is not None
  76. if len(vin) == 2: kk(f"{ssa(u, 'val')} = {render_cast(f'tl.load({r[vin[0]]} + { fill_dims_for_idx(r[vin[1]], dims)}, mask = {render_valid(valid)})', dtype)}")
  77. else: kk(f"{ssa(u, 'val')} = {render_cast(f'tl.where({r[vin[2]]}, tl.load({r[vin[0]]}+{fill_dims_for_idx(r[vin[1]],dims)} , mask={render_valid(valid+[r[vin[2]]])}), 0.0)', dtype)}")
  78. elif uop == UOps.DEFINE_ACC: kk(f"{ssa(u, 'acc')} = {define_scalar(local_size, dtype, args).replace('//', '/')}")
  79. elif uop == UOps.CONST: r[u] = define_scalar([], dtype, args)
  80. elif uop == UOps.PHI:
  81. kk(f"{r[vin[0]]} = {r[vin[1]].replace('//', '/')}")
  82. r[u] = r[vin[0]]
  83. elif uop == UOps.STORE:
  84. assert not isinstance(dtype, ImageDType), "unimplemented: image store"
  85. kk(f"{'if '+r[vin[3]]+': ' if len(vin)>3 else ''}tl.store({r[vin[0]]} + {r[vin[1]]}, {r[vin[2]].replace('//', '/')}, mask = {render_valid(valid)}) ")
  86. elif uop == UOps.DEFINE_GLOBAL:
  87. bufs.append(args)
  88. signatures.append("*" if isinstance(dtype, PtrDType) else "" + signature_dtypes[dtype])
  89. r[u] = args
  90. elif uop == UOps.SPECIAL:
  91. dims.append(args[1])
  92. valid.append(f"{args[1]}<{get_max(args[2])}")
  93. if args[1].startswith("g"): kk(f"{args[1]} = tl.program_id({args[0]}) # {args[2]}")
  94. elif args[1].startswith("l"):
  95. kk(f"{args[1]} = tl.arange({0}, {next_power_of_2(args[2])})")
  96. local_size.append(args[2])
  97. r[u] = args[1]
  98. elif uop == UOps.CAST and dtype is not None: r[u] = render_cast(r[vin[0]], dtype, isinstance(args, tuple) and args[1])
  99. else: raise NotImplementedError(f"unimplemented: {uop}")
  100. prg = f"import triton\nimport triton.language as tl\ntl.core.TRITON_MAX_TENSOR_NUMEL = float('inf')\n@triton.jit\ndef {function_name}("+','.join(bufs)+"):\n"
  101. for i, line in enumerate(list(filter(lambda line: "tl.arange" in line, kernel))): kernel[kernel.index(line)] += f"[{', '.join([':' if i == j else 'None' for j in range(len(local_size))])}]"
  102. prg += "\n".join(kernel)
  103. acc_local_size = 1
  104. for x in local_size: acc_local_size *= next_power_of_2(x)
  105. local_size = [acc_local_size] + [1] * (len(local_size) - 1)
  106. if DEBUG >= 4: print(prg)
  107. getlines = linecache.getlines
  108. linecache.getlines = lambda filename, module_globals=None: prg.splitlines(keepends=True) if "<triton>" == filename else getlines(filename, module_globals)
  109. exec(compile(prg, "<triton>", "exec"), globals()) # pylint: disable=W0122\
  110. compiled = triton_compile(globals()[function_name], signature=",".join(signatures), device_type="cuda", debug=False, cc=(35 if getenv("CUDACPU", 0) else None))
  111. prg = remove_single_scalar_curly_braces(compiled.asm["ptx"].split(".file")[0].split(".visible .func")[0])
  112. max_local_size = [int(x) for x in prg.split(".maxntid ")[1].split("\n")[0].split(", ")]
  113. for i in range(len(local_size)): local_size[i] = min(local_size[i], max_local_size[i])
  114. return prg, {"shared":compiled.metadata["shared"], "local_size":local_size + [1]*(3-len(local_size))}