index.html 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511
  1. <!DOCTYPE html>
  2. <html lang="en">
  3. <head>
  4. <meta charset="UTF-8">
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0">
  6. <title>tinygrad has WebGPU</title>
  7. <style>
  8. body {
  9. font-family: 'Arial', sans-serif;
  10. text-align: center;
  11. padding: 30px;
  12. }
  13. a {
  14. text-decoration: none;
  15. color: #4A90E2;
  16. }
  17. h1 {
  18. font-size: 36px;
  19. font-weight: normal;
  20. margin-bottom: 20px;
  21. }
  22. #mybox {
  23. display: flex;
  24. flex-direction: column;
  25. align-items: center;
  26. gap: 20px;
  27. width: 50%;
  28. margin: 0 auto;
  29. }
  30. #promptText, #stepRange, #btnRunNet, #guidanceRange {
  31. font-size: 18px;
  32. width: 100%;
  33. }
  34. #result {
  35. font-size: 48px;
  36. }
  37. #time {
  38. font-size: 16px;
  39. color: grey;
  40. }
  41. canvas {
  42. margin-top: 20px;
  43. border: 1px solid #000;
  44. }
  45. label {
  46. display: flex;
  47. align-items: center;
  48. gap: 10px;
  49. width: 100%;
  50. }
  51. #sliderValue {
  52. margin-right: 10px;/
  53. }
  54. </style>
  55. <script type="module">
  56. import ClipTokenizer from 'https://softwired.nyc3.cdn.digitaloceanspaces.com/sd/clip_tokenizer.js';
  57. window.clipTokenizer = new ClipTokenizer();
  58. </script>
  59. <script src="./f16_to_f32.js"></script>
  60. <script src="./net.js"></script>
  61. </head>
  62. <body>
  63. <h1 id="wgpuError" style="display: none; color: red;">WebGPU is not supported in this browser</h1>
  64. <h1 id="sdTitle">StableDiffusion by <a href="https://github.com/tinygrad/tinygrad" target="_blank">tinygrad</a> WebGPU</h1>
  65. <div id="mybox">
  66. <input id="promptText" type="text" placeholder="Enter your prompt here" value="a horse sized cat eating a bagel">
  67. <label>
  68. Steps: <span id="stepValue">8</span>
  69. <input id="stepRange" type="range" min="5" max="20" value="8" step="1">
  70. </label>
  71. <label>
  72. Guidance: <span id="guidanceValue">7.5</span>
  73. <input id="guidanceRange" type="range" min="3" max="15" value="7.5" step="0.1">
  74. </label>
  75. <input id="btnRunNet" type="button" value="Run" disabled>
  76. <div id="divModelDl" style="display: flex; align-items: center; width: 100%; gap: 10px;">
  77. <span id="modelDlTitle">Downloading model</span>
  78. <progress id="modelDlProgressBar" value="0" max="100" style="flex-grow: 1;"></progress>
  79. <span id="modelDlProgressValue"></span>
  80. </div>
  81. <div id="divStepProgress" style="display: none; align-items: center; width: 100%; gap: 10px;">
  82. <progress id="progressBar" value="0" max="100" style="flex-grow: 1;"></progress>
  83. <span id="progressFraction"></span>
  84. </div>
  85. </div>
  86. <canvas id="canvas" width="512" height="512"></canvas>
  87. <script>
  88. function initDb() {
  89. return new Promise((resolve, reject) => {
  90. let db;
  91. const request = indexedDB.open('tinydb', 1);
  92. request.onerror = (event) => {
  93. console.error('Database error:', event.target.error);
  94. resolve(null);
  95. };
  96. request.onsuccess = (event) => {
  97. db = event.target.result;
  98. console.log("Db initialized.");
  99. resolve(db);
  100. };
  101. request.onupgradeneeded = (event) => {
  102. db = event.target.result;
  103. if (!db.objectStoreNames.contains('tensors')) {
  104. db.createObjectStore('tensors', { keyPath: 'id' });
  105. }
  106. };
  107. });
  108. }
  109. function saveTensorToDb(db, id, tensor) {
  110. return new Promise((resolve, reject) => {
  111. if (db == null) {
  112. resolve(null);
  113. }
  114. const transaction = db.transaction(['tensors'], 'readwrite');
  115. const store = transaction.objectStore('tensors');
  116. const request = store.put({ id: id, content: tensor });
  117. transaction.onabort = (event) => {
  118. console.log("Transaction error while saving tensor: " + event.target.error);
  119. resolve(null);
  120. };
  121. request.onsuccess = () => {
  122. console.log('Tensor saved successfully.');
  123. resolve();
  124. };
  125. request.onerror = (event) => {
  126. console.error('Tensor save failed:', event.target.error);
  127. resolve(null);
  128. };
  129. });
  130. }
  131. function readTensorFromDb(db, id) {
  132. return new Promise((resolve, reject) => {
  133. if (db == null) {
  134. resolve(null);
  135. }
  136. const transaction = db.transaction(['tensors'], 'readonly');
  137. const store = transaction.objectStore('tensors');
  138. const request = store.get(id);
  139. transaction.onabort = (event) => {
  140. console.log("Transaction error while reading tensor: " + event.target.error);
  141. resolve(null);
  142. };
  143. request.onsuccess = (event) => {
  144. const result = event.target.result;
  145. if (result) {
  146. console.log("Cache hit: " + id);
  147. resolve(result);
  148. } else {
  149. console.log("Cache miss: " + id);
  150. resolve(null);
  151. }
  152. };
  153. request.onerror = (event) => {
  154. console.error('Tensor retrieve failed: ', event.target.error);
  155. resolve(null);
  156. };
  157. });
  158. }
  159. window.addEventListener('load', async function() {
  160. if (!navigator.gpu) {
  161. document.getElementById("wgpuError").style.display = "";
  162. document.getElementById("sdTitle").style.display = "none";
  163. return;
  164. }
  165. let db = await initDb();
  166. const ctx = document.getElementById("canvas").getContext("2d", { willReadFrequently: true });
  167. let labels, nets, safetensorParts;
  168. const getDevice = async () => {
  169. const adapter = await navigator.gpu.requestAdapter();
  170. const requiredLimits = {};
  171. const maxBufferSizeInSDModel = 1073741824;
  172. requiredLimits.maxStorageBufferBindingSize = maxBufferSizeInSDModel;
  173. requiredLimits.maxBufferSize = maxBufferSizeInSDModel;
  174. return await adapter.requestDevice({
  175. requiredLimits
  176. });
  177. };
  178. const timer = async (func, label = "") => {
  179. const start = performance.now();
  180. const out = await func();
  181. const delta = (performance.now() - start).toFixed(1)
  182. console.log(`${delta} ms ${label}`);
  183. return out;
  184. }
  185. const getProgressDlForPart = async (part, progressCallback) => {
  186. const response = await fetch(part);
  187. const contentLength = response.headers.get('content-length');
  188. const total = parseInt(contentLength, 10);
  189. const res = new Response(new ReadableStream({
  190. async start(controller) {
  191. const reader = response.body.getReader();
  192. for (;;) {
  193. const { done, value } = await reader.read();
  194. if (done) break;
  195. progressCallback(part, value.byteLength, total);
  196. controller.enqueue(value);
  197. }
  198. controller.close();
  199. },
  200. }));
  201. return res.arrayBuffer();
  202. };
  203. const getAndDecompressF16Safetensors = async (device, progress) => {
  204. let totalLoaded = 0;
  205. let totalSize = 0;
  206. let partSize = {};
  207. const progressCallback = (part, loaded, total) => {
  208. totalLoaded += loaded;
  209. if (!partSize[part]) {
  210. totalSize += total;
  211. partSize[part] = true;
  212. }
  213. progress(totalLoaded, totalSize);
  214. };
  215. let combinedBuffer = await readTensorFromDb(db, "net.f16");
  216. let textModelU8 = await readTensorFromDb(db, "net.text");
  217. let textModelFetched = false;
  218. if (combinedBuffer == null) {
  219. let dlParts = [
  220. getProgressDlForPart(window.MODEL_BASE_URL + '/net_part0.safetensors', progressCallback),
  221. getProgressDlForPart(window.MODEL_BASE_URL + '/net_part1.safetensors', progressCallback),
  222. getProgressDlForPart(window.MODEL_BASE_URL + '/net_part2.safetensors', progressCallback),
  223. getProgressDlForPart(window.MODEL_BASE_URL + '/net_part3.safetensors', progressCallback)
  224. ];
  225. if (textModelU8 == null) {
  226. dlParts.push(getProgressDlForPart(window.MODEL_BASE_URL + '/net_textmodel.safetensors', progressCallback));
  227. }
  228. let buffers = await Promise.all(dlParts);
  229. // Combine everything except for text model, since that's alreafy f32
  230. const totalLength = buffers.reduce((acc, buffer, index, array) => {
  231. if (index < 4) {
  232. return acc + buffer.byteLength;
  233. } else {
  234. return acc;
  235. }
  236. }, 0
  237. );
  238. combinedBuffer = new Uint8Array(totalLength);
  239. let offset = 0;
  240. buffers.forEach((buffer, index) => {
  241. if (index < 4) {
  242. combinedBuffer.set(new Uint8Array(buffer), offset);
  243. offset += buffer.byteLength;
  244. buffer = null;
  245. }
  246. });
  247. await saveTensorToDb(db, "net.f16", combinedBuffer);
  248. if (textModelU8 == null) {
  249. textModelFetched = true;
  250. textModelU8 = new Uint8Array(buffers[4]);
  251. await saveTensorToDb(db, "net.text", textModelU8);
  252. }
  253. } else {
  254. combinedBuffer = combinedBuffer.content;
  255. }
  256. if (textModelU8 == null) {
  257. textModelU8 = new Uint8Array(await getProgressDlForPart(window.MODEL_BASE_URL + '/net_textmodel.safetensors', progressCallback));
  258. await saveTensorToDb(db, "net.text", textModelU8);
  259. } else if (!textModelFetched) {
  260. textModelU8 = textModelU8.content;
  261. }
  262. document.getElementById("modelDlTitle").innerHTML = "Decompressing model";
  263. const textModelOffset = 3772703308;
  264. const metadataLength = Number(new DataView(combinedBuffer.buffer).getBigUint64(0, true));
  265. const metadata = JSON.parse(new TextDecoder("utf8").decode(combinedBuffer.subarray(8, 8 + metadataLength)));
  266. const allToDecomp = combinedBuffer.byteLength - (8 + metadataLength);
  267. const decodeChunkSize = 67107840;
  268. const numChunks = Math.ceil(allToDecomp/decodeChunkSize);
  269. console.log(allToDecomp + " bytes to decompress");
  270. console.log("Will be decompressed in " + numChunks+ " chunks");
  271. let partOffsets = [{start: 0, end: 1131408336}, {start: 1131408336, end: 2227518416}, {start: 2227518416, end: 3308987856}, {start: 3308987856, end: 4265298864}];
  272. let parts = [];
  273. for (let offsets of partOffsets) {
  274. parts.push(new Uint8Array(offsets.end-offsets.start));
  275. }
  276. parts[0].set(new Uint8Array(new BigUint64Array([BigInt(metadataLength)]).buffer), 0);
  277. parts[0].set(combinedBuffer.subarray(8, 8 + metadataLength), 8);
  278. parts[3].set(textModelU8, textModelOffset+8+metadataLength - partOffsets[3].start);
  279. let start = Date.now();
  280. let cursor = 0;
  281. for (let i = 0; i < numChunks; i++) {
  282. progress(i, numChunks);
  283. let chunkStartF16 = 8 + metadataLength + (decodeChunkSize * i);
  284. let chunkEndF16 = chunkStartF16 + decodeChunkSize;
  285. let chunk = combinedBuffer.subarray(chunkStartF16, chunkEndF16);
  286. if (chunk.byteLength %4 != 0) {
  287. const paddingBytes = 4 - (chunk.byteLength % 4);
  288. const alignedBuffer = new ArrayBuffer(chunk.byteLength + paddingBytes);
  289. const alignedView = new Uint8Array(alignedBuffer);
  290. alignedView.set(new Uint8Array(chunk));
  291. chunk = alignedView;
  292. }
  293. let result = await f16tof32GPU(device, chunk);
  294. let resultUint8 = new Uint8Array(result.buffer);
  295. let chunkStartF32 = 8 + metadataLength + (decodeChunkSize * i * 2);
  296. let chunkEndF32 = chunkStartF32 + resultUint8.byteLength;
  297. let offsetInPart = chunkStartF32 - partOffsets[cursor].start;
  298. if (chunkEndF32 < partOffsets[cursor].end || cursor === parts.length - 1) {
  299. parts[cursor].set(resultUint8, offsetInPart);
  300. } else {
  301. let spaceLeftInCurrentPart = partOffsets[cursor].end - chunkStartF32;
  302. parts[cursor].set(resultUint8.subarray(0, spaceLeftInCurrentPart), offsetInPart);
  303. cursor++;
  304. if (cursor < parts.length) {
  305. let nextPartOffset = spaceLeftInCurrentPart;
  306. let nextPartLength = resultUint8.length - nextPartOffset;
  307. parts[cursor].set(resultUint8.subarray(nextPartOffset, nextPartOffset + nextPartLength), 0);
  308. }
  309. }
  310. resultUint8 = null;
  311. result = null;
  312. }
  313. combinedBuffer = null;
  314. let end = Date.now();
  315. console.log("Decoding took: " + ((end - start) / 1000) + " s");
  316. console.log("Avarage " + ((end - start) / numChunks) + " ms per chunk");
  317. return parts;
  318. };
  319. const loadNet = async () => {
  320. const modelDlTitle = document.getElementById("modelDlTitle");
  321. const progress = (loaded, total) => {
  322. document.getElementById("modelDlProgressBar").value = (loaded/total) * 100
  323. document.getElementById("modelDlProgressValue").innerHTML = Math.trunc((loaded/total) * 100) + "%"
  324. }
  325. const device = await getDevice();
  326. safetensorParts = await getAndDecompressF16Safetensors(device, progress);
  327. modelDlTitle.innerHTML = "Compiling model"
  328. let models = ["textModel", "diffusor", "decoder"];
  329. nets = await timer(() => Promise.all([
  330. textModel().setup(device, safetensorParts),
  331. diffusor().setup(device, safetensorParts),
  332. decoder().setup(device, safetensorParts)
  333. ]).then((loadedModels) => loadedModels.reduce((acc, model, index) => { acc[models[index]] = model; return acc; }, {})), "(compilation)")
  334. progress(1, 1);
  335. modelDlTitle.innerHTML = "Model ready"
  336. setTimeout(() => {
  337. document.getElementById("modelDlProgressBar").style.display = "none";
  338. document.getElementById("modelDlProgressValue").style.display = "none";
  339. document.getElementById("divStepProgress").style.display = "flex";
  340. }, 1000);
  341. document.getElementById("btnRunNet").disabled = false;
  342. }
  343. function runStableDiffusion(prompt, steps, guidance) {
  344. return new Promise(async (resolve, reject) => {
  345. let context = await timer(() => nets["textModel"](clipTokenizer.encodeForCLIP(prompt)));
  346. let unconditional_context = await timer(() => nets["textModel"](clipTokenizer.encodeForCLIP("")));
  347. let timesteps = [];
  348. for (let i = 1; i < 1000; i += (1000/steps)) {
  349. timesteps.push(i);
  350. }
  351. console.log("Timesteps: " + timesteps);
  352. let alphasCumprod = getWeight(safetensorParts,"alphas_cumprod");
  353. let alphas = [];
  354. for (t of timesteps) {
  355. alphas.push(alphasCumprod[Math.floor(t)]);
  356. }
  357. alphas_prev = [1.0];
  358. for (let i = 0; i < alphas.length-1; i++) {
  359. alphas_prev.push(alphas[i]);
  360. }
  361. let inpSize = 4*64*64;
  362. latent = new Float32Array(inpSize);
  363. for (let i = 0; i < inpSize; i++) {
  364. latent[i] = Math.sqrt(-2.0 * Math.log(Math.random())) * Math.cos(2.0 * Math.PI * Math.random());
  365. }
  366. for (let i = timesteps.length - 1; i >= 0; i--) {
  367. let timestep = new Float32Array([timesteps[i]]);
  368. let x_prev = await timer(() => nets["diffusor"](unconditional_context, context, latent, timestep, new Float32Array([alphas[i]]), new Float32Array([alphas_prev[i]]), new Float32Array([guidance])));
  369. latent = x_prev;
  370. document.getElementById("progressBar").value = ((steps - i) / steps) * 100
  371. document.getElementById("progressFraction").innerHTML = (steps - i) + "/" + steps
  372. }
  373. resolve(await timer(() => nets["decoder"](latent)));
  374. });
  375. }
  376. document.getElementById("btnRunNet").addEventListener("click", function(e) {
  377. e.target.disabled = true;
  378. runStableDiffusion(document.getElementById("promptText").value, document.getElementById("stepRange").value, document.getElementById("guidanceRange").value).then((image) => {
  379. let pixels = []
  380. let pixelCounter = 0
  381. for (let j = 0; j < 512; j++) {
  382. for (let k = 0; k < 512; k++) {
  383. pixels.push(image[pixelCounter])
  384. pixels.push(image[pixelCounter+1])
  385. pixels.push(image[pixelCounter+2])
  386. pixels.push(255)
  387. pixelCounter += 3
  388. }
  389. }
  390. ctx.putImageData(new ImageData(new Uint8ClampedArray(pixels), 512, 512), 0, 0);
  391. console.log(image);
  392. console.log("Success");
  393. e.target.disabled = false;
  394. });
  395. }, false);
  396. const stepSlider = document.getElementById('stepRange');
  397. const stepValue = document.getElementById('stepValue');
  398. stepSlider.addEventListener('input', function() {
  399. stepValue.textContent = stepSlider.value;
  400. });
  401. const guidanceSlider = document.getElementById('guidanceRange');
  402. const guidanceValue = document.getElementById('guidanceValue');
  403. guidanceSlider.addEventListener('input', function() {
  404. guidanceValue.textContent = guidanceSlider.value;
  405. });
  406. loadNet();
  407. });
  408. </script>
  409. </body>
  410. </html>