123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428 |
- # This file incorporates code from the following:
- # Github Name | License | Link
- # Stability-AI/generative-models | MIT | https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/LICENSE-CODE
- # mlfoundations/open_clip | MIT | https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/LICENSE
- from tinygrad import Tensor, TinyJit, dtypes
- from tinygrad.nn import Conv2d, GroupNorm
- from tinygrad.nn.state import safe_load, load_state_dict
- from tinygrad.helpers import fetch, trange, colored, Timing, GlobalCounters
- from extra.models.clip import Embedder, FrozenClosedClipEmbedder, FrozenOpenClipEmbedder
- from extra.models.unet import UNetModel, Upsample, Downsample, timestep_embedding
- from examples.stable_diffusion import ResnetBlock, Mid
- import numpy as np
- from typing import Dict, List, Callable, Optional, Any, Set, Tuple
- import argparse, tempfile
- from abc import ABC, abstractmethod
- from pathlib import Path
- from PIL import Image
- # configs:
- # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/configs/inference/sd_xl_base.yaml
- # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/configs/inference/sd_xl_refiner.yaml
- configs: Dict = {
- "SDXL_Base": {
- "model": {"adm_in_ch": 2816, "in_ch": 4, "out_ch": 4, "model_ch": 320, "attention_resolutions": [4, 2], "num_res_blocks": 2, "channel_mult": [1, 2, 4], "d_head": 64, "transformer_depth": [1, 2, 10], "ctx_dim": 2048, "use_linear": True},
- "conditioner": {"concat_embedders": ["original_size_as_tuple", "crop_coords_top_left", "target_size_as_tuple"]},
- "first_stage_model": {"ch": 128, "in_ch": 3, "out_ch": 3, "z_ch": 4, "ch_mult": [1, 2, 4, 4], "num_res_blocks": 2, "resolution": 256},
- "denoiser": {"num_idx": 1000},
- },
- "SDXL_Refiner": {
- "model": {"adm_in_ch": 2560, "in_ch": 4, "out_ch": 4, "model_ch": 384, "attention_resolutions": [4, 2], "num_res_blocks": 2, "channel_mult": [1, 2, 4, 4], "d_head": 64, "transformer_depth": [4, 4, 4, 4], "ctx_dim": [1280, 1280, 1280, 1280], "use_linear": True},
- "conditioner": {"concat_embedders": ["original_size_as_tuple", "crop_coords_top_left", "aesthetic_score"]},
- "first_stage_model": {"ch": 128, "in_ch": 3, "out_ch": 3, "z_ch": 4, "ch_mult": [1, 2, 4, 4], "num_res_blocks": 2, "resolution": 256},
- "denoiser": {"num_idx": 1000},
- }
- }
- def tensor_identity(x:Tensor) -> Tensor:
- return x
- class DiffusionModel:
- def __init__(self, *args, **kwargs):
- self.diffusion_model = UNetModel(*args, **kwargs)
- class Embedder(ABC):
- input_key: str
- @abstractmethod
- def __call__(self, x:Tensor) -> Tensor:
- pass
- # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L913
- class ConcatTimestepEmbedderND(Embedder):
- def __init__(self, outdim:int, input_key:str):
- self.outdim = outdim
- self.input_key = input_key
- def __call__(self, x:Tensor):
- assert len(x.shape) == 2
- emb = timestep_embedding(x.flatten(), self.outdim)
- emb = emb.reshape((x.shape[0],-1))
- return emb
- # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L71
- class Conditioner:
- OUTPUT_DIM2KEYS = {2: "vector", 3: "crossattn", 4: "concat", 5: "concat"}
- KEY2CATDIM = {"vector": 1, "crossattn": 2, "concat": 1}
- embedders: List[Embedder]
- def __init__(self, concat_embedders:List[str]):
- self.embedders = [
- FrozenClosedClipEmbedder(ret_layer_idx=11),
- FrozenOpenClipEmbedder(dims=1280, n_heads=20, layers=32, return_pooled=True),
- ]
- for input_key in concat_embedders:
- self.embedders.append(ConcatTimestepEmbedderND(256, input_key))
- def get_keys(self) -> Set[str]:
- return set(e.input_key for e in self.embedders)
- def __call__(self, batch:Dict, force_zero_embeddings:List=[]) -> Dict[str,Tensor]:
- output: Dict[str,Tensor] = {}
- for embedder in self.embedders:
- emb_out = embedder(batch[embedder.input_key])
- if isinstance(emb_out, Tensor):
- emb_out = [emb_out]
- else:
- assert isinstance(emb_out, (list, tuple))
- for emb in emb_out:
- if embedder.input_key in force_zero_embeddings:
- emb = Tensor.zeros_like(emb)
- out_key = self.OUTPUT_DIM2KEYS[len(emb.shape)]
- if out_key in output:
- output[out_key] = Tensor.cat(output[out_key], emb, dim=self.KEY2CATDIM[out_key])
- else:
- output[out_key] = emb
- return output
- class FirstStage:
- """
- Namespace for First Stage Model components
- """
- # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/model.py#L487
- class Encoder:
- def __init__(self, ch:int, in_ch:int, out_ch:int, z_ch:int, ch_mult:List[int], num_res_blocks:int, resolution:int):
- self.conv_in = Conv2d(in_ch, ch, kernel_size=3, stride=1, padding=1)
- in_ch_mult = (1,) + tuple(ch_mult)
- class BlockEntry:
- def __init__(self, block:List[ResnetBlock], downsample):
- self.block = block
- self.downsample = downsample
- self.down: List[BlockEntry] = []
- for i_level in range(len(ch_mult)):
- block = []
- block_in = ch * in_ch_mult[i_level]
- block_out = ch * ch_mult [i_level]
- for _ in range(num_res_blocks):
- block.append(ResnetBlock(block_in, block_out))
- block_in = block_out
- downsample = tensor_identity if (i_level == len(ch_mult)-1) else Downsample(block_in)
- self.down.append(BlockEntry(block, downsample))
- self.mid = Mid(block_in)
- self.norm_out = GroupNorm(32, block_in)
- self.conv_out = Conv2d(block_in, 2*z_ch, kernel_size=3, stride=1, padding=1)
- def __call__(self, x:Tensor) -> Tensor:
- h = self.conv_in(x)
- for down in self.down:
- for block in down.block:
- h = block(h)
- h = down.downsample(h)
- h = h.sequential([self.mid.block_1, self.mid.attn_1, self.mid.block_2])
- h = h.sequential([self.norm_out, Tensor.swish, self.conv_out ])
- return h
- # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/model.py#L604
- class Decoder:
- def __init__(self, ch:int, in_ch:int, out_ch:int, z_ch:int, ch_mult:List[int], num_res_blocks:int, resolution:int):
- block_in = ch * ch_mult[-1]
- curr_res = resolution // 2 ** (len(ch_mult) - 1)
- self.z_shape = (1, z_ch, curr_res, curr_res)
- self.conv_in = Conv2d(z_ch, block_in, kernel_size=3, stride=1, padding=1)
- self.mid = Mid(block_in)
- class BlockEntry:
- def __init__(self, block:List[ResnetBlock], upsample:Callable[[Any],Any]):
- self.block = block
- self.upsample = upsample
- self.up: List[BlockEntry] = []
- for i_level in reversed(range(len(ch_mult))):
- block = []
- block_out = ch * ch_mult[i_level]
- for _ in range(num_res_blocks + 1):
- block.append(ResnetBlock(block_in, block_out))
- block_in = block_out
- upsample = tensor_identity if i_level == 0 else Upsample(block_in)
- self.up.insert(0, BlockEntry(block, upsample)) # type: ignore
- self.norm_out = GroupNorm(32, block_in)
- self.conv_out = Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
- def __call__(self, z:Tensor) -> Tensor:
- h = z.sequential([self.conv_in, self.mid.block_1, self.mid.attn_1, self.mid.block_2])
- for up in self.up[::-1]:
- for block in up.block:
- h = block(h)
- h = up.upsample(h)
- h = h.sequential([self.norm_out, Tensor.swish, self.conv_out])
- return h
- # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/models/autoencoder.py#L102
- # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/models/autoencoder.py#L437
- # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/models/autoencoder.py#L508
- class FirstStageModel:
- def __init__(self, embed_dim:int=4, **kwargs):
- self.encoder = FirstStage.Encoder(**kwargs)
- self.decoder = FirstStage.Decoder(**kwargs)
- self.quant_conv = Conv2d(2*kwargs["z_ch"], 2*embed_dim, 1)
- self.post_quant_conv = Conv2d(embed_dim, kwargs["z_ch"], 1)
- def decode(self, z:Tensor) -> Tensor:
- return z.sequential([self.post_quant_conv, self.decoder])
- # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/discretizer.py#L42
- class LegacyDDPMDiscretization:
- def __init__(self, linear_start:float=0.00085, linear_end:float=0.0120, num_timesteps:int=1000):
- self.num_timesteps = num_timesteps
- betas = np.linspace(linear_start**0.5, linear_end**0.5, num_timesteps, dtype=np.float32) ** 2
- alphas = 1.0 - betas
- self.alphas_cumprod = np.cumprod(alphas, axis=0)
- def __call__(self, n:int, flip:bool=False) -> Tensor:
- if n < self.num_timesteps:
- timesteps = np.linspace(self.num_timesteps - 1, 0, n, endpoint=False).astype(int)[::-1]
- alphas_cumprod = self.alphas_cumprod[timesteps]
- elif n == self.num_timesteps:
- alphas_cumprod = self.alphas_cumprod
- sigmas = Tensor((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
- sigmas = Tensor.cat(Tensor.zeros((1,)), sigmas)
- return sigmas if flip else sigmas.flip(axis=0) # sigmas is "pre-flipped", need to do oposite of flag
- def append_dims(x:Tensor, t:Tensor) -> Tensor:
- dims_to_append = len(t.shape) - len(x.shape)
- assert dims_to_append >= 0
- return x.reshape(x.shape + (1,)*dims_to_append)
- @TinyJit
- def run(model, x, tms, ctx, y, c_out, add):
- return (model(x, tms, ctx, y)*c_out + add).realize()
- # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/models/diffusion.py#L19
- class SDXL:
- def __init__(self, config:Dict):
- self.conditioner = Conditioner(**config["conditioner"])
- self.first_stage_model = FirstStageModel(**config["first_stage_model"])
- self.model = DiffusionModel(**config["model"])
- self.discretization = LegacyDDPMDiscretization()
- self.sigmas = self.discretization(config["denoiser"]["num_idx"], flip=True)
- # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/inference/helpers.py#L173
- def create_conditioning(self, pos_prompt:str, img_width:int, img_height:int, aesthetic_score:float=5.0) -> Tuple[Dict,Dict]:
- batch_c : Dict = {
- "txt": pos_prompt,
- "original_size_as_tuple": Tensor([img_height,img_width]).repeat(N,1),
- "crop_coords_top_left": Tensor([0,0]).repeat(N,1),
- "target_size_as_tuple": Tensor([img_height,img_width]).repeat(N,1),
- "aesthetic_score": Tensor([aesthetic_score]).repeat(N,1),
- }
- batch_uc: Dict = {
- "txt": "",
- "original_size_as_tuple": Tensor([img_height,img_width]).repeat(N,1),
- "crop_coords_top_left": Tensor([0,0]).repeat(N,1),
- "target_size_as_tuple": Tensor([img_height,img_width]).repeat(N,1),
- "aesthetic_score": Tensor([aesthetic_score]).repeat(N,1),
- }
- return model.conditioner(batch_c), model.conditioner(batch_uc, force_zero_embeddings=["txt"])
- # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/denoiser.py#L42
- def denoise(self, x:Tensor, sigma:Tensor, cond:Dict) -> Tensor:
- def sigma_to_idx(s:Tensor) -> Tensor:
- dists = s - self.sigmas.unsqueeze(1)
- return dists.abs().argmin(axis=0).view(*s.shape)
- sigma = self.sigmas[sigma_to_idx(sigma)]
- sigma_shape = sigma.shape
- sigma = append_dims(sigma, x)
- c_out = -sigma
- c_in = 1 / (sigma**2 + 1.0) ** 0.5
- c_noise = sigma_to_idx(sigma.reshape(sigma_shape))
- def prep(*tensors:Tensor):
- return tuple(t.cast(dtypes.float16).realize() for t in tensors)
- return run(self.model.diffusion_model, *prep(x*c_in, c_noise, cond["crossattn"], cond["vector"], c_out, x))
- def decode(self, x:Tensor) -> Tensor:
- return self.first_stage_model.decode(1.0 / 0.13025 * x)
- class VanillaCFG:
- def __init__(self, scale:float):
- self.scale = scale
- def prepare_inputs(self, x:Tensor, s:float, c:Dict, uc:Dict) -> Tuple[Tensor,Tensor,Tensor]:
- c_out = {}
- for k in c:
- assert k in ["vector", "crossattn", "concat"]
- c_out[k] = Tensor.cat(uc[k], c[k], dim=0)
- return Tensor.cat(x, x), Tensor.cat(s, s), c_out
- def __call__(self, x:Tensor, sigma:float) -> Tensor:
- x_u, x_c = x.chunk(2)
- x_pred = x_u + self.scale*(x_c - x_u)
- return x_pred
- # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/sampling.py#L21
- # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/sampling.py#L287
- class DPMPP2MSampler:
- def __init__(self, cfg_scale:float):
- self.discretization = LegacyDDPMDiscretization()
- self.guider = VanillaCFG(cfg_scale)
- def sampler_step(self, old_denoised:Optional[Tensor], prev_sigma:Optional[Tensor], sigma:Tensor, next_sigma:Tensor, denoiser, x:Tensor, c:Dict, uc:Dict) -> Tuple[Tensor,Tensor]:
- denoised = denoiser(*self.guider.prepare_inputs(x, sigma, c, uc))
- denoised = self.guider(denoised, sigma)
- t, t_next = sigma.log().neg(), next_sigma.log().neg()
- h = t_next - t
- r = None if prev_sigma is None else (t - prev_sigma.log().neg()) / h
- mults = [t_next.neg().exp()/t.neg().exp(), (-h).exp().sub(1)]
- if r is not None:
- mults.extend([1 + 1/(2*r), 1/(2*r)])
- mults = [append_dims(m, x) for m in mults]
- x_standard = mults[0]*x - mults[1]*denoised
- if (old_denoised is None) or (next_sigma.sum().numpy().item() < 1e-14):
- return x_standard, denoised
- denoised_d = mults[2]*denoised - mults[3]*old_denoised
- x_advanced = mults[0]*x - mults[1]*denoised_d
- x = Tensor.where(append_dims(next_sigma, x) > 0.0, x_advanced, x_standard)
- return x, denoised
- def __call__(self, denoiser, x:Tensor, c:Dict, uc:Dict, num_steps:int, timing=False) -> Tensor:
- sigmas = self.discretization(num_steps)
- x *= Tensor.sqrt(1.0 + sigmas[0] ** 2.0)
- num_sigmas = len(sigmas)
- old_denoised = None
- for i in trange(num_sigmas - 1):
- with Timing("step in ", enabled=timing, on_exit=lambda _: f", using {GlobalCounters.mem_used/1e9:.2f} GB"):
- x, old_denoised = self.sampler_step(
- old_denoised=old_denoised,
- prev_sigma=(None if i==0 else sigmas[i-1].reshape(x.shape[0])),
- sigma=sigmas[i].reshape(x.shape[0]),
- next_sigma=sigmas[i+1].reshape(x.shape[0]),
- denoiser=denoiser,
- x=x,
- c=c,
- uc=uc,
- )
- x.realize()
- old_denoised.realize()
- return x
- if __name__ == "__main__":
- default_prompt = "a horse sized cat eating a bagel"
- parser = argparse.ArgumentParser(description="Run SDXL", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
- parser.add_argument('--steps', type=int, default=10, help="The number of diffusion steps")
- parser.add_argument('--prompt', type=str, default=default_prompt, help="Description of image to generate")
- parser.add_argument('--out', type=str, default=Path(tempfile.gettempdir()) / "rendered.png", help="Output filename")
- parser.add_argument('--seed', type=int, help="Set the random latent seed")
- parser.add_argument('--guidance', type=float, default=6.0, help="Prompt strength")
- parser.add_argument('--width', type=int, default=1024, help="The output image width")
- parser.add_argument('--height', type=int, default=1024, help="The output image height")
- parser.add_argument('--weights', type=str, help="Custom path to weights")
- parser.add_argument('--timing', action='store_true', help="Print timing per step")
- parser.add_argument('--noshow', action='store_true', help="Don't show the image")
- args = parser.parse_args()
- Tensor.no_grad = True
- if args.seed is not None:
- Tensor.manual_seed(args.seed)
- model = SDXL(configs["SDXL_Base"])
- default_weight_url = 'https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors'
- weights = args.weights if args.weights else fetch(default_weight_url, 'sd_xl_base_1.0.safetensors')
- load_state_dict(model, safe_load(weights), strict=False)
- N = 1
- C = 4
- F = 8
- assert args.width % F == 0, f"img_width must be multiple of {F}, got {args.width}"
- assert args.height % F == 0, f"img_height must be multiple of {F}, got {args.height}"
- c, uc = model.create_conditioning(args.prompt, args.width, args.height)
- del model.conditioner
- for v in c .values(): v.realize()
- for v in uc.values(): v.realize()
- print("created batch")
- # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/inference/helpers.py#L101
- shape = (N, C, args.height // F, args.width // F)
- randn = Tensor.randn(shape)
- sampler = DPMPP2MSampler(args.guidance)
- z = sampler(model.denoise, randn, c, uc, args.steps, timing=args.timing)
- print("created samples")
- x = model.decode(z).realize()
- print("decoded samples")
- # make image correct size and scale
- x = (x + 1.0) / 2.0
- x = x.reshape(3,args.height,args.width).permute(1,2,0).clip(0,1).mul(255).cast(dtypes.uint8)
- print(x.shape)
- im = Image.fromarray(x.numpy())
- print(f"saving {args.out}")
- im.save(args.out)
- if not args.noshow:
- im.show()
- # validation!
- if args.prompt == default_prompt and args.steps == 10 and args.seed == 0 and args.guidance == 6.0 and args.width == args.height == 1024 \
- and not args.weights:
- ref_image = Tensor(np.array(Image.open(Path(__file__).parent / "sdxl_seed0.png")))
- distance = (((x - ref_image).cast(dtypes.float) / ref_image.max())**2).mean().item()
- assert distance < 2e-3, colored(f"validation failed with {distance=}", "red")
- print(colored(f"output validated with {distance=}", "green"))
|