sdv2.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. from tinygrad import Tensor, dtypes, TinyJit
  2. from tinygrad.helpers import fetch
  3. from tinygrad.nn.state import safe_load, load_state_dict, get_state_dict
  4. from examples.stable_diffusion import AutoencoderKL, get_alphas_cumprod
  5. from examples.sdxl import DPMPP2MSampler, append_dims, LegacyDDPMDiscretization
  6. from extra.models.unet import UNetModel
  7. from extra.models.clip import FrozenOpenClipEmbedder
  8. from typing import Dict
  9. import argparse, tempfile, os
  10. from pathlib import Path
  11. from PIL import Image
  12. class DiffusionModel:
  13. def __init__(self, model:UNetModel):
  14. self.diffusion_model = model
  15. @TinyJit
  16. def run(model, x, tms, ctx, c_out, add):
  17. return (model(x, tms, ctx)*c_out + add).realize()
  18. # https://github.com/Stability-AI/stablediffusion/blob/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/ldm/models/diffusion/ddpm.py#L521
  19. class StableDiffusionV2:
  20. def __init__(self, unet_config:Dict, cond_stage_config:Dict, parameterization:str="v"):
  21. self.model = DiffusionModel(UNetModel(**unet_config))
  22. self.first_stage_model = AutoencoderKL()
  23. self.cond_stage_model = FrozenOpenClipEmbedder(**cond_stage_config)
  24. self.alphas_cumprod = get_alphas_cumprod()
  25. self.parameterization = parameterization
  26. self.discretization = LegacyDDPMDiscretization()
  27. self.sigmas = self.discretization(1000, flip=True)
  28. def denoise(self, x:Tensor, sigma:Tensor, cond:Dict) -> Tensor:
  29. def sigma_to_idx(s:Tensor) -> Tensor:
  30. dists = s - self.sigmas.unsqueeze(1)
  31. return dists.abs().argmin(axis=0).view(*s.shape)
  32. sigma = self.sigmas[sigma_to_idx(sigma)]
  33. sigma_shape = sigma.shape
  34. sigma = append_dims(sigma, x)
  35. c_skip = 1.0 / (sigma**2 + 1.0)
  36. c_out = -sigma / (sigma**2 + 1.0) ** 0.5
  37. c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
  38. c_noise = sigma_to_idx(sigma.reshape(sigma_shape))
  39. def prep(*tensors:Tensor):
  40. return tuple(t.cast(dtypes.float16).realize() for t in tensors)
  41. return run(self.model.diffusion_model, *prep(x*c_in, c_noise, cond["crossattn"], c_out, x*c_skip))
  42. def decode(self, x:Tensor, height:int, width:int) -> Tensor:
  43. x = self.first_stage_model.post_quant_conv(1/0.18215 * x)
  44. x = self.first_stage_model.decoder(x)
  45. # make image correct size and scale
  46. x = (x + 1.0) / 2.0
  47. x = x.reshape(3,height,width).permute(1,2,0).clip(0,1).mul(255).cast(dtypes.uint8)
  48. return x
  49. params: Dict = {
  50. "unet_config": {
  51. "adm_in_ch": None,
  52. "in_ch": 4,
  53. "out_ch": 4,
  54. "model_ch": 320,
  55. "attention_resolutions": [4, 2, 1],
  56. "num_res_blocks": 2,
  57. "channel_mult": [1, 2, 4, 4],
  58. "d_head": 64,
  59. "transformer_depth": [1, 1, 1, 1],
  60. "ctx_dim": 1024,
  61. "use_linear": True,
  62. },
  63. "cond_stage_config": {
  64. "dims": 1024,
  65. "n_heads": 16,
  66. "layers": 24,
  67. "return_pooled": False,
  68. "ln_penultimate": True,
  69. }
  70. }
  71. if __name__ == "__main__":
  72. default_prompt = "a horse sized cat eating a bagel"
  73. parser = argparse.ArgumentParser(description='Run Stable Diffusion v2.X', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  74. parser.add_argument('--steps', type=int, default=10, help="The number of diffusion steps")
  75. parser.add_argument('--prompt', type=str, default=default_prompt, help="Description of image to generate")
  76. parser.add_argument('--out', type=str, default=Path(tempfile.gettempdir()) / "rendered.png", help="Output filename")
  77. parser.add_argument('--seed', type=int, help="Set the random latent seed")
  78. parser.add_argument('--guidance', type=float, default=7.5, help="Prompt strength")
  79. parser.add_argument('--width', type=int, default=768, help="The output image width")
  80. parser.add_argument('--height', type=int, default=768, help="The output image height")
  81. parser.add_argument('--weights-fn', type=str, help="Filename of weights to use")
  82. parser.add_argument('--weights-url', type=str, help="Custom URL to download weights from")
  83. parser.add_argument('--timing', action='store_true', help="Print timing per step")
  84. parser.add_argument('--noshow', action='store_true', help="Don't show the image")
  85. parser.add_argument('--fp16', action='store_true', help="Cast the weights to float16")
  86. args = parser.parse_args()
  87. N = 1
  88. C = 4
  89. F = 8
  90. assert args.width % F == 0, f"img_width must be multiple of {F}, got {args.width}"
  91. assert args.height % F == 0, f"img_height must be multiple of {F}, got {args.height}"
  92. Tensor.no_grad = True
  93. if args.seed is not None:
  94. Tensor.manual_seed(args.seed)
  95. model = StableDiffusionV2(**params)
  96. default_weights_url = 'https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/v2-1_768-ema-pruned.safetensors'
  97. weights_fn = args.weights_fn
  98. if not weights_fn:
  99. weights_url = args.weights_url if args.weights_url else default_weights_url
  100. weights_fn = fetch(weights_url, os.path.basename(str(weights_url)))
  101. load_state_dict(model, safe_load(weights_fn), strict=False)
  102. if args.fp16:
  103. for k,v in get_state_dict(model).items():
  104. if k.startswith("model"):
  105. v.replace(v.cast(dtypes.float16).realize())
  106. c = { "crossattn": model.cond_stage_model(args.prompt) }
  107. uc = { "crossattn": model.cond_stage_model("") }
  108. del model.cond_stage_model
  109. print("created conditioning")
  110. shape = (N, C, args.height // F, args.width // F)
  111. randn = Tensor.randn(shape)
  112. sampler = DPMPP2MSampler(args.guidance)
  113. z = sampler(model.denoise, randn, c, uc, args.steps, timing=args.timing)
  114. print("created samples")
  115. x = model.decode(z, args.height, args.width).realize()
  116. print("decoded samples")
  117. print(x.shape)
  118. im = Image.fromarray(x.numpy())
  119. print(f"saving {args.out}")
  120. im.save(args.out)
  121. if not args.noshow:
  122. im.show()