Stable Diffusion Inference with Diffusers (low-level)

deep learning
python
english
Author

Mingyu Jeon

Published

September 10, 2023

Stable Diffusion Inference (low-level)

Create an image using Diffusers library.

Install and Import Libraries

!pip install -qq diffusers transformers scipy ftfy accelerate
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
device
'cuda'
prompt = ["a photograph of an astronaut riding a horse"]

height = 512  # default height of Stable Diffusion
width = 512  # default width of Stable Diffusion

num_inference_steps = 50  # Number of denoising steps

guidance_scale = 7.5  # Scale for classifier-free guidance

generator = torch.manual_seed(256)  # Seed generator to create the inital latent noise

batch_size = 2
from diffusers import LMSDiscreteScheduler, StableDiffusionPipeline

scheduler = LMSDiscreteScheduler.from_pretrained(
    "CompVis/stable-diffusion-v1-4", subfolder="scheduler"
)
pipe = StableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4", scheduler=scheduler
)
pipe
StableDiffusionPipeline {
  "_class_name": "StableDiffusionPipeline",
  "_diffusers_version": "0.20.2",
  "_name_or_path": "CompVis/stable-diffusion-v1-4",
  "feature_extractor": [
    "transformers",
    "CLIPImageProcessor"
  ],
  "requires_safety_checker": true,
  "safety_checker": [
    "stable_diffusion",
    "StableDiffusionSafetyChecker"
  ],
  "scheduler": [
    "diffusers",
    "LMSDiscreteScheduler"
  ],
  "text_encoder": [
    "transformers",
    "CLIPTextModel"
  ],
  "tokenizer": [
    "transformers",
    "CLIPTokenizer"
  ],
  "unet": [
    "diffusers",
    "UNet2DConditionModel"
  ],
  "vae": [
    "diffusers",
    "AutoencoderKL"
  ]
}
pipe = pipe.to(device)

Low-level

# 1. Load the autoencoder model which will be used to decode the latents into image space.
vae = pipe.vae

# 2. Load the tokenizer and text encoder to tokenize and encode the text.
tokenizer = pipe.tokenizer
text_encoder = pipe.text_encoder

# 3. The UNet model for generating the latents.
unet = pipe.unet

Models

