123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147 |
- from tinygrad import Tensor, dtypes, TinyJit
- from tinygrad.helpers import fetch
- from tinygrad.nn.state import safe_load, load_state_dict, get_state_dict
- from examples.stable_diffusion import AutoencoderKL, get_alphas_cumprod
- from examples.sdxl import DPMPP2MSampler, append_dims, LegacyDDPMDiscretization
- from extra.models.unet import UNetModel
- from extra.models.clip import FrozenOpenClipEmbedder
- from typing import Dict
- import argparse, tempfile, os
- from pathlib import Path
- from PIL import Image
- class DiffusionModel:
- def __init__(self, model:UNetModel):
- self.diffusion_model = model
- @TinyJit
- def run(model, x, tms, ctx, c_out, add):
- return (model(x, tms, ctx)*c_out + add).realize()
- # https://github.com/Stability-AI/stablediffusion/blob/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/ldm/models/diffusion/ddpm.py#L521
- class StableDiffusionV2:
- def __init__(self, unet_config:Dict, cond_stage_config:Dict, parameterization:str="v"):
- self.model = DiffusionModel(UNetModel(**unet_config))
- self.first_stage_model = AutoencoderKL()
- self.cond_stage_model = FrozenOpenClipEmbedder(**cond_stage_config)
- self.alphas_cumprod = get_alphas_cumprod()
- self.parameterization = parameterization
- self.discretization = LegacyDDPMDiscretization()
- self.sigmas = self.discretization(1000, flip=True)
- def denoise(self, x:Tensor, sigma:Tensor, cond:Dict) -> Tensor:
- def sigma_to_idx(s:Tensor) -> Tensor:
- dists = s - self.sigmas.unsqueeze(1)
- return dists.abs().argmin(axis=0).view(*s.shape)
- sigma = self.sigmas[sigma_to_idx(sigma)]
- sigma_shape = sigma.shape
- sigma = append_dims(sigma, x)
- c_skip = 1.0 / (sigma**2 + 1.0)
- c_out = -sigma / (sigma**2 + 1.0) ** 0.5
- c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
- c_noise = sigma_to_idx(sigma.reshape(sigma_shape))
- def prep(*tensors:Tensor):
- return tuple(t.cast(dtypes.float16).realize() for t in tensors)
- return run(self.model.diffusion_model, *prep(x*c_in, c_noise, cond["crossattn"], c_out, x*c_skip))
- def decode(self, x:Tensor, height:int, width:int) -> Tensor:
- x = self.first_stage_model.post_quant_conv(1/0.18215 * x)
- x = self.first_stage_model.decoder(x)
- # make image correct size and scale
- x = (x + 1.0) / 2.0
- x = x.reshape(3,height,width).permute(1,2,0).clip(0,1).mul(255).cast(dtypes.uint8)
- return x
- params: Dict = {
- "unet_config": {
- "adm_in_ch": None,
- "in_ch": 4,
- "out_ch": 4,
- "model_ch": 320,
- "attention_resolutions": [4, 2, 1],
- "num_res_blocks": 2,
- "channel_mult": [1, 2, 4, 4],
- "d_head": 64,
- "transformer_depth": [1, 1, 1, 1],
- "ctx_dim": 1024,
- "use_linear": True,
- },
- "cond_stage_config": {
- "dims": 1024,
- "n_heads": 16,
- "layers": 24,
- "return_pooled": False,
- "ln_penultimate": True,
- }
- }
- if __name__ == "__main__":
- default_prompt = "a horse sized cat eating a bagel"
- parser = argparse.ArgumentParser(description='Run Stable Diffusion v2.X', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
- parser.add_argument('--steps', type=int, default=10, help="The number of diffusion steps")
- parser.add_argument('--prompt', type=str, default=default_prompt, help="Description of image to generate")
- parser.add_argument('--out', type=str, default=Path(tempfile.gettempdir()) / "rendered.png", help="Output filename")
- parser.add_argument('--seed', type=int, help="Set the random latent seed")
- parser.add_argument('--guidance', type=float, default=7.5, help="Prompt strength")
- parser.add_argument('--width', type=int, default=768, help="The output image width")
- parser.add_argument('--height', type=int, default=768, help="The output image height")
- parser.add_argument('--weights-fn', type=str, help="Filename of weights to use")
- parser.add_argument('--weights-url', type=str, help="Custom URL to download weights from")
- parser.add_argument('--timing', action='store_true', help="Print timing per step")
- parser.add_argument('--noshow', action='store_true', help="Don't show the image")
- parser.add_argument('--fp16', action='store_true', help="Cast the weights to float16")
- args = parser.parse_args()
- N = 1
- C = 4
- F = 8
- assert args.width % F == 0, f"img_width must be multiple of {F}, got {args.width}"
- assert args.height % F == 0, f"img_height must be multiple of {F}, got {args.height}"
- Tensor.no_grad = True
- if args.seed is not None:
- Tensor.manual_seed(args.seed)
- model = StableDiffusionV2(**params)
- default_weights_url = 'https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/v2-1_768-ema-pruned.safetensors'
- weights_fn = args.weights_fn
- if not weights_fn:
- weights_url = args.weights_url if args.weights_url else default_weights_url
- weights_fn = fetch(weights_url, os.path.basename(str(weights_url)))
- load_state_dict(model, safe_load(weights_fn), strict=False)
- if args.fp16:
- for k,v in get_state_dict(model).items():
- if k.startswith("model"):
- v.replace(v.cast(dtypes.float16).realize())
- c = { "crossattn": model.cond_stage_model(args.prompt) }
- uc = { "crossattn": model.cond_stage_model("") }
- del model.cond_stage_model
- print("created conditioning")
- shape = (N, C, args.height // F, args.width // F)
- randn = Tensor.randn(shape)
- sampler = DPMPP2MSampler(args.guidance)
- z = sampler(model.denoise, randn, c, uc, args.steps, timing=args.timing)
- print("created samples")
- x = model.decode(z, args.height, args.width).realize()
- print("decoded samples")
- print(x.shape)
- im = Image.fromarray(x.numpy())
- print(f"saving {args.out}")
- im.save(args.out)
- if not args.noshow:
- im.show()
|