ops_python.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. # pylint: disable=cell-var-from-loop
  2. # a python uops emulator
  3. # works to test the tensor cores, and all the uops in general
  4. # this is the (living) definition of uops
  5. from typing import Tuple, List, Optional, Any, Dict
  6. import pickle, base64, itertools, time, struct
  7. from tinygrad.dtype import DType, dtypes, ImageDType
  8. from tinygrad.helpers import all_same, getenv, flatten
  9. from tinygrad.device import Compiled, Compiler, Allocator
  10. from tinygrad.codegen.uops import UOps
  11. from tinygrad.codegen.uopgraph import UOpGraph
  12. from tinygrad.ops import BinaryOps, TernaryOps, exec_alu, truncate
  13. from tinygrad.renderer import Renderer
  14. from tinygrad.renderer.cstyle import CUDARenderer, MetalRenderer, AMDRenderer
  15. def _load(m, i):
  16. if i < 0 or i >= len(m): raise IndexError(f"load out of bounds, size is {len(m)} and access is {i}")
  17. return m[i]
  18. def load(inp, j=0):
  19. if len(inp) == 4: return [_load(m, x+j) if gate else default for m,x,gate,default in zip(*inp)]
  20. return [_load(m, x+j) for m,x in zip(inp[0], inp[1])]
  21. def _store(m, i, v):
  22. if i < 0 or i >= len(m): raise IndexError(f"store out of bounds, size is {len(m)}, access is {i}, value is {v}")
  23. m[i] = v
  24. class PythonProgram:
  25. def __init__(self, name:str, lib:bytes):
  26. self.uops: List[Tuple[UOps, Optional[DType], List[int], Any]] = pickle.loads(lib)
  27. def __call__(self, *bufs, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
  28. st = time.perf_counter()
  29. warp = list(itertools.product(*[range(x) for x in local_size[::-1]]))
  30. warp_size = len(warp)
  31. for idxs in itertools.product(*[range(x) for x in global_size[::-1]]):
  32. ul: Dict[int, Any] = {}
  33. dl: Dict[int, DType] = {}
  34. pbufs: List[memoryview] = list(bufs)
  35. pvals: List[int] = list(vals)
  36. i = 0
  37. loop_ends: Dict[int, int] = {}
  38. while i < len(self.uops):
  39. uop, dtype, idp, arg = self.uops[i]
  40. void_ops = {UOps.STORE, UOps.ENDRANGE, UOps.BARRIER, UOps.IF, UOps.ENDIF}
  41. if uop is UOps.DEFINE_ACC: idp = [idp[0]]
  42. inp = [ul[v] for v in idp if self.uops[v][0] not in void_ops]
  43. dtp = [dl[v] for v in idp if self.uops[v][0] not in void_ops]
  44. if getenv("TRACE"): print(i, uop, dtype, arg, inp, dtp)
  45. if uop is UOps.STORE:
  46. if len(inp) == 3: inp.append([True] * len(inp[0])) # set the gate to True
  47. if isinstance(dtp[0], ImageDType):
  48. # image store
  49. assert dtp[2].count == 4
  50. for j,val in enumerate(inp[2]):
  51. for m,ox,oy,v,g in zip(inp[0], inp[1][0], inp[1][1], val, inp[3]):
  52. assert ox >= 0 and ox < dtp[0].shape[1] and oy >= 0 and oy < dtp[0].shape[0]
  53. if g: _store(m, ox*4 + oy*dtp[0].shape[1]*4 + j, v)
  54. elif dtp[2].count > 1:
  55. for j,val in enumerate(inp[2]):
  56. for m,o,v,g in zip(inp[0], inp[1], val, inp[3]):
  57. if g: _store(m, o+j, v)
  58. else:
  59. for m,o,v,g in zip(*inp):
  60. if g: _store(m, o, v)
  61. i += 1
  62. continue
  63. if uop is UOps.ENDRANGE:
  64. loop_ends[idp[0]] = i
  65. i = idp[0]
  66. continue
  67. if uop in (UOps.BARRIER, UOps.IF, UOps.ENDIF):
  68. # in the python emulator, the warp is always in sync
  69. i += 1
  70. continue
  71. assert dtype is not None, f"{uop} is missing a dtype"
  72. dl[i] = dtype
  73. if uop is UOps.DEFINE_GLOBAL:
  74. assert dtype.fmt is not None
  75. ul[i] = [pbufs.pop(0).cast(dtype.fmt)] * warp_size
  76. elif uop is UOps.DEFINE_LOCAL:
  77. assert dtype.fmt is not None
  78. lbuf = memoryview(bytearray(arg[1]*dtype.itemsize))
  79. ul[i] = [lbuf.cast(dtype.fmt)] * warp_size
  80. elif uop is UOps.DEFINE_VAR:
  81. ul[i] = [pvals.pop(0)] * warp_size
  82. elif uop is UOps.SPECIAL:
  83. if arg[1][0] == 'g':
  84. ul[i] = [idxs[2-arg[0]]] * warp_size
  85. elif arg[1][0] == 'l':
  86. ul[i] = [x[2-arg[0]] for x in warp]
  87. elif uop is UOps.CONST:
  88. ul[i] = [[arg] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [arg] * warp_size
  89. elif uop is UOps.DEFINE_ACC:
  90. ul[i] = [[inp[0][0]] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [inp[0][0]] * warp_size
  91. elif uop is UOps.RANGE:
  92. if i not in ul: ul[i] = [inp[0][0]] * warp_size
  93. else:
  94. for j in range(len(ul[i])):
  95. ul[i][j] += 1
  96. if ul[i][0] == inp[1][0]:
  97. del ul[i]
  98. i = loop_ends[i] + 1
  99. continue
  100. elif uop is UOps.VECTORIZE: ul[i] = inp
  101. elif uop in {UOps.CAST, UOps.BITCAST}:
  102. assert dtp[0].fmt and dtype.fmt
  103. pack_format, unpack_format = str(warp_size) + dtp[0].fmt, str(warp_size) + dtype.fmt
  104. if uop is UOps.BITCAST: ul[i] = list(struct.unpack(unpack_format, struct.pack(pack_format, *inp[0])))
  105. else:
  106. casted = [dtypes.as_const(x, dtype) for x in inp[0]]
  107. if dtypes.is_int(dtype):
  108. overflow_adjust = 2**(dtype.itemsize*8 - 1) if not dtypes.is_unsigned(dtype) else 0
  109. casted = [((x + overflow_adjust) % 2**(dtype.itemsize*8) - overflow_adjust) for x in casted]
  110. elif dtypes.is_float(dtype):
  111. casted = [truncate.get(dtype, lambda dt: dt)(x) for x in casted]
  112. ul[i] = list(struct.unpack(unpack_format, struct.pack(unpack_format, *casted)))
  113. elif uop is UOps.LOAD:
  114. if isinstance(dtp[0], ImageDType):
  115. assert dtype.count == 4
  116. ul[i] = []
  117. for j in range(dtype.count):
  118. ret = []
  119. for m,ox,oy in zip(inp[0], inp[1][0], inp[1][1]):
  120. if ox < 0 or ox >= dtp[0].shape[1] or oy < 0 or oy >= dtp[0].shape[0]: ret.append(0)
  121. else: ret.append(_load(m, ox*4 + oy*dtp[0].shape[1]*4 + j))
  122. ul[i].append(ret)
  123. elif dtype.count > 1:
  124. ul[i] = [load([inp[i][j] if dtp[i].count > 1 else inp[i] for i in range(len(inp))], j) for j in range(dtype.count)]
  125. else:
  126. ul[i] = load(inp)
  127. elif uop is UOps.PHI:
  128. for j in range(len(inp[0])): inp[0][j] = inp[1][j]
  129. ul[i] = inp[0]
  130. elif uop is UOps.GEP:
  131. ul[i] = inp[0][arg]
  132. elif uop is UOps.WMMA:
  133. # here are the models for the WMMA instruction on the different hardware
  134. def wmma_helper(WARP_THREADS, K, NUM_A, NUM_B, NUM_C, a_elem, b_elem, c_map):
  135. assert len(inp[0]) == NUM_A, f"A must have {NUM_A} elements per thread, it has {len(inp[0])}"
  136. assert len(inp[1]) == NUM_B, f"B must have {NUM_B} elements per thread, it has {len(inp[1])}"
  137. assert len(inp[2]) == NUM_C, f"C must have {NUM_C} elements per thread, it has {len(inp[2])}"
  138. assert len(flatten(inp[0])) == NUM_A * warp_size, f"WMMA must have {NUM_A * warp_size} total elements for A in WMMA"
  139. assert len(flatten(inp[1])) == NUM_B * warp_size, f"WMMA must have {NUM_B * warp_size} total elements for B in WMMA"
  140. assert len(flatten(inp[2])) == NUM_C * warp_size, f"WMMA must have {NUM_C * warp_size} total elements for C in WMMA"
  141. assert warp_size > 0 and warp_size % WARP_THREADS == 0, f"must have multiples of {WARP_THREADS} warp threads"
  142. out = [inp[2][elem_idx][:] for elem_idx in range(NUM_C)]
  143. for goff in range(0, warp_size, WARP_THREADS):
  144. for lane_id in range(WARP_THREADS):
  145. for elem_idx in range(NUM_C): # calculate new muls and add to acc
  146. (c_i, c_j) = c_map(lane_id, elem_idx)
  147. out[elem_idx][goff+lane_id] += sum(a_elem(inp[0], _k, c_j, goff) * b_elem(inp[1], c_i, _k, goff) for _k in range(K))
  148. return out
  149. # TODO: refactor these to a shared TensorCoreLayout in kernel.py
  150. if arg[5] == "METAL":
  151. # A (2 elements on 32 threads): row major
  152. def a_b_elem(x, i, j, goff): return x[(i%2)][goff+(i//2)%2+(j%4)*2+(i//4)*8+(j//4)*16]
  153. # (i, j), C, D (2 elements on 32 threads): row major same as A/B
  154. def c_map(lane, elem): return (elem + ((lane%2)*2) + ((lane//8)%2)*4, ((lane//2)%4) + (lane//16)*4)
  155. ul[i] = wmma_helper(32, 8, 2, 2, 2, a_b_elem, a_b_elem, c_map)
  156. elif arg[5] == "AMD":
  157. # A (16 elements on 32 threads): col major, lane 16-32 == lane 0-15
  158. def a_elem(x, i, j, goff):
  159. assert x[i][goff+j] == x[i][goff+j+16], "warp elements not duplicated properly across lanes"
  160. return x[i][goff+j]
  161. # B (16 elements on 32 threads): row major, lane 16-32 == lane 0-15
  162. def b_elem(x, i, j, goff): return a_elem(x, j, i, goff) # pylint: disable=arguments-out-of-order
  163. def c_map(lane, elem): return (lane%16, lane//16+elem*2) # (i, j), C, D (8 elements on 32 threads): row major
  164. ul[i] = wmma_helper(32, 16, 16, 16, 8, a_elem, b_elem, c_map)
  165. elif arg[5] == "CUDA":
  166. # A (8 elements on 32 threads)
  167. def a_elem(x, i, j, goff): return x[(i%2)+(j//8)*2+(i//8)*4][goff+((i//2)%4)+(j%8)*4]
  168. # B (4 elements on 32 threads)
  169. def b_elem(x, i, j, goff): return x[(j%2)+(j//8)*2][goff+(j//2)%4+(i)*4]
  170. # (i, j), C, D (4 elements on 32 threads)
  171. def c_map(lane, elem): return ((elem%2)+(lane%4)*2, (lane//4)+(elem//2)*8)
  172. ul[i] = wmma_helper(32, 16, 8, 4, 4, a_elem, b_elem, c_map)
  173. else: raise NotImplementedError(f"unimplemented tensor core {arg}")
  174. elif uop is UOps.ALU:
  175. assert all_same([len(x) for x in inp]), f"{[len(x) for x in inp]} doesn't match on {arg}"
  176. assert all_same([dtype] + dtp) or arg in {BinaryOps.CMPNE, BinaryOps.CMPLT, TernaryOps.WHERE}, f"dtype mismatch on {arg}"
  177. ul[i] = [exec_alu(arg, dtype, p) for p in zip(*inp)]
  178. assert i in ul, (uop, dtype, idp, arg)
  179. i += 1
  180. return time.perf_counter() - st
  181. class PythonRenderer(Renderer):
  182. device = "PYTHON"
  183. def __init__(self):
  184. if getenv("EMULATE_METAL"): self.device, self.tensor_cores = "METAL", MetalRenderer.tensor_cores
  185. if getenv("EMULATE_AMD"): self.device, self.tensor_cores = "AMD", AMDRenderer.tensor_cores
  186. if getenv("EMULATE_CUDA"): self.device, self.tensor_cores = "CUDA", CUDARenderer.tensor_cores
  187. def render(self, name:str, uops:UOpGraph) -> str:
  188. lops = [(u.op, u.dtype, [uops.uops.index(v) for v in u.src], u.arg) for u in uops]
  189. return base64.b64encode(pickle.dumps(lops)).decode()
  190. class PythonCompiler(Compiler):
  191. def compile(self, src:str) -> bytes: return base64.b64decode(src)
  192. class PythonAllocator(Allocator):
  193. def _alloc(self, size, options): return memoryview(bytearray(size))
  194. def copyin(self, dest, src:memoryview): dest[:] = src
  195. def copyout(self, dest:memoryview, src): dest[:] = src
  196. class PythonDevice(Compiled):
  197. def __init__(self, device:str):
  198. super().__init__(device, PythonAllocator(), PythonRenderer(), PythonCompiler(), PythonProgram)