vae
AutoencoderKL(
  (encoder): Encoder(
    (conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (down_blocks): ModuleList(
      (0): DownEncoderBlock2D(
        (resnets): ModuleList(
          (0-1): 2 x ResnetBlock2D(
            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
            (conv1): LoRACompatibleConv(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): LoRACompatibleConv(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (nonlinearity): SiLU()
          )
        )
        (downsamplers): ModuleList(
          (0): Downsample2D(
            (conv): LoRACompatibleConv(128, 128, kernel_size=(3, 3), stride=(2, 2))
          )
        )
      )
      (1): DownEncoderBlock2D(
        (resnets): ModuleList(
          (0): ResnetBlock2D(
            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
            (conv1): LoRACompatibleConv(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): LoRACompatibleConv(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (nonlinearity): SiLU()
            (conv_shortcut): LoRACompatibleConv(128, 256, kernel_size=(1, 1), stride=(1, 1))
          )
          (1): ResnetBlock2D(
            (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)
            (conv1): LoRACompatibleConv(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): LoRACompatibleConv(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (nonlinearity): SiLU()
          )
        )
        (downsamplers): ModuleList(
          (0): Downsample2D(
            (conv): LoRACompatibleConv(256, 256, kernel_size=(3, 3), stride=(2, 2))
          )
        )
      )
      (2): DownEncoderBlock2D(
        (resnets): ModuleList(
          (0): ResnetBlock2D(
            (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)
            (conv1): LoRACompatibleConv(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): LoRACompatibleConv(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (nonlinearity): SiLU()
            (conv_shortcut): LoRACompatibleConv(256, 512, kernel_size=(1, 1), stride=(1, 1))
          )
          (1): ResnetBlock2D(
            (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
            (conv1): LoRACompatibleConv(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): LoRACompatibleConv(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (nonlinearity): SiLU()
          )
        )
        (downsamplers): ModuleList(
          (0): Downsample2D(
            (conv): LoRACompatibleConv(512, 512, kernel_size=(3, 3), stride=(2, 2))
          )
        )
      )
      (3): DownEncoderBlock2D(
        (resnets): ModuleList(
          (0-1): 2 x ResnetBlock2D(
            (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
            (conv1): LoRACompatibleConv(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): LoRACompatibleConv(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (nonlinearity): SiLU()
          )
        )
      )
    )
    (mid_block): UNetMidBlock2D(
      (attentions): ModuleList(
        (0): Attention(
          (group_norm): GroupNorm(32, 512, eps=1e-06, affine=True)
          (to_q): Linear(in_features=512, out_features=512, bias=True)
          (to_k): Linear(in_features=512, out_features=512, bias=True)
          (to_v): Linear(in_features=512, out_features=512, bias=True)
          (to_out): ModuleList(
            (0): Linear(in_features=512, out_features=512, bias=True)
            (1): Dropout(p=0.0, inplace=False)
          )
        )
      )
      (resnets): ModuleList(
        (0-1): 2 x ResnetBlock2D(
          (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
          (conv1): LoRACompatibleConv(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): LoRACompatibleConv(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
        )
      )
    )
    (conv_norm_out): GroupNorm(32, 512, eps=1e-06, affine=True)
    (conv_act): SiLU()
    (conv_out): Conv2d(512, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (decoder): Decoder(
    (conv_in): Conv2d(4, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (up_blocks): ModuleList(
      (0-1): 2 x UpDecoderBlock2D(
        (resnets): ModuleList(
          (0-2): 3 x ResnetBlock2D(
            (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
            (conv1): LoRACompatibleConv(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): LoRACompatibleConv(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (nonlinearity): SiLU()
          )
        )
        (upsamplers): ModuleList(
          (0): Upsample2D(
            (conv): LoRACompatibleConv(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
        )
      )
      (2): UpDecoderBlock2D(
        (resnets): ModuleList(
          (0): ResnetBlock2D(
            (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
            (conv1): LoRACompatibleConv(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): LoRACompatibleConv(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (nonlinearity): SiLU()
            (conv_shortcut): LoRACompatibleConv(512, 256, kernel_size=(1, 1), stride=(1, 1))
          )
          (1-2): 2 x ResnetBlock2D(
            (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)
            (conv1): LoRACompatibleConv(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): LoRACompatibleConv(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (nonlinearity): SiLU()
          )
        )
        (upsamplers): ModuleList(
          (0): Upsample2D(
            (conv): LoRACompatibleConv(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
        )
      )
      (3): UpDecoderBlock2D(
        (resnets): ModuleList(
          (0): ResnetBlock2D(
            (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)
            (conv1): LoRACompatibleConv(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): LoRACompatibleConv(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (nonlinearity): SiLU()
            (conv_shortcut): LoRACompatibleConv(256, 128, kernel_size=(1, 1), stride=(1, 1))
          )
          (1-2): 2 x ResnetBlock2D(
            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
            (conv1): LoRACompatibleConv(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): LoRACompatibleConv(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (nonlinearity): SiLU()
          )
        )
      )
    )
    (mid_block): UNetMidBlock2D(
      (attentions): ModuleList(
        (0): Attention(
          (group_norm): GroupNorm(32, 512, eps=1e-06, affine=True)
          (to_q): Linear(in_features=512, out_features=512, bias=True)
          (to_k): Linear(in_features=512, out_features=512, bias=True)
          (to_v): Linear(in_features=512, out_features=512, bias=True)
          (to_out): ModuleList(
            (0): Linear(in_features=512, out_features=512, bias=True)
            (1): Dropout(p=0.0, inplace=False)
          )
        )
      )
      (resnets): ModuleList(
        (0-1): 2 x ResnetBlock2D(
          (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
          (conv1): LoRACompatibleConv(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): LoRACompatibleConv(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
        )
      )
    )
    (conv_norm_out): GroupNorm(32, 128, eps=1e-06, affine=True)
    (conv_act): SiLU()
    (conv_out): Conv2d(128, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (quant_conv): Conv2d(8, 8, kernel_size=(1, 1), stride=(1, 1))
  (post_quant_conv): Conv2d(4, 4, kernel_size=(1, 1), stride=(1, 1))
)
tokenizer
CLIPTokenizer(name_or_path='/root/.cache/huggingface/hub/models--CompVis--stable-diffusion-v1-4/snapshots/133a221b8aa7292a167afc5127cb63fb5005638b/tokenizer', vocab_size=49408, model_max_length=77, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': AddedToken("<|startoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'eos_token': AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'unk_token': AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'pad_token': '<|endoftext|>'}, clean_up_tokenization_spaces=True)
text_encoder
CLIPTextModel(
  (text_model): CLIPTextTransformer(
    (embeddings): CLIPTextEmbeddings(
      (token_embedding): Embedding(49408, 768)
      (position_embedding): Embedding(77, 768)
    )
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
          )
          (layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
      )
    )
    (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
)
unet
UNet2DConditionModel(
  (conv_in): Conv2d(4, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (time_proj): Timesteps()
  (time_embedding): TimestepEmbedding(
    (linear_1): Linear(in_features=320, out_features=1280, bias=True)
    (act): SiLU()
    (linear_2): Linear(in_features=1280, out_features=1280, bias=True)
  )
  (down_blocks): ModuleList(
    (0): CrossAttnDownBlock2D(
      (attentions): ModuleList(
        (0-1): 2 x Transformer2DModel(
          (norm): GroupNorm(32, 320, eps=1e-06, affine=True)
          (proj_in): LoRACompatibleConv(320, 320, kernel_size=(1, 1), stride=(1, 1))
          (transformer_blocks): ModuleList(
            (0): BasicTransformerBlock(
              (norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
              (attn1): Attention(
                (to_q): Linear(in_features=320, out_features=320, bias=False)
                (to_k): Linear(in_features=320, out_features=320, bias=False)
                (to_v): Linear(in_features=320, out_features=320, bias=False)
                (to_out): ModuleList(
                  (0): Linear(in_features=320, out_features=320, bias=True)
                  (1): Dropout(p=0.0, inplace=False)
                )
              )
              (norm2): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
              (attn2): Attention(
                (to_q): Linear(in_features=320, out_features=320, bias=False)
                (to_k): Linear(in_features=768, out_features=320, bias=False)
                (to_v): Linear(in_features=768, out_features=320, bias=False)
                (to_out): ModuleList(
                  (0): Linear(in_features=320, out_features=320, bias=True)
                  (1): Dropout(p=0.0, inplace=False)
                )
              )
              (norm3): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
              (ff): FeedForward(
                (net): ModuleList(
                  (0): GEGLU(
                    (proj): LoRACompatibleLinear(in_features=320, out_features=2560, bias=True)
                  )
                  (1): Dropout(p=0.0, inplace=False)
                  (2): LoRACompatibleLinear(in_features=1280, out_features=320, bias=True)
                )
              )
            )
          )
          (proj_out): LoRACompatibleConv(320, 320, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (resnets): ModuleList(
        (0-1): 2 x ResnetBlock2D(
          (norm1): GroupNorm(32, 320, eps=1e-05, affine=True)
          (conv1): LoRACompatibleConv(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): LoRACompatibleLinear(in_features=1280, out_features=320, bias=True)
          (norm2): GroupNorm(32, 320, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): LoRACompatibleConv(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
        )
      )
      (downsamplers): ModuleList(
        (0): Downsample2D(
          (conv): LoRACompatibleConv(320, 320, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        )
      )
    )
    (1): CrossAttnDownBlock2D(
      (attentions): ModuleList(
        (0-1): 2 x Transformer2DModel(
          (norm): GroupNorm(32, 640, eps=1e-06, affine=True)
          (proj_in): LoRACompatibleConv(640, 640, kernel_size=(1, 1), stride=(1, 1))
          (transformer_blocks): ModuleList(
            (0): BasicTransformerBlock(
              (norm1): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
              (attn1): Attention(
                (to_q): Linear(in_features=640, out_features=640, bias=False)
                (to_k): Linear(in_features=640, out_features=640, bias=False)
                (to_v): Linear(in_features=640, out_features=640, bias=False)
                (to_out): ModuleList(
                  (0): Linear(in_features=640, out_features=640, bias=True)
                  (1): Dropout(p=0.0, inplace=False)
                )
              )
              (norm2): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
              (attn2): Attention(
                (to_q): Linear(in_features=640, out_features=640, bias=False)
                (to_k): Linear(in_features=768, out_features=640, bias=False)
                (to_v): Linear(in_features=768, out_features=640, bias=False)
                (to_out): ModuleList(
                  (0): Linear(in_features=640, out_features=640, bias=True)
                  (1): Dropout(p=0.0, inplace=False)
                )
              )
              (norm3): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
              (ff): FeedForward(
                (net): ModuleList(
                  (0): GEGLU(
                    (proj): LoRACompatibleLinear(in_features=640, out_features=5120, bias=True)
                  )
                  (1): Dropout(p=0.0, inplace=False)
                  (2): LoRACompatibleLinear(in_features=2560, out_features=640, bias=True)
                )
              )
            )
          )
          (proj_out): LoRACompatibleConv(640, 640, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (resnets): ModuleList(
        (0): ResnetBlock2D(
          (norm1): GroupNorm(32, 320, eps=1e-05, affine=True)
          (conv1): LoRACompatibleConv(320, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): LoRACompatibleLinear(in_features=1280, out_features=640, bias=True)
          (norm2): GroupNorm(32, 640, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): LoRACompatibleConv(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
          (conv_shortcut): LoRACompatibleConv(320, 640, kernel_size=(1, 1), stride=(1, 1))
        )
        (1): ResnetBlock2D(
          (norm1): GroupNorm(32, 640, eps=1e-05, affine=True)
          (conv1): LoRACompatibleConv(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): LoRACompatibleLinear(in_features=1280, out_features=640, bias=True)
          (norm2): GroupNorm(32, 640, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): LoRACompatibleConv(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
        )
      )
      (downsamplers): ModuleList(
        (0): Downsample2D(
          (conv): LoRACompatibleConv(640, 640, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        )
      )
    )
    (2): CrossAttnDownBlock2D(
      (attentions): ModuleList(
        (0-1): 2 x Transformer2DModel(
          (norm): GroupNorm(32, 1280, eps=1e-06, affine=True)
          (proj_in): LoRACompatibleConv(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
          (transformer_blocks): ModuleList(
            (0): BasicTransformerBlock(
              (norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
              (attn1): Attention(
                (to_q): Linear(in_features=1280, out_features=1280, bias=False)
                (to_k): Linear(in_features=1280, out_features=1280, bias=False)
                (to_v): Linear(in_features=1280, out_features=1280, bias=False)
                (to_out): ModuleList(
                  (0): Linear(in_features=1280, out_features=1280, bias=True)
                  (1): Dropout(p=0.0, inplace=False)
                )
              )
              (norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
              (attn2): Attention(
                (to_q): Linear(in_features=1280, out_features=1280, bias=False)
                (to_k): Linear(in_features=768, out_features=1280, bias=False)
                (to_v): Linear(in_features=768, out_features=1280, bias=False)
                (to_out): ModuleList(
                  (0): Linear(in_features=1280, out_features=1280, bias=True)
                  (1): Dropout(p=0.0, inplace=False)
                )
              )
              (norm3): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
              (ff): FeedForward(
                (net): ModuleList(
                  (0): GEGLU(
                    (proj): LoRACompatibleLinear(in_features=1280, out_features=10240, bias=True)
                  )
                  (1): Dropout(p=0.0, inplace=False)
                  (2): LoRACompatibleLinear(in_features=5120, out_features=1280, bias=True)
                )
              )
            )
          )
          (proj_out): LoRACompatibleConv(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (resnets): ModuleList(
        (0): ResnetBlock2D(
          (norm1): GroupNorm(32, 640, eps=1e-05, affine=True)
          (conv1): LoRACompatibleConv(640, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=True)
          (norm2): GroupNorm(32, 1280, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): LoRACompatibleConv(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
          (conv_shortcut): LoRACompatibleConv(640, 1280, kernel_size=(1, 1), stride=(1, 1))
        )
        (1): ResnetBlock2D(
          (norm1): GroupNorm(32, 1280, eps=1e-05, affine=True)
          (conv1): LoRACompatibleConv(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=True)
          (norm2): GroupNorm(32, 1280, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): LoRACompatibleConv(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
        )
      )
      (downsamplers): ModuleList(
        (0): Downsample2D(
          (conv): LoRACompatibleConv(1280, 1280, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        )
      )
    )
    (3): DownBlock2D(
      (resnets): ModuleList(
        (0-1): 2 x ResnetBlock2D(
          (norm1): GroupNorm(32, 1280, eps=1e-05, affine=True)
          (conv1): LoRACompatibleConv(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=True)
          (norm2): GroupNorm(32, 1280, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): LoRACompatibleConv(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
        )
      )
    )
  )
  (up_blocks): ModuleList(
    (0): UpBlock2D(
      (resnets): ModuleList(
        (0-2): 3 x ResnetBlock2D(
          (norm1): GroupNorm(32, 2560, eps=1e-05, affine=True)
          (conv1): LoRACompatibleConv(2560, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=True)
          (norm2): GroupNorm(32, 1280, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): LoRACompatibleConv(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
          (conv_shortcut): LoRACompatibleConv(2560, 1280, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (upsamplers): ModuleList(
        (0): Upsample2D(
          (conv): LoRACompatibleConv(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
    )
    (1): CrossAttnUpBlock2D(
      (attentions): ModuleList(
        (0-2): 3 x Transformer2DModel(
          (norm): GroupNorm(32, 1280, eps=1e-06, affine=True)
          (proj_in): LoRACompatibleConv(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
          (transformer_blocks): ModuleList(
            (0): BasicTransformerBlock(
              (norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
              (attn1): Attention(
                (to_q): Linear(in_features=1280, out_features=1280, bias=False)
                (to_k): Linear(in_features=1280, out_features=1280, bias=False)
                (to_v): Linear(in_features=1280, out_features=1280, bias=False)
                (to_out): ModuleList(
                  (0): Linear(in_features=1280, out_features=1280, bias=True)
                  (1): Dropout(p=0.0, inplace=False)
                )
              )
              (norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
              (attn2): Attention(
                (to_q): Linear(in_features=1280, out_features=1280, bias=False)
                (to_k): Linear(in_features=768, out_features=1280, bias=False)
                (to_v): Linear(in_features=768, out_features=1280, bias=False)
                (to_out): ModuleList(
                  (0): Linear(in_features=1280, out_features=1280, bias=True)
                  (1): Dropout(p=0.0, inplace=False)
                )
              )
              (norm3): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
              (ff): FeedForward(
                (net): ModuleList(
                  (0): GEGLU(
                    (proj): LoRACompatibleLinear(in_features=1280, out_features=10240, bias=True)
                  )
                  (1): Dropout(p=0.0, inplace=False)
                  (2): LoRACompatibleLinear(in_features=5120, out_features=1280, bias=True)
                )
              )
            )
          )
          (proj_out): LoRACompatibleConv(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (resnets): ModuleList(
        (0-1): 2 x ResnetBlock2D(
          (norm1): GroupNorm(32, 2560, eps=1e-05, affine=True)
          (conv1): LoRACompatibleConv(2560, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=True)
          (norm2): GroupNorm(32, 1280, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): LoRACompatibleConv(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
          (conv_shortcut): LoRACompatibleConv(2560, 1280, kernel_size=(1, 1), stride=(1, 1))
        )
        (2): ResnetBlock2D(
          (norm1): GroupNorm(32, 1920, eps=1e-05, affine=True)
          (conv1): LoRACompatibleConv(1920, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=True)
          (norm2): GroupNorm(32, 1280, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): LoRACompatibleConv(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
          (conv_shortcut): LoRACompatibleConv(1920, 1280, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (upsamplers): ModuleList(
        (0): Upsample2D(
          (conv): LoRACompatibleConv(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
    )
    (2): CrossAttnUpBlock2D(
      (attentions): ModuleList(
        (0-2): 3 x Transformer2DModel(
          (norm): GroupNorm(32, 640, eps=1e-06, affine=True)
          (proj_in): LoRACompatibleConv(640, 640, kernel_size=(1, 1), stride=(1, 1))
          (transformer_blocks): ModuleList(
            (0): BasicTransformerBlock(
              (norm1): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
              (attn1): Attention(
                (to_q): Linear(in_features=640, out_features=640, bias=False)
                (to_k): Linear(in_features=640, out_features=640, bias=False)
                (to_v): Linear(in_features=640, out_features=640, bias=False)
                (to_out): ModuleList(
                  (0): Linear(in_features=640, out_features=640, bias=True)
                  (1): Dropout(p=0.0, inplace=False)
                )
              )
              (norm2): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
              (attn2): Attention(
                (to_q): Linear(in_features=640, out_features=640, bias=False)
                (to_k): Linear(in_features=768, out_features=640, bias=False)
                (to_v): Linear(in_features=768, out_features=640, bias=False)
                (to_out): ModuleList(
                  (0): Linear(in_features=640, out_features=640, bias=True)
                  (1): Dropout(p=0.0, inplace=False)
                )
              )
              (norm3): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
              (ff): FeedForward(
                (net): ModuleList(
                  (0): GEGLU(
                    (proj): LoRACompatibleLinear(in_features=640, out_features=5120, bias=True)
                  )
                  (1): Dropout(p=0.0, inplace=False)
                  (2): LoRACompatibleLinear(in_features=2560, out_features=640, bias=True)
                )
              )
            )
          )
          (proj_out): LoRACompatibleConv(640, 640, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (resnets): ModuleList(
        (0): ResnetBlock2D(
          (norm1): GroupNorm(32, 1920, eps=1e-05, affine=True)
          (conv1): LoRACompatibleConv(1920, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): LoRACompatibleLinear(in_features=1280, out_features=640, bias=True)
          (norm2): GroupNorm(32, 640, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): LoRACompatibleConv(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
          (conv_shortcut): LoRACompatibleConv(1920, 640, kernel_size=(1, 1), stride=(1, 1))
        )
        (1): ResnetBlock2D(
          (norm1): GroupNorm(32, 1280, eps=1e-05, affine=True)
          (conv1): LoRACompatibleConv(1280, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): LoRACompatibleLinear(in_features=1280, out_features=640, bias=True)
          (norm2): GroupNorm(32, 640, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): LoRACompatibleConv(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
          (conv_shortcut): LoRACompatibleConv(1280, 640, kernel_size=(1, 1), stride=(1, 1))
        )
        (2): ResnetBlock2D(
          (norm1): GroupNorm(32, 960, eps=1e-05, affine=True)
          (conv1): LoRACompatibleConv(960, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): LoRACompatibleLinear(in_features=1280, out_features=640, bias=True)
          (norm2): GroupNorm(32, 640, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): LoRACompatibleConv(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
          (conv_shortcut): LoRACompatibleConv(960, 640, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (upsamplers): ModuleList(
        (0): Upsample2D(
          (conv): LoRACompatibleConv(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
    )
    (3): CrossAttnUpBlock2D(
      (attentions): ModuleList(
        (0-2): 3 x Transformer2DModel(
          (norm): GroupNorm(32, 320, eps=1e-06, affine=True)
          (proj_in): LoRACompatibleConv(320, 320, kernel_size=(1, 1), stride=(1, 1))
          (transformer_blocks): ModuleList(
            (0): BasicTransformerBlock(
              (norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
              (attn1): Attention(
                (to_q): Linear(in_features=320, out_features=320, bias=False)
                (to_k): Linear(in_features=320, out_features=320, bias=False)
                (to_v): Linear(in_features=320, out_features=320, bias=False)
                (to_out): ModuleList(
                  (0): Linear(in_features=320, out_features=320, bias=True)
                  (1): Dropout(p=0.0, inplace=False)
                )
              )
              (norm2): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
              (attn2): Attention(
                (to_q): Linear(in_features=320, out_features=320, bias=False)
                (to_k): Linear(in_features=768, out_features=320, bias=False)
                (to_v): Linear(in_features=768, out_features=320, bias=False)
                (to_out): ModuleList(
                  (0): Linear(in_features=320, out_features=320, bias=True)
                  (1): Dropout(p=0.0, inplace=False)
                )
              )
              (norm3): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
              (ff): FeedForward(
                (net): ModuleList(
                  (0): GEGLU(
                    (proj): LoRACompatibleLinear(in_features=320, out_features=2560, bias=True)
                  )
                  (1): Dropout(p=0.0, inplace=False)
                  (2): LoRACompatibleLinear(in_features=1280, out_features=320, bias=True)
                )
              )
            )
          )
          (proj_out): LoRACompatibleConv(320, 320, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (resnets): ModuleList(
        (0): ResnetBlock2D(
          (norm1): GroupNorm(32, 960, eps=1e-05, affine=True)
          (conv1): LoRACompatibleConv(960, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): LoRACompatibleLinear(in_features=1280, out_features=320, bias=True)
          (norm2): GroupNorm(32, 320, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): LoRACompatibleConv(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
          (conv_shortcut): LoRACompatibleConv(960, 320, kernel_size=(1, 1), stride=(1, 1))
        )
        (1-2): 2 x ResnetBlock2D(
          (norm1): GroupNorm(32, 640, eps=1e-05, affine=True)
          (conv1): LoRACompatibleConv(640, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): LoRACompatibleLinear(in_features=1280, out_features=320, bias=True)
          (norm2): GroupNorm(32, 320, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): LoRACompatibleConv(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
          (conv_shortcut): LoRACompatibleConv(640, 320, kernel_size=(1, 1), stride=(1, 1))
        )
      )
    )
  )
  (mid_block): UNetMidBlock2DCrossAttn(
    (attentions): ModuleList(
      (0): Transformer2DModel(
        (norm): GroupNorm(32, 1280, eps=1e-06, affine=True)
        (proj_in): LoRACompatibleConv(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
        (transformer_blocks): ModuleList(
          (0): BasicTransformerBlock(
            (norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
            (attn1): Attention(
              (to_q): Linear(in_features=1280, out_features=1280, bias=False)
              (to_k): Linear(in_features=1280, out_features=1280, bias=False)
              (to_v): Linear(in_features=1280, out_features=1280, bias=False)
              (to_out): ModuleList(
                (0): Linear(in_features=1280, out_features=1280, bias=True)
                (1): Dropout(p=0.0, inplace=False)
              )
            )
            (norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
            (attn2): Attention(
              (to_q): Linear(in_features=1280, out_features=1280, bias=False)
              (to_k): Linear(in_features=768, out_features=1280, bias=False)
              (to_v): Linear(in_features=768, out_features=1280, bias=False)
              (to_out): ModuleList(
                (0): Linear(in_features=1280, out_features=1280, bias=True)
                (1): Dropout(p=0.0, inplace=False)
              )
            )
            (norm3): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
            (ff): FeedForward(
              (net): ModuleList(
                (0): GEGLU(
                  (proj): LoRACompatibleLinear(in_features=1280, out_features=10240, bias=True)
                )
                (1): Dropout(p=0.0, inplace=False)
                (2): LoRACompatibleLinear(in_features=5120, out_features=1280, bias=True)
              )
            )
          )
        )
        (proj_out): LoRACompatibleConv(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
      )
    )
    (resnets): ModuleList(
      (0-1): 2 x ResnetBlock2D(
        (norm1): GroupNorm(32, 1280, eps=1e-05, affine=True)
        (conv1): LoRACompatibleConv(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (time_emb_proj): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=True)
        (norm2): GroupNorm(32, 1280, eps=1e-05, affine=True)
        (dropout): Dropout(p=0.0, inplace=False)
        (conv2): LoRACompatibleConv(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (nonlinearity): SiLU()
      )
    )
  )
  (conv_norm_out): GroupNorm(32, 320, eps=1e-05, affine=True)
  (conv_act): SiLU()
  (conv_out): Conv2d(320, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)

Inference

text_input = tokenizer(
    prompt * batch_size,
    padding="max_length",
    max_length=tokenizer.model_max_length,
    truncation=True,
    return_tensors="pt",
).input_ids.to(device)

text_input.shape, text_input
(torch.Size([2, 77]),
 tensor([[49406,   320,  8853,   539,   550, 18376,  6765,   320,  4558, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407],
         [49406,   320,  8853,   539,   550, 18376,  6765,   320,  4558, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407]], device='cuda:0'))
with torch.no_grad():
    text_embeddings = text_encoder(text_input)[0]

text_embeddings.shape, text_embeddings
(torch.Size([2, 77, 768]),
 tensor([[[-0.3884,  0.0229, -0.0522,  ..., -0.4899, -0.3066,  0.0675],
          [ 0.0290, -1.3258,  0.3085,  ..., -0.5257,  0.9768,  0.6652],
          [ 0.4595,  0.5617,  1.6663,  ..., -1.9515, -1.2307,  0.0104],
          ...,
          [-3.0421, -0.0656, -0.1793,  ...,  0.3943, -0.0190,  0.7664],
          [-3.0551, -0.1036, -0.1936,  ...,  0.4236, -0.0190,  0.7575],
          [-2.9854, -0.0832, -0.1715,  ...,  0.4355,  0.0095,  0.7485]],
 
         [[-0.3884,  0.0229, -0.0522,  ..., -0.4899, -0.3066,  0.0675],
          [ 0.0290, -1.3258,  0.3085,  ..., -0.5257,  0.9768,  0.6652],
          [ 0.4595,  0.5617,  1.6663,  ..., -1.9515, -1.2307,  0.0104],
          ...,
          [-3.0421, -0.0656, -0.1793,  ...,  0.3943, -0.0190,  0.7664],
          [-3.0551, -0.1036, -0.1936,  ...,  0.4236, -0.0190,  0.7575],
          [-2.9854, -0.0832, -0.1715,  ...,  0.4355,  0.0095,  0.7485]]],
        device='cuda:0'))
uncond_input = tokenizer(
    [""] * batch_size,
    padding="max_length",
    max_length=text_input.shape[-1],
    return_tensors="pt",
).input_ids.to(device)

uncond_input.shape, uncond_input
(torch.Size([2, 77]),
 tensor([[49406, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407],
         [49406, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407]], device='cuda:0'))
with torch.no_grad():
    uncond_embeddings = text_encoder(uncond_input)[0]

uncond_embeddings.shape, uncond_embeddings
(torch.Size([2, 77, 768]),
 tensor([[[-0.3884,  0.0229, -0.0522,  ..., -0.4899, -0.3066,  0.0675],
          [-0.3711, -1.4497, -0.3401,  ...,  0.9489,  0.1867, -1.1034],
          [-0.5107, -1.4629, -0.2926,  ...,  1.0419,  0.0701, -1.0284],
          ...,
          [ 0.5006, -0.9552, -0.6610,  ...,  1.6013, -1.0622, -0.2191],
          [ 0.4988, -0.9451, -0.6656,  ...,  1.6467, -1.0858, -0.2088],
          [ 0.4923, -0.8124, -0.4912,  ...,  1.6108, -1.0174, -0.2484]],
 
         [[-0.3884,  0.0229, -0.0522,  ..., -0.4899, -0.3066,  0.0675],
          [-0.3711, -1.4497, -0.3401,  ...,  0.9489,  0.1867, -1.1034],
          [-0.5107, -1.4629, -0.2926,  ...,  1.0419,  0.0701, -1.0284],
          ...,
          [ 0.5006, -0.9552, -0.6610,  ...,  1.6013, -1.0622, -0.2191],
          [ 0.4988, -0.9451, -0.6656,  ...,  1.6467, -1.0858, -0.2088],
          [ 0.4923, -0.8124, -0.4912,  ...,  1.6108, -1.0174, -0.2484]]],
        device='cuda:0'))
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

text_embeddings.shape, text_embeddings
(torch.Size([4, 77, 768]),
 tensor([[[-0.3884,  0.0229, -0.0522,  ..., -0.4899, -0.3066,  0.0675],
          [-0.3711, -1.4497, -0.3401,  ...,  0.9489,  0.1867, -1.1034],
          [-0.5107, -1.4629, -0.2926,  ...,  1.0419,  0.0701, -1.0284],
          ...,
          [ 0.5006, -0.9552, -0.6610,  ...,  1.6013, -1.0622, -0.2191],
          [ 0.4988, -0.9451, -0.6656,  ...,  1.6467, -1.0858, -0.2088],
          [ 0.4923, -0.8124, -0.4912,  ...,  1.6108, -1.0174, -0.2484]],
 
         [[-0.3884,  0.0229, -0.0522,  ..., -0.4899, -0.3066,  0.0675],
          [-0.3711, -1.4497, -0.3401,  ...,  0.9489,  0.1867, -1.1034],
          [-0.5107, -1.4629, -0.2926,  ...,  1.0419,  0.0701, -1.0284],
          ...,
          [ 0.5006, -0.9552, -0.6610,  ...,  1.6013, -1.0622, -0.2191],
          [ 0.4988, -0.9451, -0.6656,  ...,  1.6467, -1.0858, -0.2088],
          [ 0.4923, -0.8124, -0.4912,  ...,  1.6108, -1.0174, -0.2484]],
 
         [[-0.3884,  0.0229, -0.0522,  ..., -0.4899, -0.3066,  0.0675],
          [ 0.0290, -1.3258,  0.3085,  ..., -0.5257,  0.9768,  0.6652],
          [ 0.4595,  0.5617,  1.6663,  ..., -1.9515, -1.2307,  0.0104],
          ...,
          [-3.0421, -0.0656, -0.1793,  ...,  0.3943, -0.0190,  0.7664],
          [-3.0551, -0.1036, -0.1936,  ...,  0.4236, -0.0190,  0.7575],
          [-2.9854, -0.0832, -0.1715,  ...,  0.4355,  0.0095,  0.7485]],
 
         [[-0.3884,  0.0229, -0.0522,  ..., -0.4899, -0.3066,  0.0675],
          [ 0.0290, -1.3258,  0.3085,  ..., -0.5257,  0.9768,  0.6652],
          [ 0.4595,  0.5617,  1.6663,  ..., -1.9515, -1.2307,  0.0104],
          ...,
          [-3.0421, -0.0656, -0.1793,  ...,  0.3943, -0.0190,  0.7664],
          [-3.0551, -0.1036, -0.1936,  ...,  0.4236, -0.0190,  0.7575],
          [-2.9854, -0.0832, -0.1715,  ...,  0.4355,  0.0095,  0.7485]]],
        device='cuda:0'))
latents = torch.randn(
    (batch_size, unet.in_channels, height // 8, width // 8),
    generator=generator,
).to(device)

latents.shape, latents
(torch.Size([2, 4, 64, 64]),
 tensor([[[[ 0.1884, -0.6394,  0.1089,  ..., -0.9887, -0.7133, -1.1545],
           [ 0.4124,  1.5587, -0.3407,  ...,  2.1968, -0.0356, -0.0810],
           [-1.8912,  0.0528, -0.4425,  ...,  1.3110,  0.7100,  0.6802],
           ...,
           [-1.3443, -0.1747, -0.6298,  ...,  0.4572, -0.8584, -0.1284],
           [-1.7920, -0.6554, -0.0439,  ...,  0.5436,  2.2266, -0.5003],
           [ 0.6213, -1.3155,  0.7470,  ..., -0.2354,  0.7097,  0.6170]],
 
          [[-0.5007, -1.4418,  0.2598,  ..., -0.2586,  2.3239, -1.3245],
           [ 0.8540, -0.4135,  0.5658,  ..., -1.9556,  2.0454, -0.2454],
           [-0.3212, -1.9329, -1.1598,  ...,  0.7156, -0.7228, -0.6992],
           ...,
           [ 0.0180, -0.7993,  2.3330,  ...,  0.2594, -0.0333, -0.0826],
           [-1.2569, -0.8219,  1.3467,  ...,  0.4792,  1.8265, -0.6156],
           [-1.9367, -0.0949,  0.0720,  ...,  0.0806,  0.2966, -1.0284]],
 
          [[ 0.2291, -0.0936, -1.3283,  ...,  1.4995, -0.1965, -0.2879],
           [-1.0226, -1.2896,  1.6202,  ..., -0.3910, -0.3834,  0.5519],
           [ 0.5424,  0.2685,  0.4912,  ...,  0.9773, -0.8260,  1.1552],
           ...,
           [-1.5280, -0.2530, -1.3748,  ..., -1.4948,  1.3661, -1.1294],
           [ 0.4241, -0.2996,  1.8231,  ...,  0.6968,  0.8247, -0.0279],
           [-3.3711, -0.7468, -1.3212,  ..., -0.4128,  0.4621,  2.6297]],
 
          [[-0.7510, -0.7452, -0.8998,  ..., -1.6957, -0.4004, -0.2596],
           [-1.2092, -1.8881, -0.5828,  ..., -1.0428, -0.6500,  0.3601],
           [-0.4254,  0.9478,  1.3083,  ..., -0.0259, -0.4542,  0.4353],
           ...,
           [-0.1918,  0.4858,  0.0666,  ...,  0.8505, -0.6606, -0.3193],
           [ 1.3620,  0.2283,  0.6292,  ..., -0.9271,  1.7018,  0.2161],
           [-0.3891, -1.8911, -0.7501,  ..., -0.2330, -1.0460,  0.4121]]],
 
 
         [[[ 0.3649, -1.3183, -1.3308,  ..., -0.5548, -1.3610, -1.9329],
           [-0.0071,  0.1977,  1.5517,  ..., -1.6664,  1.6551,  0.1798],
           [-1.0404,  0.6524,  0.4654,  ..., -0.5947, -1.0871,  2.2230],
           ...,
           [-0.6844,  0.1692, -0.2559,  ...,  0.5511,  0.9734,  0.7936],
           [-1.1951,  0.5016,  0.8089,  ...,  0.2337, -0.2213, -1.1724],
           [-0.5055, -0.7491, -1.4940,  ..., -2.1332,  0.9120,  0.2057]],
 
          [[ 1.3668, -1.1680, -0.8574,  ..., -0.0635, -1.9132, -0.6023],
           [ 1.0974, -0.9654,  1.2987,  ...,  1.3187, -0.0241, -0.5427],
           [-2.0427, -1.4358, -0.7115,  ...,  0.1088,  0.0764,  0.7254],
           ...,
           [ 1.0957,  1.4058, -0.0178,  ...,  0.5748,  0.0953,  0.7550],
           [ 0.4080,  0.8792,  0.6801,  ..., -0.7215,  1.1261,  0.0551],
           [-0.3183, -2.3306,  0.7155,  ...,  0.4291, -0.2074, -1.1237]],
 
          [[-0.2401,  0.9229,  0.0212,  ...,  0.2128, -0.4705, -0.3262],
           [ 0.1108,  0.8909,  0.5309,  ..., -1.7175, -1.6657, -1.7706],
           [-0.1654, -0.4582, -1.2832,  ...,  0.5297, -0.8363,  1.0293],
           ...,
           [-1.3526,  2.1482,  0.5417,  ..., -2.2156, -1.9940, -0.9745],
           [-0.5821,  0.0492,  0.6693,  ..., -0.8610,  0.5864, -0.6040],
           [ 1.0180,  1.4447,  0.9563,  ...,  0.9034,  0.7988, -1.7119]],
 
          [[-1.6146,  0.0868,  0.6415,  ...,  0.2083,  0.4058,  0.2813],
           [ 0.1969, -0.3334, -0.6526,  ..., -1.4639, -1.6302, -0.6036],
           [ 0.1556, -0.0859, -0.0230,  ..., -0.7900, -0.3481,  0.8767],
           ...,
           [ 0.6056,  0.8374, -0.3834,  ..., -0.6636, -0.4814,  0.8244],
           [ 0.6982, -0.4884, -1.3777,  ...,  0.5876, -2.0944,  0.0853],
           [ 0.0388, -0.5761, -0.5116,  ..., -1.6645,  0.1752, -0.1923]]]],
        device='cuda:0'))
scheduler.set_timesteps(num_inference_steps)

scheduler.timesteps.shape, pipe.scheduler.timesteps
(torch.Size([50]),
 tensor([999.0000, 978.6122, 958.2245, 937.8367, 917.4490, 897.0612, 876.6735,
         856.2857, 835.8980, 815.5102, 795.1224, 774.7347, 754.3469, 733.9592,
         713.5714, 693.1837, 672.7959, 652.4082, 632.0204, 611.6327, 591.2449,
         570.8571, 550.4694, 530.0816, 509.6939, 489.3061, 468.9184, 448.5306,
         428.1429, 407.7551, 387.3673, 366.9796, 346.5918, 326.2041, 305.8163,
         285.4286, 265.0408, 244.6531, 224.2653, 203.8776, 183.4898, 163.1020,
         142.7143, 122.3265, 101.9388,  81.5510,  61.1633,  40.7755,  20.3878,
           0.0000], dtype=torch.float64))

Classifier-free guidance (CFG)

\tilde{\boldsymbol{\epsilon}}_\theta(\mathbf{z}_t, \mathbf{c}) = w\boldsymbol{\epsilon}_\theta(\mathbf{z}_t, \mathbf{c}) + (1-w)\boldsymbol{\epsilon}_{\theta}(\mathbf{z}_t).

Here, \boldsymbol{\epsilon}_\theta(\mathbf{z}_t, \mathbf{c}) and \boldsymbol{\epsilon}_{\theta}(\mathbf{z}_t) are conditional and unconditional \boldsymbol{\epsilon}-predictions, given by \boldsymbol{\epsilon}_\theta := (\mathbf{z}_t - \alpha_t\hat{\mathbf{x}}_\theta)/\sigma_t, and w is the guidance weight. Setting w = 1 disables classifier-free guidance, while increasing w > 1 strengthens the effect of guidance.1

from tqdm.auto import tqdm

latents = latents * scheduler.init_noise_sigma

for t in tqdm(scheduler.timesteps):
    # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
    latent_model_input = torch.cat([latents] * 2)

    latent_model_input = scheduler.scale_model_input(latent_model_input, t)

    # predict the noise residual
    with torch.no_grad():
        noise_pred = unet(
            latent_model_input, t, encoder_hidden_states=text_embeddings
        ).sample

    # perform guidance
    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
    noise_pred = (
        guidance_scale * noise_pred_text + (1 - guidance_scale) * noise_pred_uncond
    )

    # compute the previous noisy sample x_t -> x_t-1
    latents = scheduler.step(noise_pred, t, latents).prev_sample
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents

with torch.no_grad():
    image = vae.decode(latents).sample
from PIL import Image

image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
for pil_image in pil_images:
    display(pil_image)

References

Footnotes

  1. Saharia et al. (2022) Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding, https://arxiv.org/abs/2205.11487↩︎