| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511 |
- <!DOCTYPE html>
- <html lang="en">
- <head>
- <meta charset="UTF-8">
- <meta name="viewport" content="width=device-width, initial-scale=1.0">
- <title>tinygrad has WebGPU</title>
- <style>
- body {
- font-family: 'Arial', sans-serif;
- text-align: center;
- padding: 30px;
- }
- a {
- text-decoration: none;
- color: #4A90E2;
- }
- h1 {
- font-size: 36px;
- font-weight: normal;
- margin-bottom: 20px;
- }
- #mybox {
- display: flex;
- flex-direction: column;
- align-items: center;
- gap: 20px;
- width: 50%;
- margin: 0 auto;
- }
- #promptText, #stepRange, #btnRunNet, #guidanceRange {
- font-size: 18px;
- width: 100%;
- }
- #result {
- font-size: 48px;
- }
- #time {
- font-size: 16px;
- color: grey;
- }
- canvas {
- margin-top: 20px;
- border: 1px solid #000;
- }
- label {
- display: flex;
- align-items: center;
- gap: 10px;
- width: 100%;
- }
-
- #sliderValue {
- margin-right: 10px;/
- }
- </style>
- <script type="module">
- import ClipTokenizer from 'https://softwired.nyc3.cdn.digitaloceanspaces.com/sd/clip_tokenizer.js';
- window.clipTokenizer = new ClipTokenizer();
- </script>
- <script src="./f16_to_f32.js"></script>
- <script src="./net.js"></script>
- </head>
- <body>
- <h1 id="wgpuError" style="display: none; color: red;">WebGPU is not supported in this browser</h1>
- <h1 id="sdTitle">StableDiffusion by <a href="https://github.com/tinygrad/tinygrad" target="_blank">tinygrad</a> WebGPU</h1>
- <div id="mybox">
- <input id="promptText" type="text" placeholder="Enter your prompt here" value="a horse sized cat eating a bagel">
-
- <label>
- Steps: <span id="stepValue">8</span>
- <input id="stepRange" type="range" min="5" max="20" value="8" step="1">
- </label>
- <label>
- Guidance: <span id="guidanceValue">7.5</span>
- <input id="guidanceRange" type="range" min="3" max="15" value="7.5" step="0.1">
- </label>
-
- <input id="btnRunNet" type="button" value="Run" disabled>
- <div id="divModelDl" style="display: flex; align-items: center; width: 100%; gap: 10px;">
- <span id="modelDlTitle">Downloading model</span>
- <progress id="modelDlProgressBar" value="0" max="100" style="flex-grow: 1;"></progress>
- <span id="modelDlProgressValue"></span>
- </div>
- <div id="divStepProgress" style="display: none; align-items: center; width: 100%; gap: 10px;">
- <progress id="progressBar" value="0" max="100" style="flex-grow: 1;"></progress>
- <span id="progressFraction"></span>
- </div>
- </div>
- <canvas id="canvas" width="512" height="512"></canvas>
- <script>
- function initDb() {
- return new Promise((resolve, reject) => {
- let db;
- const request = indexedDB.open('tinydb', 1);
- request.onerror = (event) => {
- console.error('Database error:', event.target.error);
- resolve(null);
- };
- request.onsuccess = (event) => {
- db = event.target.result;
- console.log("Db initialized.");
- resolve(db);
- };
- request.onupgradeneeded = (event) => {
- db = event.target.result;
- if (!db.objectStoreNames.contains('tensors')) {
- db.createObjectStore('tensors', { keyPath: 'id' });
- }
- };
- });
- }
- function saveTensorToDb(db, id, tensor) {
- return new Promise((resolve, reject) => {
- if (db == null) {
- resolve(null);
- }
- const transaction = db.transaction(['tensors'], 'readwrite');
- const store = transaction.objectStore('tensors');
- const request = store.put({ id: id, content: tensor });
- transaction.onabort = (event) => {
- console.log("Transaction error while saving tensor: " + event.target.error);
- resolve(null);
- };
- request.onsuccess = () => {
- console.log('Tensor saved successfully.');
- resolve();
- };
- request.onerror = (event) => {
- console.error('Tensor save failed:', event.target.error);
- resolve(null);
- };
- });
- }
- function readTensorFromDb(db, id) {
- return new Promise((resolve, reject) => {
- if (db == null) {
- resolve(null);
- }
-
- const transaction = db.transaction(['tensors'], 'readonly');
- const store = transaction.objectStore('tensors');
- const request = store.get(id);
- transaction.onabort = (event) => {
- console.log("Transaction error while reading tensor: " + event.target.error);
- resolve(null);
- };
- request.onsuccess = (event) => {
- const result = event.target.result;
- if (result) {
- console.log("Cache hit: " + id);
- resolve(result);
- } else {
- console.log("Cache miss: " + id);
- resolve(null);
- }
- };
- request.onerror = (event) => {
- console.error('Tensor retrieve failed: ', event.target.error);
- resolve(null);
- };
- });
- }
- window.addEventListener('load', async function() {
- if (!navigator.gpu) {
- document.getElementById("wgpuError").style.display = "";
- document.getElementById("sdTitle").style.display = "none";
- return;
- }
- let db = await initDb();
- const ctx = document.getElementById("canvas").getContext("2d", { willReadFrequently: true });
- let labels, nets, safetensorParts;
- const getDevice = async () => {
- const adapter = await navigator.gpu.requestAdapter();
- const requiredLimits = {};
- const maxBufferSizeInSDModel = 1073741824;
- requiredLimits.maxStorageBufferBindingSize = maxBufferSizeInSDModel;
- requiredLimits.maxBufferSize = maxBufferSizeInSDModel;
-
- return await adapter.requestDevice({
- requiredLimits
- });
- };
- const timer = async (func, label = "") => {
- const start = performance.now();
- const out = await func();
- const delta = (performance.now() - start).toFixed(1)
- console.log(`${delta} ms ${label}`);
- return out;
- }
- const getProgressDlForPart = async (part, progressCallback) => {
- const response = await fetch(part);
- const contentLength = response.headers.get('content-length');
- const total = parseInt(contentLength, 10);
- const res = new Response(new ReadableStream({
- async start(controller) {
- const reader = response.body.getReader();
- for (;;) {
- const { done, value } = await reader.read();
- if (done) break;
- progressCallback(part, value.byteLength, total);
- controller.enqueue(value);
- }
-
- controller.close();
- },
- }));
-
- return res.arrayBuffer();
- };
-
- const getAndDecompressF16Safetensors = async (device, progress) => {
- let totalLoaded = 0;
- let totalSize = 0;
- let partSize = {};
- const progressCallback = (part, loaded, total) => {
- totalLoaded += loaded;
- if (!partSize[part]) {
- totalSize += total;
- partSize[part] = true;
- }
-
- progress(totalLoaded, totalSize);
- };
- let combinedBuffer = await readTensorFromDb(db, "net.f16");
- let textModelU8 = await readTensorFromDb(db, "net.text");
- let textModelFetched = false;
- if (combinedBuffer == null) {
- let dlParts = [
- getProgressDlForPart(window.MODEL_BASE_URL + '/net_part0.safetensors', progressCallback),
- getProgressDlForPart(window.MODEL_BASE_URL + '/net_part1.safetensors', progressCallback),
- getProgressDlForPart(window.MODEL_BASE_URL + '/net_part2.safetensors', progressCallback),
- getProgressDlForPart(window.MODEL_BASE_URL + '/net_part3.safetensors', progressCallback)
- ];
- if (textModelU8 == null) {
- dlParts.push(getProgressDlForPart(window.MODEL_BASE_URL + '/net_textmodel.safetensors', progressCallback));
- }
- let buffers = await Promise.all(dlParts);
- // Combine everything except for text model, since that's alreafy f32
- const totalLength = buffers.reduce((acc, buffer, index, array) => {
- if (index < 4) {
- return acc + buffer.byteLength;
- } else {
- return acc;
- }
- }, 0
- );
- combinedBuffer = new Uint8Array(totalLength);
- let offset = 0;
- buffers.forEach((buffer, index) => {
- if (index < 4) {
- combinedBuffer.set(new Uint8Array(buffer), offset);
- offset += buffer.byteLength;
- buffer = null;
- }
- });
- await saveTensorToDb(db, "net.f16", combinedBuffer);
- if (textModelU8 == null) {
- textModelFetched = true;
- textModelU8 = new Uint8Array(buffers[4]);
- await saveTensorToDb(db, "net.text", textModelU8);
- }
- } else {
- combinedBuffer = combinedBuffer.content;
- }
- if (textModelU8 == null) {
- textModelU8 = new Uint8Array(await getProgressDlForPart(window.MODEL_BASE_URL + '/net_textmodel.safetensors', progressCallback));
- await saveTensorToDb(db, "net.text", textModelU8);
- } else if (!textModelFetched) {
- textModelU8 = textModelU8.content;
- }
- document.getElementById("modelDlTitle").innerHTML = "Decompressing model";
- const textModelOffset = 3772703308;
- const metadataLength = Number(new DataView(combinedBuffer.buffer).getBigUint64(0, true));
- const metadata = JSON.parse(new TextDecoder("utf8").decode(combinedBuffer.subarray(8, 8 + metadataLength)));
- const allToDecomp = combinedBuffer.byteLength - (8 + metadataLength);
- const decodeChunkSize = 67107840;
- const numChunks = Math.ceil(allToDecomp/decodeChunkSize);
- console.log(allToDecomp + " bytes to decompress");
- console.log("Will be decompressed in " + numChunks+ " chunks");
- let partOffsets = [{start: 0, end: 1131408336}, {start: 1131408336, end: 2227518416}, {start: 2227518416, end: 3308987856}, {start: 3308987856, end: 4265298864}];
- let parts = [];
- for (let offsets of partOffsets) {
- parts.push(new Uint8Array(offsets.end-offsets.start));
- }
- parts[0].set(new Uint8Array(new BigUint64Array([BigInt(metadataLength)]).buffer), 0);
- parts[0].set(combinedBuffer.subarray(8, 8 + metadataLength), 8);
- parts[3].set(textModelU8, textModelOffset+8+metadataLength - partOffsets[3].start);
- let start = Date.now();
- let cursor = 0;
- for (let i = 0; i < numChunks; i++) {
- progress(i, numChunks);
- let chunkStartF16 = 8 + metadataLength + (decodeChunkSize * i);
- let chunkEndF16 = chunkStartF16 + decodeChunkSize;
- let chunk = combinedBuffer.subarray(chunkStartF16, chunkEndF16);
- if (chunk.byteLength %4 != 0) {
- const paddingBytes = 4 - (chunk.byteLength % 4);
- const alignedBuffer = new ArrayBuffer(chunk.byteLength + paddingBytes);
- const alignedView = new Uint8Array(alignedBuffer);
- alignedView.set(new Uint8Array(chunk));
- chunk = alignedView;
- }
- let result = await f16tof32GPU(device, chunk);
- let resultUint8 = new Uint8Array(result.buffer);
- let chunkStartF32 = 8 + metadataLength + (decodeChunkSize * i * 2);
- let chunkEndF32 = chunkStartF32 + resultUint8.byteLength;
- let offsetInPart = chunkStartF32 - partOffsets[cursor].start;
- if (chunkEndF32 < partOffsets[cursor].end || cursor === parts.length - 1) {
- parts[cursor].set(resultUint8, offsetInPart);
- } else {
- let spaceLeftInCurrentPart = partOffsets[cursor].end - chunkStartF32;
- parts[cursor].set(resultUint8.subarray(0, spaceLeftInCurrentPart), offsetInPart);
- cursor++;
- if (cursor < parts.length) {
- let nextPartOffset = spaceLeftInCurrentPart;
- let nextPartLength = resultUint8.length - nextPartOffset;
- parts[cursor].set(resultUint8.subarray(nextPartOffset, nextPartOffset + nextPartLength), 0);
- }
- }
- resultUint8 = null;
- result = null;
- }
- combinedBuffer = null;
- let end = Date.now();
- console.log("Decoding took: " + ((end - start) / 1000) + " s");
- console.log("Avarage " + ((end - start) / numChunks) + " ms per chunk");
- return parts;
- };
- const loadNet = async () => {
- const modelDlTitle = document.getElementById("modelDlTitle");
- const progress = (loaded, total) => {
- document.getElementById("modelDlProgressBar").value = (loaded/total) * 100
- document.getElementById("modelDlProgressValue").innerHTML = Math.trunc((loaded/total) * 100) + "%"
- }
- const device = await getDevice();
- safetensorParts = await getAndDecompressF16Safetensors(device, progress);
- modelDlTitle.innerHTML = "Compiling model"
- let models = ["textModel", "diffusor", "decoder"];
- nets = await timer(() => Promise.all([
- textModel().setup(device, safetensorParts),
- diffusor().setup(device, safetensorParts),
- decoder().setup(device, safetensorParts)
- ]).then((loadedModels) => loadedModels.reduce((acc, model, index) => { acc[models[index]] = model; return acc; }, {})), "(compilation)")
- progress(1, 1);
- modelDlTitle.innerHTML = "Model ready"
- setTimeout(() => {
- document.getElementById("modelDlProgressBar").style.display = "none";
- document.getElementById("modelDlProgressValue").style.display = "none";
- document.getElementById("divStepProgress").style.display = "flex";
- }, 1000);
- document.getElementById("btnRunNet").disabled = false;
- }
- function runStableDiffusion(prompt, steps, guidance) {
- return new Promise(async (resolve, reject) => {
- let context = await timer(() => nets["textModel"](clipTokenizer.encodeForCLIP(prompt)));
- let unconditional_context = await timer(() => nets["textModel"](clipTokenizer.encodeForCLIP("")));
- let timesteps = [];
- for (let i = 1; i < 1000; i += (1000/steps)) {
- timesteps.push(i);
- }
- console.log("Timesteps: " + timesteps);
- let alphasCumprod = getWeight(safetensorParts,"alphas_cumprod");
- let alphas = [];
- for (t of timesteps) {
- alphas.push(alphasCumprod[Math.floor(t)]);
- }
- alphas_prev = [1.0];
- for (let i = 0; i < alphas.length-1; i++) {
- alphas_prev.push(alphas[i]);
- }
- let inpSize = 4*64*64;
- latent = new Float32Array(inpSize);
- for (let i = 0; i < inpSize; i++) {
- latent[i] = Math.sqrt(-2.0 * Math.log(Math.random())) * Math.cos(2.0 * Math.PI * Math.random());
- }
-
- for (let i = timesteps.length - 1; i >= 0; i--) {
- let timestep = new Float32Array([timesteps[i]]);
- let x_prev = await timer(() => nets["diffusor"](unconditional_context, context, latent, timestep, new Float32Array([alphas[i]]), new Float32Array([alphas_prev[i]]), new Float32Array([guidance])));
- latent = x_prev;
- document.getElementById("progressBar").value = ((steps - i) / steps) * 100
- document.getElementById("progressFraction").innerHTML = (steps - i) + "/" + steps
- }
-
- resolve(await timer(() => nets["decoder"](latent)));
- });
- }
- document.getElementById("btnRunNet").addEventListener("click", function(e) {
- e.target.disabled = true;
- runStableDiffusion(document.getElementById("promptText").value, document.getElementById("stepRange").value, document.getElementById("guidanceRange").value).then((image) => {
- let pixels = []
- let pixelCounter = 0
- for (let j = 0; j < 512; j++) {
- for (let k = 0; k < 512; k++) {
- pixels.push(image[pixelCounter])
- pixels.push(image[pixelCounter+1])
- pixels.push(image[pixelCounter+2])
- pixels.push(255)
- pixelCounter += 3
- }
- }
-
- ctx.putImageData(new ImageData(new Uint8ClampedArray(pixels), 512, 512), 0, 0);
- console.log(image);
- console.log("Success");
- e.target.disabled = false;
- });
- }, false);
- const stepSlider = document.getElementById('stepRange');
- const stepValue = document.getElementById('stepValue');
- stepSlider.addEventListener('input', function() {
- stepValue.textContent = stepSlider.value;
- });
- const guidanceSlider = document.getElementById('guidanceRange');
- const guidanceValue = document.getElementById('guidanceValue');
- guidanceSlider.addEventListener('input', function() {
- guidanceValue.textContent = guidanceSlider.value;
- });
- loadNet();
- });
- </script>
- </body>
- </html>
|