export_model.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355
  1. from typing import Tuple, Dict, List
  2. from tinygrad.dtype import DType
  3. from tinygrad.renderer import Program
  4. from tinygrad.tensor import Device, Tensor
  5. from tinygrad.engine.jit import TinyJit
  6. from tinygrad.nn.state import get_state_dict
  7. from tinygrad.helpers import Context
  8. from tinygrad.dtype import dtypes
  9. import json
  10. EXPORT_SUPPORTED_DEVICE = ["WEBGPU", "WEBGL", "CLANG", "CUDA", "GPU"]
  11. web_utils = {
  12. "getTensorBuffer":
  13. """const getTensorBuffer = (safetensorBuffer, tensorMetadata) => {
  14. return safetensorBuffer.subarray(...tensorMetadata.data_offsets);
  15. }""",
  16. "getTensorMetadata": """const getTensorMetadata = (safetensorBuffer) => {
  17. const metadataLength = Number(new DataView(safetensorBuffer.buffer).getBigUint64(0, true));
  18. const metadata = JSON.parse(new TextDecoder("utf8").decode(safetensorBuffer.subarray(8, 8 + metadataLength)));
  19. return Object.fromEntries(Object.entries(metadata).filter(([k, v]) => k !== "__metadata__").map(([k, v]) => [k, {...v, data_offsets: v.data_offsets.map(x => 8 + metadataLength + x)}]));
  20. };"""
  21. }
  22. def compile_net(run:TinyJit, special_names:Dict[int,str]) -> Tuple[Dict[str,str],List[Tuple[str,List[str],List[int]]],Dict[str,Tuple[int,DType,int]],Dict[str,Tensor]]:
  23. functions, bufs, bufs_to_save, statements, bufnum = {}, {}, {}, [], 0
  24. for ji in run.jit_cache:
  25. fxn: Program = ji.prg.p
  26. functions[fxn.function_name] = fxn.src # NOTE: this assumes all with the same name are the same
  27. cargs = []
  28. for i,arg in enumerate(ji.bufs):
  29. key = id(arg)
  30. if key not in bufs:
  31. if key in special_names:
  32. bufs[key] = (special_names[key], arg.size*arg.dtype.itemsize, arg.dtype, key)
  33. else:
  34. bufs[key] = (f"buf_{bufnum}", arg.size*arg.dtype.itemsize, arg.dtype, key)
  35. bufnum += 1
  36. if i > 0: bufs_to_save[bufs[key][0]] = arg # if first usage of a buffer is not an output, and it's not a special name
  37. cargs.append(bufs[key][0])
  38. statements.append((fxn.function_name, cargs, fxn.global_size, fxn.local_size))
  39. return functions, statements, {name:(size, dtype, key) for (name,size,dtype,key) in bufs.values()}, bufs_to_save
  40. def jit_model(model, *args) -> Tuple[TinyJit,Dict[int,str]]:
  41. assert hasattr(model, "forward") or callable(model), "model needs a forward function"
  42. @TinyJit
  43. def run(*x):
  44. out = model.forward(*x) if hasattr(model, "forward") else model(*x)
  45. assert isinstance(out, tuple) or isinstance(out, list) or isinstance(out, Tensor), "model output must be a Tensor, tuple, or a list of Tensors for export"
  46. out = [out] if isinstance(out, Tensor) else out
  47. return [o.realize() for o in out]
  48. # twice to run the JIT
  49. for _ in range(2): the_output = run(*args)
  50. special_names = {}
  51. # hack to put the inputs back
  52. for (j,i),idx in run.input_replace.items():
  53. realized_input = args[idx].lazydata.base.realized
  54. run.jit_cache[j].bufs[i] = realized_input
  55. special_names[id(realized_input)] = f'input{idx}'
  56. # TODO: fetch this from the jit in self.input_replace and self.ret (hint: use get_parameters on self.ret)
  57. for i, output in enumerate(the_output):
  58. special_names[id(output.lazydata.base.realized)] = f'output{i}'
  59. return run, special_names
  60. def export_model_clang(functions:Dict[str,str], statements:Dict[str,Tuple[str,int,int]], bufs:Dict[str,Tuple[str,int,int]], bufs_to_save:Dict[str,Tensor], input_names:List[str], output_names:List[str]) -> str:
  61. cprog = ["#include <tgmath.h>"]
  62. for name,cl in bufs_to_save.items():
  63. weight = ''.join(["\\x%02X"%x for x in bytes(cl._buf)])
  64. cprog.append(f"unsigned char {name}_data[] = \"{weight}\";")
  65. inputs = ", ".join([f'float* {input}' for input in input_names])
  66. outputs = ", ".join([f'float* {output}' for output in output_names])
  67. cprog += [f"float {name}[{len}];" if name not in bufs_to_save else f"float *{name} = (float *){name}_data;" for name,(len,dtype,_key) in bufs.items() if name not in ['input', 'outputs']]
  68. cprog += list(functions.values())
  69. cprog += [f"void net({inputs}, {outputs}) {{"] + [f"{name}({', '.join(args)});" for (name, args, _global_size, _local_size) in statements] + ["}"]
  70. return '\n'.join(cprog)
  71. def export_model_webgl(functions, statements, bufs, bufs_to_save, weight_names, input_names, output_names) -> str:
  72. header = f"""
  73. function setupNet(gl, safetensor) {{
  74. function createShaderProgram(gl, code) {{
  75. const vertexShader = loadShader(gl, gl.VERTEX_SHADER, '#version 300 es\\nin vec2 in_position;in vec2 in_uv;out vec2 uv;void main(){{gl_Position=vec4(in_position,0.0,1.0);uv=in_uv;}}');
  76. const fragmentShader = loadShader(gl, gl.FRAGMENT_SHADER, code);
  77. const shaderProgram = gl.createProgram();
  78. gl.attachShader(shaderProgram, vertexShader);
  79. gl.attachShader(shaderProgram, fragmentShader);
  80. gl.linkProgram(shaderProgram);
  81. if (!gl.getProgramParameter(shaderProgram, gl.LINK_STATUS)) {{
  82. console.log(`Unable to initialize the shader program: ${{gl.getProgramInfoLog(shaderProgram)}}`);
  83. return null;
  84. }}
  85. return shaderProgram;
  86. }}
  87. function loadShader(gl, type, source) {{
  88. const shader = gl.createShader(type);
  89. gl.shaderSource(shader, source);
  90. gl.compileShader(shader);
  91. if (!gl.getShaderParameter(shader, gl.COMPILE_STATUS)) {{
  92. console.log(`An error occurred compiling the shaders: ${{gl.getShaderInfoLog(shader)}}`);
  93. gl.deleteShader(shader);
  94. return null;
  95. }}
  96. return shader;
  97. }}
  98. function setupVertexData(gl, program, vertices) {{
  99. let vao = gl.createVertexArray();
  100. gl.bindVertexArray(vao);
  101. let vertexBuffer = gl.createBuffer();
  102. gl.bindBuffer(gl.ARRAY_BUFFER, vertexBuffer);
  103. gl.bufferData(gl.ARRAY_BUFFER, new Float32Array(vertices), gl.STATIC_DRAW);
  104. const positionLocation = gl.getAttribLocation(program, 'in_position');
  105. const uvLocation = gl.getAttribLocation(program, 'in_uv');
  106. gl.enableVertexAttribArray(positionLocation);
  107. gl.vertexAttribPointer(positionLocation, 2, gl.FLOAT, false, 4 * 4, 0);
  108. gl.enableVertexAttribArray(uvLocation);
  109. gl.vertexAttribPointer(uvLocation, 2, gl.FLOAT, false, 4 * 4, 2 * 4);
  110. gl.bindVertexArray(null);
  111. return vao;
  112. }}
  113. function runProgram(gl, kernelName, program, textures) {{
  114. let framebuffer = gl.createFramebuffer();
  115. gl.bindFramebuffer(gl.FRAMEBUFFER, framebuffer);
  116. gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, textures[0].tex, 0);
  117. gl.useProgram(program);
  118. gl.uniform1i(gl.getUniformLocation(program, "width"), textures[0].width);
  119. const vao = setupVertexData(gl, program, [-1, 1, 0, 1, -1, -1, 0, 0, 1, 1, 1, 1, 1, -1, 1, 0]);
  120. gl.bindVertexArray(vao);
  121. // Texture 0 is the framebuffer texture, so we skip that
  122. for (let i = 1; i < textures.length; i++) {{
  123. gl.activeTexture(gl.TEXTURE0 + i-1);
  124. gl.bindTexture(gl.TEXTURE_2D, textures[i].tex);
  125. gl.uniform1i(gl.getUniformLocation(program, 'data' + i), i-1);
  126. }}
  127. gl.viewport(0, 0, textures[0].width, textures[0].height);
  128. gl.drawArrays(gl.TRIANGLE_STRIP, 0, 4);
  129. gl.bindFramebuffer(gl.FRAMEBUFFER, null);
  130. for (let i = 1; i < textures.length; i++) {{
  131. gl.activeTexture(gl.TEXTURE0 + i-1);
  132. gl.bindTexture(gl.TEXTURE_2D, null);
  133. }}
  134. console.log("Finished running: " + kernelName);
  135. }}
  136. function limitTextureDims(size, threshold) {{
  137. if (size <= threshold) {{ return [size, 1] }};
  138. for (let i = 2; i < threshold + 1; i++) {{
  139. if ((size % i == 0) && (Math.floor(size / i) <= threshold)) {{
  140. return [Math.floor(size / i), i];
  141. }}
  142. }}
  143. return [size, 1];
  144. }}
  145. function updateTextureData(gl, texture, data, isHalf) {{
  146. gl.bindTexture(gl.TEXTURE_2D, texture.tex);
  147. gl.texSubImage2D(gl.TEXTURE_2D, 0, 0, 0, texture.width, texture.height, gl.RED, (isHalf) ? gl.HALF_FLOAT : gl.FLOAT, data);
  148. gl.bindTexture(gl.TEXTURE_2D, null);
  149. }}
  150. function readTextureData(gl, texture) {{
  151. const framebuffer = gl.createFramebuffer();
  152. gl.bindFramebuffer(gl.FRAMEBUFFER, framebuffer);
  153. gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture.tex, 0);
  154. if (gl.checkFramebufferStatus(gl.FRAMEBUFFER) !== gl.FRAMEBUFFER_COMPLETE) {{
  155. throw new Error('Framebuffer not complete');
  156. }}
  157. let data = new Float32Array(texture.width * texture.height);
  158. gl.readPixels(0, 0, texture.width, texture.height, gl.RED, gl.FLOAT, data);
  159. gl.bindFramebuffer(gl.FRAMEBUFFER, null);
  160. gl.deleteFramebuffer(framebuffer);
  161. return data;
  162. }}
  163. function createTexture(gl, size, isHalf, tensorBuffer) {{
  164. const texture = gl.createTexture();
  165. gl.bindTexture(gl.TEXTURE_2D, texture);
  166. const internalFormat = gl.RGBA;
  167. const texSize = limitTextureDims(size, gl.getParameter(gl.MAX_TEXTURE_SIZE));
  168. let weights;
  169. if (tensorBuffer != null) {{
  170. if (!isHalf)
  171. weights = new Float32Array(tensorBuffer.buffer, tensorBuffer.byteOffset, tensorBuffer.byteLength / Float32Array.BYTES_PER_ELEMENT);
  172. else
  173. weights = new Uint16Array(tensorBuffer.buffer, tensorBuffer.byteOffset, tensorBuffer.byteLength / Uint16Array.BYTES_PER_ELEMENT);
  174. }} else {{
  175. if (!isHalf)
  176. weights = new Float32Array(size).fill(0.0);
  177. else
  178. weights = new Uint16Array(size).fill(0.0);
  179. }}
  180. if (size != weights.length)
  181. console.log("Weights length: " + weights.length + ", texsize: " + texSize[0]*texSize[1]);
  182. gl.texImage2D(gl.TEXTURE_2D, 0, (isHalf) ? gl.R16F : gl.R32F, texSize[0], texSize[1], 0, gl.RED, (isHalf) ? gl.HALF_FLOAT : gl.FLOAT, weights);
  183. gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
  184. gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
  185. gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.NEAREST);
  186. gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.NEAREST);
  187. gl.bindTexture(gl.TEXTURE_2D, null);
  188. return {{ tex: texture, width: texSize[0], height: texSize[1] }};
  189. }}
  190. {web_utils["getTensorBuffer"]}
  191. {web_utils["getTensorMetadata"]}
  192. const metadata = getTensorMetadata(safetensor);
  193. """
  194. textures = '\n '.join([f"const {name} = " + (f"createTexture(gl, {size/(2 if dtype == dtypes.half else 4)}, {'true' if dtype == dtypes.half else 'false'});" if _key not in weight_names else f"createTexture(gl, {size/(2 if dtype == dtypes.half else 4)}, {'true' if dtype == dtypes.half else 'false'}, getTensorBuffer(safetensor, metadata['{weight_names[_key]}']))") + ";" for name,(size,dtype,_key) in bufs.items()])
  195. kernels = '\n\n'.join([f"const {key} = `{code.replace(key, 'main').replace('version 330', 'version 300 es')}`;" for key, code in functions.items()])
  196. kernel_names = ', '.join([name for (name, _args, _global_size, _local_size) in statements])
  197. kernel_calls = '\n '.join([f"runProgram(gl, '{name}', programs[{i}], [{', '.join(args)}]);" for i, (name, args, _global_size, _local_size) in enumerate(statements) ])
  198. copy_inputs = "\n".join([f'updateTextureData(gl, {name}, _{name}, {"true" if dtype == dtypes.half else "false"});' for name,(size,dtype,_key) in bufs.items() if "input" in name])
  199. entry_point = f"""
  200. return function({",".join([f"_{name}" for name,(size,dtype,_key) in bufs.items() if "input" in name])}) {{
  201. const ext = gl.getExtension('EXT_color_buffer_float');
  202. {copy_inputs}
  203. {kernel_calls}
  204. return readTextureData(gl, output0);
  205. }}
  206. """
  207. programs = f"let programs = [{kernel_names}].map((code) => createShaderProgram(gl, code));"
  208. return f"{header}\n{kernels}\n{textures}\n{programs}\n{entry_point}}}"
  209. def export_model_webgpu(functions, statements, bufs, bufs_to_save, weight_names, input_names, output_names) -> Tuple[str,int,int]:
  210. kernel_code = '\n\n'.join([f"const {key} = `{code.replace(key, 'main')}`;" for key, code in functions.items()])
  211. kernel_names = ', '.join([name for (name, _args, _global_size, _local_size) in statements])
  212. kernel_calls = '\n '.join([f"addComputePass(device, commandEncoder, piplines[{i}], [{', '.join(args)}], {global_size});" for i, (_name, args, global_size, _local_size) in enumerate(statements) ])
  213. _bufs = '\n '.join([f"const {name} = " + (f"createEmptyBuf(device, {size});" if _key not in weight_names else f"createWeightBuf(device, {size}, getTensorBuffer(safetensor, metadata['{weight_names[_key]}']))") + ";" for name,(size,dtype,_key) in bufs.items()])
  214. gpu_write_bufs = '\n '.join([f"const gpuWriteBuffer{i} = device.createBuffer({{size:{input_name}.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }});" for i,input_name in enumerate(input_names)])
  215. input_writers = '\n '.join([f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new Float32Array(gpuWriteBuffer{i}.getMappedRange()).set(" + f'_{inp_name});' + f"\n gpuWriteBuffer{i}.unmap();\n commandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, {inp_name}, 0, gpuWriteBuffer{i}.size);" for i,inp_name in enumerate(input_names)])
  216. gpu_read_bufs = '\n '.join([f"const gpuReadBuffer{i} = device.createBuffer({{size:{output_name}.size, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }});" for i,output_name in enumerate(output_names)])
  217. outbuf_copies = '\n '.join([f"commandEncoder.copyBufferToBuffer({output_name}, 0, gpuReadBuffer{i}, 0, output{i}.size);" for i,output_name in enumerate(output_names)])
  218. output_readers = '\n '.join([f"await gpuReadBuffer{i}.mapAsync(GPUMapMode.READ);\n const resultBuffer{i} = new Float32Array(gpuReadBuffer{i}.size);\n resultBuffer{i}.set(new Float32Array(gpuReadBuffer{i}.getMappedRange()));\n gpuReadBuffer{i}.unmap();" for i in range(len(output_names))])
  219. output_return = '[{}]'.format(",".join([f'resultBuffer{i}' for i in range(len(output_names))]))
  220. return f"""
  221. {web_utils["getTensorBuffer"]}
  222. {web_utils["getTensorMetadata"]}
  223. const createEmptyBuf = (device, size) => {{
  224. return device.createBuffer({{size, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST }});
  225. }};
  226. const createWeightBuf = (device, size, data) => {{
  227. const buf = device.createBuffer({{ mappedAtCreation: true, size, usage: GPUBufferUsage.STORAGE }});
  228. new Uint8Array(buf.getMappedRange()).set(data);
  229. buf.unmap();
  230. return buf;
  231. }};
  232. const addComputePass = (device, commandEncoder, pipeline, bufs, workgroup) => {{
  233. const bindGroup = device.createBindGroup({{layout: pipeline.getBindGroupLayout(0), entries: bufs.map((buffer, index) => ({{ binding: index, resource: {{ buffer }} }}))}});
  234. const passEncoder = commandEncoder.beginComputePass();
  235. passEncoder.setPipeline(pipeline);
  236. passEncoder.setBindGroup(0, bindGroup);
  237. passEncoder.dispatchWorkgroups(...workgroup);
  238. passEncoder.end();
  239. }};
  240. {kernel_code}
  241. const setupNet = async (device, safetensor) => {{
  242. const metadata = getTensorMetadata(safetensor);
  243. {_bufs}
  244. {gpu_write_bufs}
  245. {gpu_read_bufs}
  246. const kernels = [{kernel_names}];
  247. const piplines = await Promise.all(kernels.map(name => device.createComputePipelineAsync({{layout: "auto", compute: {{ module: device.createShaderModule({{ code: name }}), entryPoint: "main" }}}})));
  248. return async ({",".join([f"_{input_name}" for input_name in input_names])}) => {{
  249. const commandEncoder = device.createCommandEncoder();
  250. {input_writers}
  251. {kernel_calls}
  252. {outbuf_copies}
  253. const gpuCommands = commandEncoder.finish();
  254. device.queue.submit([gpuCommands]);
  255. {output_readers}
  256. return {output_return};
  257. }}
  258. }}
  259. """ + f"\n\nconst loadNet = async (device) => {{ return await fetch('net.safetensors').then(x => x.arrayBuffer()).then(x => setupNet(device, new Uint8Array(x))); }}"
  260. def export_model(model, target:str, *inputs):
  261. assert Device.DEFAULT in EXPORT_SUPPORTED_DEVICE, "only WEBGPU, WEBGL, CLANG, CUDA, GPU, METAL are supported"
  262. with Context(JIT=2): run,special_names = jit_model(model, *inputs)
  263. functions, statements, bufs, bufs_to_save = compile_net(run, special_names)
  264. state = get_state_dict(model)
  265. weight_names = {id(x.lazydata.base.realized): name for name, x in state.items()}
  266. input_names = [name for _,name in special_names.items() if "input" in name]
  267. output_names = [name for _,name in special_names.items() if "output" in name]
  268. prg = ""
  269. if target == "clang":
  270. prg = export_model_clang(functions, statements, bufs, bufs_to_save, input_names, output_names)
  271. elif target == "webgpu":
  272. prg = export_model_webgpu(functions, statements, bufs, bufs_to_save, weight_names, input_names, output_names)
  273. elif target == "webgl":
  274. prg = export_model_webgl(functions, statements, bufs, bufs_to_save, weight_names, input_names, output_names)
  275. else:
  276. prg = json.dumps({
  277. "backend": Device.DEFAULT,
  278. "inputs": [{
  279. "size": bufs[name][0],
  280. "dtype": bufs[name][1].name
  281. } for name in input_names],
  282. "outputs": [{
  283. "size": bufs[name][0],
  284. "dtype": bufs[name][1].name
  285. } for name in output_names],
  286. "functions": functions,
  287. "statements": [{
  288. "kernel": kernel,
  289. "args": args,
  290. "global_size": global_size,
  291. "local_size": local_size
  292. } for (kernel, args, global_size, local_size) in statements],
  293. "buffers": {
  294. name: {
  295. "size": size,
  296. "dtype": dtype.name,
  297. "id": weight_names[_key] if _key in weight_names else ""
  298. } for name, (size,dtype,_key) in bufs.items() if name not in ["input", "outputs"]
  299. }
  300. })
  301. return prg, {input:bufs[input][0] for input in input_names}, {output:bufs[output][0] for output in output_names}, state