thneed.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. # this can be constructed from a cl_cache or loaded from a thneed file
  2. import time
  3. import struct
  4. import json
  5. import traceback
  6. import numpy as np
  7. from tinygrad.runtime.ops_gpu import CLProgram, compile_gpu
  8. from tinygrad.device import Device
  9. from tinygrad.helpers import DEBUG, getenv
  10. from collections import defaultdict
  11. import pyopencl as cl
  12. from tinygrad.runtime.ops_gpu import OSX_TIMING_RATIO
  13. CL = Device["GPU"]
  14. DEBUGCL = getenv("DEBUGCL", 0)
  15. FLOAT16 = getenv("FLOAT16", 0)
  16. class Thneed:
  17. def __init__(self, cl_cache=[], inputs={}):
  18. self.cl_cache, self.inputs = cl_cache[:], inputs
  19. self.gobj = 0
  20. # build graph
  21. # NOTE: if CLCACHE=1, this is wrong!
  22. nodes = defaultdict(lambda: {'in_edges': [], 'out_edges': []})
  23. for _, args in self.cl_cache:
  24. # output is always the first parameter
  25. for a in args[3:]:
  26. nodes[a]['out_edges'].append(args[2])
  27. nodes[args[2]]['in_edges'].append(a)
  28. # get buffers to save
  29. self.buffers_to_save = set()
  30. self.outputs = []
  31. for n in nodes.keys():
  32. if len(nodes[n]['in_edges']) == 0:
  33. self.buffers_to_save.add(n)
  34. if len(nodes[n]['out_edges']) == 0:
  35. self.outputs.append(n)
  36. fake_inputs = []
  37. for k,n in self.inputs.items():
  38. if n in self.buffers_to_save:
  39. self.buffers_to_save.remove(n)
  40. else:
  41. print(f"WARNING: {k} was not a used input, removing it")
  42. fake_inputs.append(k)
  43. for k in fake_inputs:
  44. del self.inputs[k]
  45. def load(self, input_fn):
  46. float32 = not FLOAT16
  47. mf = cl.mem_flags
  48. image_fmt = cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.FLOAT if float32 else cl.channel_type.HALF_FLOAT)
  49. image_fmt_32 = cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.FLOAT)
  50. with open(input_fn, "rb") as f:
  51. json_len = struct.unpack("I", f.read(4))[0]
  52. jdat = json.loads(f.read(json_len).decode('latin_1'))
  53. weights = f.read()
  54. # load in the buffers
  55. bufs = {'\x00\x00\x00\x00\x00\x00\x00\x00': None}
  56. bufs_loaded = {}
  57. ptr = 0
  58. for o in jdat['objects']:
  59. #print(o)
  60. if o['needs_load']:
  61. nptr = ptr + o['size']
  62. o['data'] = weights[ptr:nptr]
  63. ptr = nptr
  64. if o['arg_type'] == "image2d_t" or o['arg_type'] == "image1d_t":
  65. tfmt = image_fmt_32 if 'float32' in o and o['float32'] else image_fmt
  66. if o['arg_type'] == "image2d_t":
  67. if 'buffer_id' in o and o['height'] == 1 and not bufs_loaded[o['buffer_id']]:
  68. # hack: use a image1d since we can back that with a buffer
  69. buf = cl.Image(CL.ctx, mf.READ_WRITE, tfmt, shape=(o['width'],), buffer=bufs[o['buffer_id']])
  70. else:
  71. # buffer isn't supported in image2d, copy buffer into image
  72. if 'buffer_id' in o and bufs_loaded[o['buffer_id']]:
  73. arr = np.zeros(bufs[o['buffer_id']].size // 2, dtype=np.float16)
  74. cl.enqueue_copy(CL.queue, arr, bufs[o['buffer_id']])
  75. buf = cl.Image(CL.ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, tfmt,
  76. shape=(o['width'], o['height']), pitches=(o['row_pitch'],), hostbuf=arr)
  77. elif o['needs_load']:
  78. buf = cl.Image(CL.ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, tfmt,
  79. shape=(o['width'], o['height']), pitches=(o['row_pitch'],), hostbuf=o['data'])
  80. else:
  81. buf = cl.Image(CL.ctx, mf.READ_WRITE, tfmt, shape=(o['width'], o['height']))
  82. if o['arg_type'] == "image1d_t":
  83. assert not o['needs_load']
  84. assert not bufs_loaded[o['buffer_id']]
  85. buf = cl.Image(CL.ctx, mf.READ_WRITE, tfmt, shape=(o['width'],), buffer=bufs[o['buffer_id']])
  86. else:
  87. if 'data' in o:
  88. buf = cl.Buffer(CL.ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=o['data'])
  89. else:
  90. # zero out buffers
  91. buf = cl.Buffer(CL.ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=b'\x00'*o['size'])
  92. bufs[o['id']] = buf
  93. bufs_loaded[o['id']] = 'data' in o
  94. # if it's loaded, it's saved
  95. if 'data' in o:
  96. self.buffers_to_save.add(buf)
  97. # load binaries
  98. prgs = {}
  99. for o in jdat['binaries']:
  100. nptr = ptr + o['length']
  101. prgs[o['name']] = CLProgram(Device["GPU"], o['name'], weights[ptr:nptr])
  102. ptr = nptr
  103. # populate the cl_cache
  104. for i,k in enumerate(jdat['kernels']):
  105. kernel = prgs[k['name']]
  106. aaa = []
  107. for j,(a,sz) in enumerate(zip(k['args'], k['args_size'])):
  108. if len(a) == 0:
  109. aa = cl.LocalMemory(sz)
  110. elif len(a) == 4:
  111. a = a.encode('latin_1')
  112. aa = np.uint32(struct.unpack("I", a)[0])
  113. elif len(a) == 2:
  114. a = a.encode('latin_1')
  115. aa = np.uint16(struct.unpack("H", a)[0])
  116. elif len(a) == 8:
  117. #print(i,j,struct.unpack("Q", a.encode('latin_1'))[0])
  118. aa = bufs[a]
  119. aaa.append(aa)
  120. self.cl_cache.append((kernel, [k['global_work_size'], k['local_work_size'], *aaa]))
  121. if DEBUG >= 1: print(f"thneed: total bufs loaded: {len(bufs.keys())}")
  122. # load inputs
  123. for k in jdat['inputs']:
  124. self.inputs[k['name']] = bufs[k['buffer_id']]
  125. # load outputs
  126. for k in jdat['outputs']:
  127. self.outputs.append(bufs[k['buffer_id']])
  128. def save(self, output_fn):
  129. # this is the struct that will be saved
  130. jdat = {"binaries": [], "programs": {}, "kernels": [], "objects": []}
  131. # build the pieces of this struct
  132. weights = []
  133. binaries = []
  134. saved_objs = set()
  135. saved_binaries = set()
  136. for prg, args in self.cl_cache:
  137. # get binaries for saving
  138. if prg.name not in saved_binaries:
  139. binary = prg.clprogram.get_info(cl.program_info.BINARIES)
  140. assert len(binary) == 1
  141. jdat['binaries'].append({"name":prg.name, "length":len(binary[0])})
  142. binaries.append(binary[0])
  143. saved_binaries.add(prg.name)
  144. # get the args from the kernel, some need the data saved
  145. targs, args_size = [], []
  146. argdtypes = [None]*(len(args)-2)
  147. for a,d in zip(args[2:], argdtypes):
  148. if d == np.int16:
  149. targs.append(struct.pack("H", a).decode("latin_1"))
  150. args_size.append(2)
  151. elif d == np.int32:
  152. targs.append(struct.pack("I", a).decode("latin_1"))
  153. args_size.append(4)
  154. elif isinstance(a, cl.LocalMemory):
  155. targs.append("")
  156. args_size.append(a.size)
  157. elif d is None:
  158. if getattr(a, "global_id", None) is None:
  159. setattr(a, "global_id", self.gobj)
  160. self.gobj += 1
  161. ptr = struct.pack("Q", a.global_id).decode("latin_1")
  162. if ptr not in saved_objs:
  163. if isinstance(a, cl.Buffer):
  164. needs_load = a in self.buffers_to_save
  165. jdat['objects'].append({
  166. "id": ptr, "arg_type": "float*", "needs_load": needs_load, "size": a.size,
  167. })
  168. if needs_load:
  169. data = np.empty(a.size//4, dtype=np.float32)
  170. cl.enqueue_copy(CL.queue, data, a, is_blocking=True)
  171. weights.append(data.tobytes())
  172. elif isinstance(a, cl.Image):
  173. assert a.format == cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.HALF_FLOAT if FLOAT16 else cl.channel_type.FLOAT), "wrong type"
  174. needs_load = a in self.buffers_to_save
  175. row_pitch = (a.shape[0]*4*(2 if FLOAT16 else 4) + 63)//64 * 64
  176. size = row_pitch * a.shape[1]
  177. # this is *2 if float16 and *4 if float32
  178. buf = cl.Buffer(CL.ctx, cl.mem_flags.READ_WRITE, size=size * (2 if FLOAT16 else 1))
  179. # zero out the buffer
  180. cl.enqueue_copy(CL.queue, buf, b'\x00'*buf.size, is_blocking=True)
  181. CLProgram(CL, "from_image_strided", compile_gpu("""
  182. __kernel void from_image_strided(read_only image2d_t in, __global float4 *out, int row_pitch) {
  183. const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
  184. int2 l;
  185. l.y = get_global_id(1);
  186. l.x = get_global_id(0);
  187. out[l.y*row_pitch + l.x] = read_imagef(in, smp, l);
  188. }
  189. """), bufs=2, vars=1)(a, buf, row_pitch//(4*(2 if FLOAT16 else 4)), global_size=a.shape)
  190. # multiple of 32 isn't enough
  191. jdat['objects'].append({
  192. "id": ptr, "needs_load": needs_load, "size": size, "arg_type": "image2d_t",
  193. "width": a.shape[0], "height": a.shape[1], "row_pitch": row_pitch, "float32": not FLOAT16,
  194. })
  195. if needs_load:
  196. data = np.empty(size//(2 if FLOAT16 else 4), dtype=np.float32)
  197. cl.enqueue_copy(CL.queue, data, buf, is_blocking=True)
  198. if FLOAT16: data = data.astype(np.float16)
  199. weights.append(data.tobytes())
  200. else:
  201. raise Exception("unknown object", a)
  202. #print(jdat['objects'][-1])
  203. saved_objs.add(ptr)
  204. targs.append(ptr)
  205. args_size.append(8)
  206. else:
  207. raise Exception("idk this type")
  208. # save the kernel itself
  209. jdat['kernels'].append({
  210. "name": prg.name,
  211. "work_dim": len(args[0]),
  212. "global_work_size": args[0],
  213. # TODO: C++ thneed requires a local_work_size, so we fill it with ones
  214. "local_work_size": [1 for _ in args[0]] if args[1] is None else args[1],
  215. "num_args": len(args)-2,
  216. "args": targs,
  217. "args_size": args_size
  218. })
  219. jdat['outputs'] = [{
  220. "buffer_id": struct.pack("Q", x.global_id).decode("latin_1"),
  221. "size": x.size,
  222. } for x in self.outputs]
  223. jdat['inputs'] = [{
  224. "buffer_id": struct.pack("Q", v.global_id).decode("latin_1"),
  225. "size": v.size,
  226. "name": k
  227. } for k,v in self.inputs.items()][::-1]
  228. print(f"saving thneed to {output_fn}")
  229. with open(output_fn, "wb") as f:
  230. j = json.dumps(jdat, ensure_ascii=False).encode('latin_1')
  231. f.write(struct.pack("I", len(j)))
  232. f.write(j)
  233. f.write(b''.join(weights))
  234. f.write(b''.join(binaries))
  235. def run(self):
  236. events = []
  237. st = time.monotonic()
  238. for prg, args in self.cl_cache:
  239. events.append(prg.clprg(CL.queue, *args))
  240. mt = time.monotonic()
  241. Device["GPU"].synchronize()
  242. et = time.monotonic() - st
  243. print(f"submit in {(mt-st)*1000.0:.2f} ms, total runtime is {et*1000.0:.2f} ms")
  244. if DEBUGCL >= 2:
  245. for i, ((prg, args), e) in enumerate(zip(self.cl_cache, events)):
  246. print(f"{i:3d} {prg.name:25s} " + "queued @ %5.2f ms, submit @ %5.2fms, start @ %5.2f ms, end @ %5.2f ms" % tuple((x*OSX_TIMING_RATIO - st*1e9)/1e6 for x in [e.profile.queued, e.profile.submit, e.profile.start, e.profile.end]))
  247. if DEBUGCL >= 1:
  248. total_runtime = 0
  249. for i, ((prg, args), e) in enumerate(zip(self.cl_cache, events)):
  250. runtime = (e.profile.end - e.profile.start) * OSX_TIMING_RATIO
  251. print(f"{i:3d} time {total_runtime/1e6:5.2f} ms running {prg.name:25s} with {str(args[0]):15s} {str(args[1]):15s} count {len(args)-2:2d} runtime {runtime/1e3:7.2f} us {(getattr(prg, 'op_estimate', float('nan')))/runtime:9.2f} GFLOPS -> {args[2].shape if hasattr(args[2], 'shape') else args[2].size}")
  252. if hasattr(prg, 'prg') and ((DEBUGCL >= 2 and getenv("PRINT_KERNEL", -1) == i) or DEBUGCL >= 3):
  253. print(prg.prg)
  254. total_runtime += runtime
  255. print(f"total runtime: {total_runtime/1e6:.2f} ms wall time: {et*1000.0:.2f} ms")
  256. return total_runtime/1e9
  257. return et