compile.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. import os
  2. from extra.export_model import compile_net, jit_model
  3. from examples.stable_diffusion import StableDiffusion
  4. from tinygrad.nn.state import get_state_dict, safe_save, safe_load_metadata, torch_load, load_state_dict
  5. from tinygrad.tensor import Tensor
  6. from tinygrad import Device
  7. from tinygrad.helpers import fetch
  8. from typing import NamedTuple, Any, List
  9. from pathlib import Path
  10. import argparse
  11. import numpy as np
  12. def convert_f32_to_f16(input_file, output_file):
  13. with open(input_file, 'rb') as f:
  14. metadata_length_bytes = f.read(8)
  15. metadata_length = int.from_bytes(metadata_length_bytes, byteorder='little', signed=False)
  16. metadata_json_bytes = f.read(metadata_length)
  17. float32_values = np.fromfile(f, dtype=np.float32)
  18. first_text_model_offset = 3772703308
  19. num_elements = int((first_text_model_offset)/4)
  20. front_float16_values = float32_values[:num_elements].astype(np.float16)
  21. rest_float32_values = float32_values[num_elements:]
  22. with open(output_file, 'wb') as f:
  23. f.write(metadata_length_bytes)
  24. f.write(metadata_json_bytes)
  25. front_float16_values.tofile(f)
  26. rest_float32_values.tofile(f)
  27. def split_safetensor(fn):
  28. _, json_len, metadata = safe_load_metadata(fn)
  29. text_model_offset = 3772703308
  30. chunk_size = 536870912
  31. for k in metadata:
  32. # safetensor is in fp16, except for text moel
  33. if (metadata[k]["data_offsets"][0] < text_model_offset):
  34. metadata[k]["data_offsets"][0] = int(metadata[k]["data_offsets"][0]/2)
  35. metadata[k]["data_offsets"][1] = int(metadata[k]["data_offsets"][1]/2)
  36. last_offset = 0
  37. part_end_offsets = []
  38. for k in metadata:
  39. offset = metadata[k]['data_offsets'][0]
  40. if offset == text_model_offset:
  41. break
  42. part_offset = offset - last_offset
  43. if (part_offset >= chunk_size):
  44. part_end_offsets.append(8+json_len+offset)
  45. last_offset = offset
  46. text_model_start = int(text_model_offset/2)
  47. net_bytes = bytes(open(fn, 'rb').read())
  48. part_end_offsets.append(text_model_start+8+json_len)
  49. cur_pos = 0
  50. for i, end_pos in enumerate(part_end_offsets):
  51. with open(f'./net_part{i}.safetensors', "wb+") as f:
  52. f.write(net_bytes[cur_pos:end_pos])
  53. cur_pos = end_pos
  54. with open(f'./net_textmodel.safetensors', "wb+") as f:
  55. f.write(net_bytes[text_model_start+8+json_len:])
  56. return part_end_offsets
  57. if __name__ == "__main__":
  58. parser = argparse.ArgumentParser(description='Run Stable Diffusion', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  59. parser.add_argument('--remoteweights', action='store_true', help="Use safetensors from Huggingface, or from local")
  60. args = parser.parse_args()
  61. Device.DEFAULT = "WEBGPU"
  62. Tensor.no_grad = True
  63. model = StableDiffusion()
  64. # load in weights
  65. load_state_dict(model, torch_load(fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt'))['state_dict'], strict=False)
  66. class Step(NamedTuple):
  67. name: str = ""
  68. input: List[Tensor] = []
  69. forward: Any = None
  70. sub_steps = [
  71. Step(name = "textModel", input = [Tensor.randn(1, 77)], forward = model.cond_stage_model.transformer.text_model),
  72. Step(name = "diffusor", input = [Tensor.randn(1, 77, 768), Tensor.randn(1, 77, 768), Tensor.randn(1,4,64,64), Tensor.rand(1), Tensor.randn(1), Tensor.randn(1), Tensor.randn(1)], forward = model),
  73. Step(name = "decoder", input = [Tensor.randn(1,4,64,64)], forward = model.decode)
  74. ]
  75. prg = ""
  76. def compile_step(model, step: Step):
  77. run, special_names = jit_model(step, *step.input)
  78. functions, statements, bufs, _ = compile_net(run, special_names)
  79. state = get_state_dict(model)
  80. weights = {id(x.lazydata.base.realized): name for name, x in state.items()}
  81. kernel_code = '\n\n'.join([f"const {key} = `{code.replace(key, 'main')}`;" for key, code in functions.items()])
  82. kernel_names = ', '.join([name for (name, _, _, _) in statements])
  83. kernel_calls = '\n '.join([f"addComputePass(device, commandEncoder, piplines[{i}], [{', '.join(args)}], {global_size});" for i, (_name, args, global_size, _local_size) in enumerate(statements) ])
  84. bufs = '\n '.join([f"const {name} = " + (f"createEmptyBuf(device, {size});" if _key not in weights else f"createWeightBuf(device, {size}, getTensorBuffer(safetensor, metadata['{weights[_key]}'], '{weights[_key]}'))") + ";" for name,(size,dtype,_key) in bufs.items()])
  85. gpu_write_bufs = '\n '.join([f"const gpuWriteBuffer{i} = device.createBuffer({{size:input{i}.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }});" for i,(_,value) in enumerate(special_names.items()) if "output" not in value])
  86. input_writer = '\n '.join([f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new Float32Array(gpuWriteBuffer{i}.getMappedRange()).set(" + f'data{i});' + f"\n gpuWriteBuffer{i}.unmap();\ncommandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, input{i}, 0, gpuWriteBuffer{i}.size);" for i,(_,value) in enumerate(special_names.items()) if value != "output0"])
  87. return f"""\n var {step.name} = function() {{
  88. {kernel_code}
  89. return {{
  90. "setup": async (device, safetensor) => {{
  91. const metadata = getTensorMetadata(safetensor[0]);
  92. {bufs}
  93. {gpu_write_bufs}
  94. const gpuReadBuffer = device.createBuffer({{ size: output0.size, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }});
  95. const kernels = [{kernel_names}];
  96. const piplines = await Promise.all(kernels.map(name => device.createComputePipelineAsync({{layout: "auto", compute: {{ module: device.createShaderModule({{ code: name }}), entryPoint: "main" }}}})));
  97. return async ({",".join([f'data{i}' for i,(k,v) in enumerate(special_names.items()) if v != "output0"])}) => {{
  98. const commandEncoder = device.createCommandEncoder();
  99. {input_writer}
  100. {kernel_calls}
  101. commandEncoder.copyBufferToBuffer(output0, 0, gpuReadBuffer, 0, output0.size);
  102. const gpuCommands = commandEncoder.finish();
  103. device.queue.submit([gpuCommands]);
  104. await gpuReadBuffer.mapAsync(GPUMapMode.READ);
  105. const resultBuffer = new Float32Array(gpuReadBuffer.size/4);
  106. resultBuffer.set(new Float32Array(gpuReadBuffer.getMappedRange()));
  107. gpuReadBuffer.unmap();
  108. return resultBuffer;
  109. }}
  110. }}
  111. }}
  112. }}
  113. """
  114. for step in sub_steps:
  115. print(f'Executing step={step.name}')
  116. prg += compile_step(model, step)
  117. if step.name == "diffusor":
  118. if args.remoteweights:
  119. base_url = "https://huggingface.co/wpmed/tinygrad-sd-f16/resolve/main"
  120. else:
  121. state = get_state_dict(model)
  122. safe_save(state, os.path.join(os.path.dirname(__file__), "net.safetensors"))
  123. convert_f32_to_f16("./net.safetensors", "./net_conv.safetensors")
  124. split_safetensor("./net_conv.safetensors")
  125. os.remove("net.safetensors")
  126. os.remove("net_conv.safetensors")
  127. base_url = "."
  128. prekernel = f"""
  129. window.MODEL_BASE_URL= "{base_url}";
  130. const getTensorMetadata = (safetensorBuffer) => {{
  131. const metadataLength = Number(new DataView(safetensorBuffer.buffer).getBigUint64(0, true));
  132. const metadata = JSON.parse(new TextDecoder("utf8").decode(safetensorBuffer.subarray(8, 8 + metadataLength)));
  133. return Object.fromEntries(Object.entries(metadata).filter(([k, v]) => k !== "__metadata__").map(([k, v]) => [k, {{...v, data_offsets: v.data_offsets.map(x => 8 + metadataLength + x)}}]));
  134. }};
  135. const getTensorBuffer = (safetensorParts, tensorMetadata, key) => {{
  136. let selectedPart = 0;
  137. let counter = 0;
  138. let partStartOffsets = [1131408336, 2227518416, 3308987856, 4265298864];
  139. let correctedOffsets = tensorMetadata.data_offsets;
  140. let prev_offset = 0;
  141. for (let start of partStartOffsets) {{
  142. prev_offset = (counter == 0) ? 0 : partStartOffsets[counter-1];
  143. if (tensorMetadata.data_offsets[0] < start) {{
  144. selectedPart = counter;
  145. correctedOffsets = [correctedOffsets[0]-prev_offset, correctedOffsets[1]-prev_offset];
  146. break;
  147. }}
  148. counter++;
  149. }}
  150. let allZero = true;
  151. let out = safetensorParts[selectedPart].subarray(...correctedOffsets);
  152. for (let i = 0; i < out.length; i++) {{
  153. if (out[i] !== 0) {{
  154. allZero = false;
  155. break;
  156. }}
  157. }}
  158. if (allZero) {{
  159. console.log("Error: weight '" + key + "' is all zero.");
  160. }}
  161. return safetensorParts[selectedPart].subarray(...correctedOffsets);
  162. }}
  163. const getWeight = (safetensors, key) => {{
  164. let uint8Data = getTensorBuffer(safetensors, getTensorMetadata(safetensors[0])[key], key);
  165. return new Float32Array(uint8Data.buffer, uint8Data.byteOffset, uint8Data.byteLength / Float32Array.BYTES_PER_ELEMENT);
  166. }}
  167. const createEmptyBuf = (device, size) => {{
  168. return device.createBuffer({{size, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST }});
  169. }};
  170. const createWeightBuf = (device, size, data) => {{
  171. const buf = device.createBuffer({{ mappedAtCreation: true, size, usage: GPUBufferUsage.STORAGE }});
  172. new Uint8Array(buf.getMappedRange()).set(data);
  173. buf.unmap();
  174. return buf;
  175. }};
  176. const addComputePass = (device, commandEncoder, pipeline, bufs, workgroup) => {{
  177. const bindGroup = device.createBindGroup({{layout: pipeline.getBindGroupLayout(0), entries: bufs.map((buffer, index) => ({{ binding: index, resource: {{ buffer }} }}))}});
  178. const passEncoder = commandEncoder.beginComputePass();
  179. passEncoder.setPipeline(pipeline);
  180. passEncoder.setBindGroup(0, bindGroup);
  181. passEncoder.dispatchWorkgroups(...workgroup);
  182. passEncoder.end();
  183. }};"""
  184. with open(os.path.join(os.path.dirname(__file__), "net.js"), "w") as text_file:
  185. text_file.write(prekernel + prg)