!pip install -qq diffusers transformers scipy ftfy accelerate
Stable Diffusion Inference (low-level)
Create an image using Diffusers library.
Install and Import Libraries
import torch
= "cuda" if torch.cuda.is_available() else "cpu"
device device
'cuda'
= ["a photograph of an astronaut riding a horse"]
prompt
= 512 # default height of Stable Diffusion
height = 512 # default width of Stable Diffusion
width
= 50 # Number of denoising steps
num_inference_steps
= 7.5 # Scale for classifier-free guidance
guidance_scale
= torch.manual_seed(256) # Seed generator to create the inital latent noise
generator
= 2 batch_size
from diffusers import LMSDiscreteScheduler, StableDiffusionPipeline
= LMSDiscreteScheduler.from_pretrained(
scheduler "CompVis/stable-diffusion-v1-4", subfolder="scheduler"
)= StableDiffusionPipeline.from_pretrained(
pipe "CompVis/stable-diffusion-v1-4", scheduler=scheduler
)
pipe
StableDiffusionPipeline {
"_class_name": "StableDiffusionPipeline",
"_diffusers_version": "0.20.2",
"_name_or_path": "CompVis/stable-diffusion-v1-4",
"feature_extractor": [
"transformers",
"CLIPImageProcessor"
],
"requires_safety_checker": true,
"safety_checker": [
"stable_diffusion",
"StableDiffusionSafetyChecker"
],
"scheduler": [
"diffusers",
"LMSDiscreteScheduler"
],
"text_encoder": [
"transformers",
"CLIPTextModel"
],
"tokenizer": [
"transformers",
"CLIPTokenizer"
],
"unet": [
"diffusers",
"UNet2DConditionModel"
],
"vae": [
"diffusers",
"AutoencoderKL"
]
}
= pipe.to(device) pipe
Low-level
# 1. Load the autoencoder model which will be used to decode the latents into image space.
= pipe.vae
vae
# 2. Load the tokenizer and text encoder to tokenize and encode the text.
= pipe.tokenizer
tokenizer = pipe.text_encoder
text_encoder
# 3. The UNet model for generating the latents.
= pipe.unet unet
Models
vae
AutoencoderKL(
(encoder): Encoder(
(conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(down_blocks): ModuleList(
(0): DownEncoderBlock2D(
(resnets): ModuleList(
(0-1): 2 x ResnetBlock2D(
(norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
(conv1): LoRACompatibleConv(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): LoRACompatibleConv(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
)
)
(downsamplers): ModuleList(
(0): Downsample2D(
(conv): LoRACompatibleConv(128, 128, kernel_size=(3, 3), stride=(2, 2))
)
)
)
(1): DownEncoderBlock2D(
(resnets): ModuleList(
(0): ResnetBlock2D(
(norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
(conv1): LoRACompatibleConv(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm2): GroupNorm(32, 256, eps=1e-06, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): LoRACompatibleConv(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
(conv_shortcut): LoRACompatibleConv(128, 256, kernel_size=(1, 1), stride=(1, 1))
)
(1): ResnetBlock2D(
(norm1): GroupNorm(32, 256, eps=1e-06, affine=True)
(conv1): LoRACompatibleConv(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm2): GroupNorm(32, 256, eps=1e-06, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): LoRACompatibleConv(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
)
)
(downsamplers): ModuleList(
(0): Downsample2D(
(conv): LoRACompatibleConv(256, 256, kernel_size=(3, 3), stride=(2, 2))
)
)
)
(2): DownEncoderBlock2D(
(resnets): ModuleList(
(0): ResnetBlock2D(
(norm1): GroupNorm(32, 256, eps=1e-06, affine=True)
(conv1): LoRACompatibleConv(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): LoRACompatibleConv(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
(conv_shortcut): LoRACompatibleConv(256, 512, kernel_size=(1, 1), stride=(1, 1))
)
(1): ResnetBlock2D(
(norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
(conv1): LoRACompatibleConv(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): LoRACompatibleConv(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
)
)
(downsamplers): ModuleList(
(0): Downsample2D(
(conv): LoRACompatibleConv(512, 512, kernel_size=(3, 3), stride=(2, 2))
)
)
)
(3): DownEncoderBlock2D(
(resnets): ModuleList(
(0-1): 2 x ResnetBlock2D(
(norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
(conv1): LoRACompatibleConv(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): LoRACompatibleConv(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
)
)
)
)
(mid_block): UNetMidBlock2D(
(attentions): ModuleList(
(0): Attention(
(group_norm): GroupNorm(32, 512, eps=1e-06, affine=True)
(to_q): Linear(in_features=512, out_features=512, bias=True)
(to_k): Linear(in_features=512, out_features=512, bias=True)
(to_v): Linear(in_features=512, out_features=512, bias=True)
(to_out): ModuleList(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Dropout(p=0.0, inplace=False)
)
)
)
(resnets): ModuleList(
(0-1): 2 x ResnetBlock2D(
(norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
(conv1): LoRACompatibleConv(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): LoRACompatibleConv(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
)
)
)
(conv_norm_out): GroupNorm(32, 512, eps=1e-06, affine=True)
(conv_act): SiLU()
(conv_out): Conv2d(512, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(decoder): Decoder(
(conv_in): Conv2d(4, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(up_blocks): ModuleList(
(0-1): 2 x UpDecoderBlock2D(
(resnets): ModuleList(
(0-2): 3 x ResnetBlock2D(
(norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
(conv1): LoRACompatibleConv(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): LoRACompatibleConv(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
)
)
(upsamplers): ModuleList(
(0): Upsample2D(
(conv): LoRACompatibleConv(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
)
)
(2): UpDecoderBlock2D(
(resnets): ModuleList(
(0): ResnetBlock2D(
(norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
(conv1): LoRACompatibleConv(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm2): GroupNorm(32, 256, eps=1e-06, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): LoRACompatibleConv(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
(conv_shortcut): LoRACompatibleConv(512, 256, kernel_size=(1, 1), stride=(1, 1))
)
(1-2): 2 x ResnetBlock2D(
(norm1): GroupNorm(32, 256, eps=1e-06, affine=True)
(conv1): LoRACompatibleConv(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm2): GroupNorm(32, 256, eps=1e-06, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): LoRACompatibleConv(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
)
)
(upsamplers): ModuleList(
(0): Upsample2D(
(conv): LoRACompatibleConv(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
)
)
(3): UpDecoderBlock2D(
(resnets): ModuleList(
(0): ResnetBlock2D(
(norm1): GroupNorm(32, 256, eps=1e-06, affine=True)
(conv1): LoRACompatibleConv(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): LoRACompatibleConv(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
(conv_shortcut): LoRACompatibleConv(256, 128, kernel_size=(1, 1), stride=(1, 1))
)
(1-2): 2 x ResnetBlock2D(
(norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
(conv1): LoRACompatibleConv(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): LoRACompatibleConv(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
)
)
)
)
(mid_block): UNetMidBlock2D(
(attentions): ModuleList(
(0): Attention(
(group_norm): GroupNorm(32, 512, eps=1e-06, affine=True)
(to_q): Linear(in_features=512, out_features=512, bias=True)
(to_k): Linear(in_features=512, out_features=512, bias=True)
(to_v): Linear(in_features=512, out_features=512, bias=True)
(to_out): ModuleList(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Dropout(p=0.0, inplace=False)
)
)
)
(resnets): ModuleList(
(0-1): 2 x ResnetBlock2D(
(norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
(conv1): LoRACompatibleConv(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): LoRACompatibleConv(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
)
)
)
(conv_norm_out): GroupNorm(32, 128, eps=1e-06, affine=True)
(conv_act): SiLU()
(conv_out): Conv2d(128, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(quant_conv): Conv2d(8, 8, kernel_size=(1, 1), stride=(1, 1))
(post_quant_conv): Conv2d(4, 4, kernel_size=(1, 1), stride=(1, 1))
)
tokenizer
CLIPTokenizer(name_or_path='/root/.cache/huggingface/hub/models--CompVis--stable-diffusion-v1-4/snapshots/133a221b8aa7292a167afc5127cb63fb5005638b/tokenizer', vocab_size=49408, model_max_length=77, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': AddedToken("<|startoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'eos_token': AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'unk_token': AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'pad_token': '<|endoftext|>'}, clean_up_tokenization_spaces=True)
text_encoder
CLIPTextModel(
(text_model): CLIPTextTransformer(
(embeddings): CLIPTextEmbeddings(
(token_embedding): Embedding(49408, 768)
(position_embedding): Embedding(77, 768)
)
(encoder): CLIPEncoder(
(layers): ModuleList(
(0-11): 12 x CLIPEncoderLayer(
(self_attn): CLIPAttention(
(k_proj): Linear(in_features=768, out_features=768, bias=True)
(v_proj): Linear(in_features=768, out_features=768, bias=True)
(q_proj): Linear(in_features=768, out_features=768, bias=True)
(out_proj): Linear(in_features=768, out_features=768, bias=True)
)
(layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): CLIPMLP(
(activation_fn): QuickGELUActivation()
(fc1): Linear(in_features=768, out_features=3072, bias=True)
(fc2): Linear(in_features=3072, out_features=768, bias=True)
)
(layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
)
)
(final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
)
unet
UNet2DConditionModel(
(conv_in): Conv2d(4, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_proj): Timesteps()
(time_embedding): TimestepEmbedding(
(linear_1): Linear(in_features=320, out_features=1280, bias=True)
(act): SiLU()
(linear_2): Linear(in_features=1280, out_features=1280, bias=True)
)
(down_blocks): ModuleList(
(0): CrossAttnDownBlock2D(
(attentions): ModuleList(
(0-1): 2 x Transformer2DModel(
(norm): GroupNorm(32, 320, eps=1e-06, affine=True)
(proj_in): LoRACompatibleConv(320, 320, kernel_size=(1, 1), stride=(1, 1))
(transformer_blocks): ModuleList(
(0): BasicTransformerBlock(
(norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
(attn1): Attention(
(to_q): Linear(in_features=320, out_features=320, bias=False)
(to_k): Linear(in_features=320, out_features=320, bias=False)
(to_v): Linear(in_features=320, out_features=320, bias=False)
(to_out): ModuleList(
(0): Linear(in_features=320, out_features=320, bias=True)
(1): Dropout(p=0.0, inplace=False)
)
)
(norm2): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
(attn2): Attention(
(to_q): Linear(in_features=320, out_features=320, bias=False)
(to_k): Linear(in_features=768, out_features=320, bias=False)
(to_v): Linear(in_features=768, out_features=320, bias=False)
(to_out): ModuleList(
(0): Linear(in_features=320, out_features=320, bias=True)
(1): Dropout(p=0.0, inplace=False)
)
)
(norm3): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
(ff): FeedForward(
(net): ModuleList(
(0): GEGLU(
(proj): LoRACompatibleLinear(in_features=320, out_features=2560, bias=True)
)
(1): Dropout(p=0.0, inplace=False)
(2): LoRACompatibleLinear(in_features=1280, out_features=320, bias=True)
)
)
)
)
(proj_out): LoRACompatibleConv(320, 320, kernel_size=(1, 1), stride=(1, 1))
)
)
(resnets): ModuleList(
(0-1): 2 x ResnetBlock2D(
(norm1): GroupNorm(32, 320, eps=1e-05, affine=True)
(conv1): LoRACompatibleConv(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): LoRACompatibleLinear(in_features=1280, out_features=320, bias=True)
(norm2): GroupNorm(32, 320, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): LoRACompatibleConv(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
)
)
(downsamplers): ModuleList(
(0): Downsample2D(
(conv): LoRACompatibleConv(320, 320, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
)
)
)
(1): CrossAttnDownBlock2D(
(attentions): ModuleList(
(0-1): 2 x Transformer2DModel(
(norm): GroupNorm(32, 640, eps=1e-06, affine=True)
(proj_in): LoRACompatibleConv(640, 640, kernel_size=(1, 1), stride=(1, 1))
(transformer_blocks): ModuleList(
(0): BasicTransformerBlock(
(norm1): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
(attn1): Attention(
(to_q): Linear(in_features=640, out_features=640, bias=False)
(to_k): Linear(in_features=640, out_features=640, bias=False)
(to_v): Linear(in_features=640, out_features=640, bias=False)
(to_out): ModuleList(
(0): Linear(in_features=640, out_features=640, bias=True)
(1): Dropout(p=0.0, inplace=False)
)
)
(norm2): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
(attn2): Attention(
(to_q): Linear(in_features=640, out_features=640, bias=False)
(to_k): Linear(in_features=768, out_features=640, bias=False)
(to_v): Linear(in_features=768, out_features=640, bias=False)
(to_out): ModuleList(
(0): Linear(in_features=640, out_features=640, bias=True)
(1): Dropout(p=0.0, inplace=False)
)
)
(norm3): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
(ff): FeedForward(
(net): ModuleList(
(0): GEGLU(
(proj): LoRACompatibleLinear(in_features=640, out_features=5120, bias=True)
)
(1): Dropout(p=0.0, inplace=False)
(2): LoRACompatibleLinear(in_features=2560, out_features=640, bias=True)
)
)
)
)
(proj_out): LoRACompatibleConv(640, 640, kernel_size=(1, 1), stride=(1, 1))
)
)
(resnets): ModuleList(
(0): ResnetBlock2D(
(norm1): GroupNorm(32, 320, eps=1e-05, affine=True)
(conv1): LoRACompatibleConv(320, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): LoRACompatibleLinear(in_features=1280, out_features=640, bias=True)
(norm2): GroupNorm(32, 640, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): LoRACompatibleConv(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
(conv_shortcut): LoRACompatibleConv(320, 640, kernel_size=(1, 1), stride=(1, 1))
)
(1): ResnetBlock2D(
(norm1): GroupNorm(32, 640, eps=1e-05, affine=True)
(conv1): LoRACompatibleConv(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): LoRACompatibleLinear(in_features=1280, out_features=640, bias=True)
(norm2): GroupNorm(32, 640, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): LoRACompatibleConv(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
)
)
(downsamplers): ModuleList(
(0): Downsample2D(
(conv): LoRACompatibleConv(640, 640, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
)
)
)
(2): CrossAttnDownBlock2D(
(attentions): ModuleList(
(0-1): 2 x Transformer2DModel(
(norm): GroupNorm(32, 1280, eps=1e-06, affine=True)
(proj_in): LoRACompatibleConv(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
(transformer_blocks): ModuleList(
(0): BasicTransformerBlock(
(norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
(attn1): Attention(
(to_q): Linear(in_features=1280, out_features=1280, bias=False)
(to_k): Linear(in_features=1280, out_features=1280, bias=False)
(to_v): Linear(in_features=1280, out_features=1280, bias=False)
(to_out): ModuleList(
(0): Linear(in_features=1280, out_features=1280, bias=True)
(1): Dropout(p=0.0, inplace=False)
)
)
(norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
(attn2): Attention(
(to_q): Linear(in_features=1280, out_features=1280, bias=False)
(to_k): Linear(in_features=768, out_features=1280, bias=False)
(to_v): Linear(in_features=768, out_features=1280, bias=False)
(to_out): ModuleList(
(0): Linear(in_features=1280, out_features=1280, bias=True)
(1): Dropout(p=0.0, inplace=False)
)
)
(norm3): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
(ff): FeedForward(
(net): ModuleList(
(0): GEGLU(
(proj): LoRACompatibleLinear(in_features=1280, out_features=10240, bias=True)
)
(1): Dropout(p=0.0, inplace=False)
(2): LoRACompatibleLinear(in_features=5120, out_features=1280, bias=True)
)
)
)
)
(proj_out): LoRACompatibleConv(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
)
)
(resnets): ModuleList(
(0): ResnetBlock2D(
(norm1): GroupNorm(32, 640, eps=1e-05, affine=True)
(conv1): LoRACompatibleConv(640, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=True)
(norm2): GroupNorm(32, 1280, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): LoRACompatibleConv(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
(conv_shortcut): LoRACompatibleConv(640, 1280, kernel_size=(1, 1), stride=(1, 1))
)
(1): ResnetBlock2D(
(norm1): GroupNorm(32, 1280, eps=1e-05, affine=True)
(conv1): LoRACompatibleConv(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=True)
(norm2): GroupNorm(32, 1280, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): LoRACompatibleConv(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
)
)
(downsamplers): ModuleList(
(0): Downsample2D(
(conv): LoRACompatibleConv(1280, 1280, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
)
)
)
(3): DownBlock2D(
(resnets): ModuleList(
(0-1): 2 x ResnetBlock2D(
(norm1): GroupNorm(32, 1280, eps=1e-05, affine=True)
(conv1): LoRACompatibleConv(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=True)
(norm2): GroupNorm(32, 1280, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): LoRACompatibleConv(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
)
)
)
)
(up_blocks): ModuleList(
(0): UpBlock2D(
(resnets): ModuleList(
(0-2): 3 x ResnetBlock2D(
(norm1): GroupNorm(32, 2560, eps=1e-05, affine=True)
(conv1): LoRACompatibleConv(2560, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=True)
(norm2): GroupNorm(32, 1280, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): LoRACompatibleConv(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
(conv_shortcut): LoRACompatibleConv(2560, 1280, kernel_size=(1, 1), stride=(1, 1))
)
)
(upsamplers): ModuleList(
(0): Upsample2D(
(conv): LoRACompatibleConv(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
)
)
(1): CrossAttnUpBlock2D(
(attentions): ModuleList(
(0-2): 3 x Transformer2DModel(
(norm): GroupNorm(32, 1280, eps=1e-06, affine=True)
(proj_in): LoRACompatibleConv(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
(transformer_blocks): ModuleList(
(0): BasicTransformerBlock(
(norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
(attn1): Attention(
(to_q): Linear(in_features=1280, out_features=1280, bias=False)
(to_k): Linear(in_features=1280, out_features=1280, bias=False)
(to_v): Linear(in_features=1280, out_features=1280, bias=False)
(to_out): ModuleList(
(0): Linear(in_features=1280, out_features=1280, bias=True)
(1): Dropout(p=0.0, inplace=False)
)
)
(norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
(attn2): Attention(
(to_q): Linear(in_features=1280, out_features=1280, bias=False)
(to_k): Linear(in_features=768, out_features=1280, bias=False)
(to_v): Linear(in_features=768, out_features=1280, bias=False)
(to_out): ModuleList(
(0): Linear(in_features=1280, out_features=1280, bias=True)
(1): Dropout(p=0.0, inplace=False)
)
)
(norm3): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
(ff): FeedForward(
(net): ModuleList(
(0): GEGLU(
(proj): LoRACompatibleLinear(in_features=1280, out_features=10240, bias=True)
)
(1): Dropout(p=0.0, inplace=False)
(2): LoRACompatibleLinear(in_features=5120, out_features=1280, bias=True)
)
)
)
)
(proj_out): LoRACompatibleConv(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
)
)
(resnets): ModuleList(
(0-1): 2 x ResnetBlock2D(
(norm1): GroupNorm(32, 2560, eps=1e-05, affine=True)
(conv1): LoRACompatibleConv(2560, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=True)
(norm2): GroupNorm(32, 1280, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): LoRACompatibleConv(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
(conv_shortcut): LoRACompatibleConv(2560, 1280, kernel_size=(1, 1), stride=(1, 1))
)
(2): ResnetBlock2D(
(norm1): GroupNorm(32, 1920, eps=1e-05, affine=True)
(conv1): LoRACompatibleConv(1920, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=True)
(norm2): GroupNorm(32, 1280, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): LoRACompatibleConv(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
(conv_shortcut): LoRACompatibleConv(1920, 1280, kernel_size=(1, 1), stride=(1, 1))
)
)
(upsamplers): ModuleList(
(0): Upsample2D(
(conv): LoRACompatibleConv(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
)
)
(2): CrossAttnUpBlock2D(
(attentions): ModuleList(
(0-2): 3 x Transformer2DModel(
(norm): GroupNorm(32, 640, eps=1e-06, affine=True)
(proj_in): LoRACompatibleConv(640, 640, kernel_size=(1, 1), stride=(1, 1))
(transformer_blocks): ModuleList(
(0): BasicTransformerBlock(
(norm1): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
(attn1): Attention(
(to_q): Linear(in_features=640, out_features=640, bias=False)
(to_k): Linear(in_features=640, out_features=640, bias=False)
(to_v): Linear(in_features=640, out_features=640, bias=False)
(to_out): ModuleList(
(0): Linear(in_features=640, out_features=640, bias=True)
(1): Dropout(p=0.0, inplace=False)
)
)
(norm2): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
(attn2): Attention(
(to_q): Linear(in_features=640, out_features=640, bias=False)
(to_k): Linear(in_features=768, out_features=640, bias=False)
(to_v): Linear(in_features=768, out_features=640, bias=False)
(to_out): ModuleList(
(0): Linear(in_features=640, out_features=640, bias=True)
(1): Dropout(p=0.0, inplace=False)
)
)
(norm3): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
(ff): FeedForward(
(net): ModuleList(
(0): GEGLU(
(proj): LoRACompatibleLinear(in_features=640, out_features=5120, bias=True)
)
(1): Dropout(p=0.0, inplace=False)
(2): LoRACompatibleLinear(in_features=2560, out_features=640, bias=True)
)
)
)
)
(proj_out): LoRACompatibleConv(640, 640, kernel_size=(1, 1), stride=(1, 1))
)
)
(resnets): ModuleList(
(0): ResnetBlock2D(
(norm1): GroupNorm(32, 1920, eps=1e-05, affine=True)
(conv1): LoRACompatibleConv(1920, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): LoRACompatibleLinear(in_features=1280, out_features=640, bias=True)
(norm2): GroupNorm(32, 640, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): LoRACompatibleConv(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
(conv_shortcut): LoRACompatibleConv(1920, 640, kernel_size=(1, 1), stride=(1, 1))
)
(1): ResnetBlock2D(
(norm1): GroupNorm(32, 1280, eps=1e-05, affine=True)
(conv1): LoRACompatibleConv(1280, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): LoRACompatibleLinear(in_features=1280, out_features=640, bias=True)
(norm2): GroupNorm(32, 640, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): LoRACompatibleConv(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
(conv_shortcut): LoRACompatibleConv(1280, 640, kernel_size=(1, 1), stride=(1, 1))
)
(2): ResnetBlock2D(
(norm1): GroupNorm(32, 960, eps=1e-05, affine=True)
(conv1): LoRACompatibleConv(960, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): LoRACompatibleLinear(in_features=1280, out_features=640, bias=True)
(norm2): GroupNorm(32, 640, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): LoRACompatibleConv(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
(conv_shortcut): LoRACompatibleConv(960, 640, kernel_size=(1, 1), stride=(1, 1))
)
)
(upsamplers): ModuleList(
(0): Upsample2D(
(conv): LoRACompatibleConv(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
)
)
(3): CrossAttnUpBlock2D(
(attentions): ModuleList(
(0-2): 3 x Transformer2DModel(
(norm): GroupNorm(32, 320, eps=1e-06, affine=True)
(proj_in): LoRACompatibleConv(320, 320, kernel_size=(1, 1), stride=(1, 1))
(transformer_blocks): ModuleList(
(0): BasicTransformerBlock(
(norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
(attn1): Attention(
(to_q): Linear(in_features=320, out_features=320, bias=False)
(to_k): Linear(in_features=320, out_features=320, bias=False)
(to_v): Linear(in_features=320, out_features=320, bias=False)
(to_out): ModuleList(
(0): Linear(in_features=320, out_features=320, bias=True)
(1): Dropout(p=0.0, inplace=False)
)
)
(norm2): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
(attn2): Attention(
(to_q): Linear(in_features=320, out_features=320, bias=False)
(to_k): Linear(in_features=768, out_features=320, bias=False)
(to_v): Linear(in_features=768, out_features=320, bias=False)
(to_out): ModuleList(
(0): Linear(in_features=320, out_features=320, bias=True)
(1): Dropout(p=0.0, inplace=False)
)
)
(norm3): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
(ff): FeedForward(
(net): ModuleList(
(0): GEGLU(
(proj): LoRACompatibleLinear(in_features=320, out_features=2560, bias=True)
)
(1): Dropout(p=0.0, inplace=False)
(2): LoRACompatibleLinear(in_features=1280, out_features=320, bias=True)
)
)
)
)
(proj_out): LoRACompatibleConv(320, 320, kernel_size=(1, 1), stride=(1, 1))
)
)
(resnets): ModuleList(
(0): ResnetBlock2D(
(norm1): GroupNorm(32, 960, eps=1e-05, affine=True)
(conv1): LoRACompatibleConv(960, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): LoRACompatibleLinear(in_features=1280, out_features=320, bias=True)
(norm2): GroupNorm(32, 320, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): LoRACompatibleConv(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
(conv_shortcut): LoRACompatibleConv(960, 320, kernel_size=(1, 1), stride=(1, 1))
)
(1-2): 2 x ResnetBlock2D(
(norm1): GroupNorm(32, 640, eps=1e-05, affine=True)
(conv1): LoRACompatibleConv(640, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): LoRACompatibleLinear(in_features=1280, out_features=320, bias=True)
(norm2): GroupNorm(32, 320, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): LoRACompatibleConv(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
(conv_shortcut): LoRACompatibleConv(640, 320, kernel_size=(1, 1), stride=(1, 1))
)
)
)
)
(mid_block): UNetMidBlock2DCrossAttn(
(attentions): ModuleList(
(0): Transformer2DModel(
(norm): GroupNorm(32, 1280, eps=1e-06, affine=True)
(proj_in): LoRACompatibleConv(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
(transformer_blocks): ModuleList(
(0): BasicTransformerBlock(
(norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
(attn1): Attention(
(to_q): Linear(in_features=1280, out_features=1280, bias=False)
(to_k): Linear(in_features=1280, out_features=1280, bias=False)
(to_v): Linear(in_features=1280, out_features=1280, bias=False)
(to_out): ModuleList(
(0): Linear(in_features=1280, out_features=1280, bias=True)
(1): Dropout(p=0.0, inplace=False)
)
)
(norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
(attn2): Attention(
(to_q): Linear(in_features=1280, out_features=1280, bias=False)
(to_k): Linear(in_features=768, out_features=1280, bias=False)
(to_v): Linear(in_features=768, out_features=1280, bias=False)
(to_out): ModuleList(
(0): Linear(in_features=1280, out_features=1280, bias=True)
(1): Dropout(p=0.0, inplace=False)
)
)
(norm3): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
(ff): FeedForward(
(net): ModuleList(
(0): GEGLU(
(proj): LoRACompatibleLinear(in_features=1280, out_features=10240, bias=True)
)
(1): Dropout(p=0.0, inplace=False)
(2): LoRACompatibleLinear(in_features=5120, out_features=1280, bias=True)
)
)
)
)
(proj_out): LoRACompatibleConv(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
)
)
(resnets): ModuleList(
(0-1): 2 x ResnetBlock2D(
(norm1): GroupNorm(32, 1280, eps=1e-05, affine=True)
(conv1): LoRACompatibleConv(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=True)
(norm2): GroupNorm(32, 1280, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): LoRACompatibleConv(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
)
)
)
(conv_norm_out): GroupNorm(32, 320, eps=1e-05, affine=True)
(conv_act): SiLU()
(conv_out): Conv2d(320, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
Inference
= tokenizer(
text_input * batch_size,
prompt ="max_length",
padding=tokenizer.model_max_length,
max_length=True,
truncation="pt",
return_tensors
).input_ids.to(device)
text_input.shape, text_input
(torch.Size([2, 77]),
tensor([[49406, 320, 8853, 539, 550, 18376, 6765, 320, 4558, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407],
[49406, 320, 8853, 539, 550, 18376, 6765, 320, 4558, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407]], device='cuda:0'))
with torch.no_grad():
= text_encoder(text_input)[0]
text_embeddings
text_embeddings.shape, text_embeddings
(torch.Size([2, 77, 768]),
tensor([[[-0.3884, 0.0229, -0.0522, ..., -0.4899, -0.3066, 0.0675],
[ 0.0290, -1.3258, 0.3085, ..., -0.5257, 0.9768, 0.6652],
[ 0.4595, 0.5617, 1.6663, ..., -1.9515, -1.2307, 0.0104],
...,
[-3.0421, -0.0656, -0.1793, ..., 0.3943, -0.0190, 0.7664],
[-3.0551, -0.1036, -0.1936, ..., 0.4236, -0.0190, 0.7575],
[-2.9854, -0.0832, -0.1715, ..., 0.4355, 0.0095, 0.7485]],
[[-0.3884, 0.0229, -0.0522, ..., -0.4899, -0.3066, 0.0675],
[ 0.0290, -1.3258, 0.3085, ..., -0.5257, 0.9768, 0.6652],
[ 0.4595, 0.5617, 1.6663, ..., -1.9515, -1.2307, 0.0104],
...,
[-3.0421, -0.0656, -0.1793, ..., 0.3943, -0.0190, 0.7664],
[-3.0551, -0.1036, -0.1936, ..., 0.4236, -0.0190, 0.7575],
[-2.9854, -0.0832, -0.1715, ..., 0.4355, 0.0095, 0.7485]]],
device='cuda:0'))
= tokenizer(
uncond_input ""] * batch_size,
[="max_length",
padding=text_input.shape[-1],
max_length="pt",
return_tensors
).input_ids.to(device)
uncond_input.shape, uncond_input
(torch.Size([2, 77]),
tensor([[49406, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407],
[49406, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407]], device='cuda:0'))
with torch.no_grad():
= text_encoder(uncond_input)[0]
uncond_embeddings
uncond_embeddings.shape, uncond_embeddings
(torch.Size([2, 77, 768]),
tensor([[[-0.3884, 0.0229, -0.0522, ..., -0.4899, -0.3066, 0.0675],
[-0.3711, -1.4497, -0.3401, ..., 0.9489, 0.1867, -1.1034],
[-0.5107, -1.4629, -0.2926, ..., 1.0419, 0.0701, -1.0284],
...,
[ 0.5006, -0.9552, -0.6610, ..., 1.6013, -1.0622, -0.2191],
[ 0.4988, -0.9451, -0.6656, ..., 1.6467, -1.0858, -0.2088],
[ 0.4923, -0.8124, -0.4912, ..., 1.6108, -1.0174, -0.2484]],
[[-0.3884, 0.0229, -0.0522, ..., -0.4899, -0.3066, 0.0675],
[-0.3711, -1.4497, -0.3401, ..., 0.9489, 0.1867, -1.1034],
[-0.5107, -1.4629, -0.2926, ..., 1.0419, 0.0701, -1.0284],
...,
[ 0.5006, -0.9552, -0.6610, ..., 1.6013, -1.0622, -0.2191],
[ 0.4988, -0.9451, -0.6656, ..., 1.6467, -1.0858, -0.2088],
[ 0.4923, -0.8124, -0.4912, ..., 1.6108, -1.0174, -0.2484]]],
device='cuda:0'))
= torch.cat([uncond_embeddings, text_embeddings])
text_embeddings
text_embeddings.shape, text_embeddings
(torch.Size([4, 77, 768]),
tensor([[[-0.3884, 0.0229, -0.0522, ..., -0.4899, -0.3066, 0.0675],
[-0.3711, -1.4497, -0.3401, ..., 0.9489, 0.1867, -1.1034],
[-0.5107, -1.4629, -0.2926, ..., 1.0419, 0.0701, -1.0284],
...,
[ 0.5006, -0.9552, -0.6610, ..., 1.6013, -1.0622, -0.2191],
[ 0.4988, -0.9451, -0.6656, ..., 1.6467, -1.0858, -0.2088],
[ 0.4923, -0.8124, -0.4912, ..., 1.6108, -1.0174, -0.2484]],
[[-0.3884, 0.0229, -0.0522, ..., -0.4899, -0.3066, 0.0675],
[-0.3711, -1.4497, -0.3401, ..., 0.9489, 0.1867, -1.1034],
[-0.5107, -1.4629, -0.2926, ..., 1.0419, 0.0701, -1.0284],
...,
[ 0.5006, -0.9552, -0.6610, ..., 1.6013, -1.0622, -0.2191],
[ 0.4988, -0.9451, -0.6656, ..., 1.6467, -1.0858, -0.2088],
[ 0.4923, -0.8124, -0.4912, ..., 1.6108, -1.0174, -0.2484]],
[[-0.3884, 0.0229, -0.0522, ..., -0.4899, -0.3066, 0.0675],
[ 0.0290, -1.3258, 0.3085, ..., -0.5257, 0.9768, 0.6652],
[ 0.4595, 0.5617, 1.6663, ..., -1.9515, -1.2307, 0.0104],
...,
[-3.0421, -0.0656, -0.1793, ..., 0.3943, -0.0190, 0.7664],
[-3.0551, -0.1036, -0.1936, ..., 0.4236, -0.0190, 0.7575],
[-2.9854, -0.0832, -0.1715, ..., 0.4355, 0.0095, 0.7485]],
[[-0.3884, 0.0229, -0.0522, ..., -0.4899, -0.3066, 0.0675],
[ 0.0290, -1.3258, 0.3085, ..., -0.5257, 0.9768, 0.6652],
[ 0.4595, 0.5617, 1.6663, ..., -1.9515, -1.2307, 0.0104],
...,
[-3.0421, -0.0656, -0.1793, ..., 0.3943, -0.0190, 0.7664],
[-3.0551, -0.1036, -0.1936, ..., 0.4236, -0.0190, 0.7575],
[-2.9854, -0.0832, -0.1715, ..., 0.4355, 0.0095, 0.7485]]],
device='cuda:0'))
= torch.randn(
latents // 8, width // 8),
(batch_size, unet.in_channels, height =generator,
generator
).to(device)
latents.shape, latents
(torch.Size([2, 4, 64, 64]),
tensor([[[[ 0.1884, -0.6394, 0.1089, ..., -0.9887, -0.7133, -1.1545],
[ 0.4124, 1.5587, -0.3407, ..., 2.1968, -0.0356, -0.0810],
[-1.8912, 0.0528, -0.4425, ..., 1.3110, 0.7100, 0.6802],
...,
[-1.3443, -0.1747, -0.6298, ..., 0.4572, -0.8584, -0.1284],
[-1.7920, -0.6554, -0.0439, ..., 0.5436, 2.2266, -0.5003],
[ 0.6213, -1.3155, 0.7470, ..., -0.2354, 0.7097, 0.6170]],
[[-0.5007, -1.4418, 0.2598, ..., -0.2586, 2.3239, -1.3245],
[ 0.8540, -0.4135, 0.5658, ..., -1.9556, 2.0454, -0.2454],
[-0.3212, -1.9329, -1.1598, ..., 0.7156, -0.7228, -0.6992],
...,
[ 0.0180, -0.7993, 2.3330, ..., 0.2594, -0.0333, -0.0826],
[-1.2569, -0.8219, 1.3467, ..., 0.4792, 1.8265, -0.6156],
[-1.9367, -0.0949, 0.0720, ..., 0.0806, 0.2966, -1.0284]],
[[ 0.2291, -0.0936, -1.3283, ..., 1.4995, -0.1965, -0.2879],
[-1.0226, -1.2896, 1.6202, ..., -0.3910, -0.3834, 0.5519],
[ 0.5424, 0.2685, 0.4912, ..., 0.9773, -0.8260, 1.1552],
...,
[-1.5280, -0.2530, -1.3748, ..., -1.4948, 1.3661, -1.1294],
[ 0.4241, -0.2996, 1.8231, ..., 0.6968, 0.8247, -0.0279],
[-3.3711, -0.7468, -1.3212, ..., -0.4128, 0.4621, 2.6297]],
[[-0.7510, -0.7452, -0.8998, ..., -1.6957, -0.4004, -0.2596],
[-1.2092, -1.8881, -0.5828, ..., -1.0428, -0.6500, 0.3601],
[-0.4254, 0.9478, 1.3083, ..., -0.0259, -0.4542, 0.4353],
...,
[-0.1918, 0.4858, 0.0666, ..., 0.8505, -0.6606, -0.3193],
[ 1.3620, 0.2283, 0.6292, ..., -0.9271, 1.7018, 0.2161],
[-0.3891, -1.8911, -0.7501, ..., -0.2330, -1.0460, 0.4121]]],
[[[ 0.3649, -1.3183, -1.3308, ..., -0.5548, -1.3610, -1.9329],
[-0.0071, 0.1977, 1.5517, ..., -1.6664, 1.6551, 0.1798],
[-1.0404, 0.6524, 0.4654, ..., -0.5947, -1.0871, 2.2230],
...,
[-0.6844, 0.1692, -0.2559, ..., 0.5511, 0.9734, 0.7936],
[-1.1951, 0.5016, 0.8089, ..., 0.2337, -0.2213, -1.1724],
[-0.5055, -0.7491, -1.4940, ..., -2.1332, 0.9120, 0.2057]],
[[ 1.3668, -1.1680, -0.8574, ..., -0.0635, -1.9132, -0.6023],
[ 1.0974, -0.9654, 1.2987, ..., 1.3187, -0.0241, -0.5427],
[-2.0427, -1.4358, -0.7115, ..., 0.1088, 0.0764, 0.7254],
...,
[ 1.0957, 1.4058, -0.0178, ..., 0.5748, 0.0953, 0.7550],
[ 0.4080, 0.8792, 0.6801, ..., -0.7215, 1.1261, 0.0551],
[-0.3183, -2.3306, 0.7155, ..., 0.4291, -0.2074, -1.1237]],
[[-0.2401, 0.9229, 0.0212, ..., 0.2128, -0.4705, -0.3262],
[ 0.1108, 0.8909, 0.5309, ..., -1.7175, -1.6657, -1.7706],
[-0.1654, -0.4582, -1.2832, ..., 0.5297, -0.8363, 1.0293],
...,
[-1.3526, 2.1482, 0.5417, ..., -2.2156, -1.9940, -0.9745],
[-0.5821, 0.0492, 0.6693, ..., -0.8610, 0.5864, -0.6040],
[ 1.0180, 1.4447, 0.9563, ..., 0.9034, 0.7988, -1.7119]],
[[-1.6146, 0.0868, 0.6415, ..., 0.2083, 0.4058, 0.2813],
[ 0.1969, -0.3334, -0.6526, ..., -1.4639, -1.6302, -0.6036],
[ 0.1556, -0.0859, -0.0230, ..., -0.7900, -0.3481, 0.8767],
...,
[ 0.6056, 0.8374, -0.3834, ..., -0.6636, -0.4814, 0.8244],
[ 0.6982, -0.4884, -1.3777, ..., 0.5876, -2.0944, 0.0853],
[ 0.0388, -0.5761, -0.5116, ..., -1.6645, 0.1752, -0.1923]]]],
device='cuda:0'))
scheduler.set_timesteps(num_inference_steps)
scheduler.timesteps.shape, pipe.scheduler.timesteps
(torch.Size([50]),
tensor([999.0000, 978.6122, 958.2245, 937.8367, 917.4490, 897.0612, 876.6735,
856.2857, 835.8980, 815.5102, 795.1224, 774.7347, 754.3469, 733.9592,
713.5714, 693.1837, 672.7959, 652.4082, 632.0204, 611.6327, 591.2449,
570.8571, 550.4694, 530.0816, 509.6939, 489.3061, 468.9184, 448.5306,
428.1429, 407.7551, 387.3673, 366.9796, 346.5918, 326.2041, 305.8163,
285.4286, 265.0408, 244.6531, 224.2653, 203.8776, 183.4898, 163.1020,
142.7143, 122.3265, 101.9388, 81.5510, 61.1633, 40.7755, 20.3878,
0.0000], dtype=torch.float64))
Classifier-free guidance (CFG)
\tilde{\boldsymbol{\epsilon}}_\theta(\mathbf{z}_t, \mathbf{c}) = w\boldsymbol{\epsilon}_\theta(\mathbf{z}_t, \mathbf{c}) + (1-w)\boldsymbol{\epsilon}_{\theta}(\mathbf{z}_t).
Here, \boldsymbol{\epsilon}_\theta(\mathbf{z}_t, \mathbf{c}) and \boldsymbol{\epsilon}_{\theta}(\mathbf{z}_t) are conditional and unconditional \boldsymbol{\epsilon}-predictions, given by \boldsymbol{\epsilon}_\theta := (\mathbf{z}_t - \alpha_t\hat{\mathbf{x}}_\theta)/\sigma_t, and w is the guidance weight. Setting w = 1 disables classifier-free guidance, while increasing w > 1 strengthens the effect of guidance.1
from tqdm.auto import tqdm
= latents * scheduler.init_noise_sigma
latents
for t in tqdm(scheduler.timesteps):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
= torch.cat([latents] * 2)
latent_model_input
= scheduler.scale_model_input(latent_model_input, t)
latent_model_input
# predict the noise residual
with torch.no_grad():
= unet(
noise_pred =text_embeddings
latent_model_input, t, encoder_hidden_states
).sample
# perform guidance
= noise_pred.chunk(2)
noise_pred_uncond, noise_pred_text = (
noise_pred * noise_pred_text + (1 - guidance_scale) * noise_pred_uncond
guidance_scale
)
# compute the previous noisy sample x_t -> x_t-1
= scheduler.step(noise_pred, t, latents).prev_sample latents
# scale and decode the image latents with vae
= 1 / 0.18215 * latents
latents
with torch.no_grad():
= vae.decode(latents).sample image
from PIL import Image
= (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
image = (image * 255).round().astype("uint8")
images = [Image.fromarray(image) for image in images] pil_images
for pil_image in pil_images:
display(pil_image)
References
- Patil et al. (2022) Stable Diffusion with 🧨 Diffusers, https://huggingface.co/blog/stable_diffusion
Footnotes
Saharia et al. (2022) Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding, https://arxiv.org/abs/2205.11487↩︎