Skip to content

A Compressed Stable Diffusion for Efficient Text-to-Image Generation [ICCV'23 Demo] [ICML'23 Workshop]

License

Notifications You must be signed in to change notification settings

Nota-NetsPresso/BK-SDM

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

64 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Block-removed Knowledge-distilled Stable Diffusion

Official codebase for BK-SDM: Architecturally Compressed Stable Diffusion for Efficient Text-to-Image Generation [ArXiv] [ICCV 2023 Demo Track] [ICML 2023 Workshop on ES-FoMo].

BK-SDMs are lightweight text-to-image (T2I) synthesis models:

  • Certain residual & attention blocks are eliminated from the U-Net of SD.
  • Distillation pretraining is conducted with very limited data, but it (surprisingly) remains effective.

⚡Quick Links: KD Pretraining | Evaluation on MS-COCO | DreamBooth Finetuning | Demo

Notice

Model Description

Installation

conda create -n bk-sdm python=3.8
conda activate bk-sdm
git clone https://github.com/Nota-NetsPresso/BK-SDM.git
cd BK-SDM
pip install -r requirements.txt

Note on the torch versions we've used:

  • torch 1.13.1 for MS-COCO evaluation & DreamBooth finetuning on a single 24GB RTX3090
  • torch 2.0.1 for KD pretraining on a single 80GB A100
    • If pretraining with a total batch size of 256 on A100 causes out-of-GPU-memory, check torch version & consider upgrading to torch>2.0.0.

Minimal Example with 🤗Diffusers

With the default PNDM scheduler and 50 denoising steps:

import torch
from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained("nota-ai/bk-sdm-small", torch_dtype=torch.float16)
pipe = pipe.to("cuda")

prompt = "a golden vase with different flowers"
image = pipe(prompt).images[0]  
    
image.save("example.png")
An equivalent code (modifying solely the U-Net of SD-v1.4 while preserving its Text Encoder and Image Decoder):
import torch
from diffusers import StableDiffusionPipeline, UNet2DConditionModel

pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16)
pipe.unet = UNet2DConditionModel.from_pretrained("nota-ai/bk-sdm-small", subfolder="unet", torch_dtype=torch.float16)
pipe = pipe.to("cuda")

prompt = "a golden vase with different flowers"
image = pipe(prompt).images[0]  
    
image.save("example.png")

Distillation Pretraining

Our code was based on train_text_to_image.py of Diffusers 0.15.0. To access the latest version, use this link.

[Optional] Toy to check runnability

bash scripts/get_laion_data.sh preprocessed_11k
bash scripts/kd_train_toy.sh
Note
  • A toy dataset (11K img-txt pairs) is downloaded at ./data/laion_aes/preprocessed_11k (1.7GB in tar.gz; 1.8GB data folder).
  • A toy script can be used to verify the code executability and find the batch size that matches your GPU. With a batch size of 8 (=4×2), training BK-SDM-Base for 20 iterations takes about 5 minutes and 22GB GPU memory.

Single-gpu training for BK-SDM-{Base, Small, Tiny}

bash scripts/get_laion_data.sh preprocessed_212k
bash scripts/kd_train.sh
Note
  • The dataset with 212K (=0.22M) pairs is downloaded at ./data/laion_aes/preprocessed_212k (18GB tar.gz; 20GB data folder).
  • With a batch size of 256 (=4×64), training BK-SDM-Base for 50K iterations takes about 300 hours and 53GB GPU memory. With a batch size of 64 (=4×16), it takes 60 hours and 28GB GPU memory.
  • Training BK-SDM-{Small, Tiny} results in 5∼10% decrease in GPU memory usage.

Single-gpu training for BK-SDM-{Base-2M, Small-2M, Tiny-2M}

bash scripts/get_laion_data.sh preprocessed_2256k
bash scripts/kd_train_2m.sh
Note
  • The dataset with 2256K (=2.3M) pairs is downloaded at ./data/laion_aes/preprocessed_2256k (182GB tar.gz; 204GB data folder).
  • Except the dataset, kd_train_2m.sh is the same as kd_train.sh; given the same number of iterations, the training computation remains identical.

Multi-gpu training

bash scripts/kd_train_toy_ddp.sh
Note
  • Multi-GPU training is supported (sample results: link), although all experiments for our paper were conducted using a single GPU. Thanks @youngwanLEE for sharing the script :)

Compression of SD-v2 with BK-SDM

bash scripts/kd_train_v2-base-im512.sh
bash scripts/kd_train_v2-im768.sh

# For inference, see: 'scripts/generate_with_trained_unet.sh'  

Note on training code

Key segments for KD training
  • Define Student U-Net by adjusting config.json [link]
  • Initialize Student U-Net by copying Teacher U-Net's weights [link]
  • Define hook locations for feature KD [link]
  • Define losses for feature-and-output KD [link]
Key learning hyperparams
--unet_config_name "bk_small" # option: ["bk_base", "bk_small", "bk_tiny"]
--use_copy_weight_from_teacher # initialize student unet with teacher weights
--learning_rate 5e-05
--train_batch_size 64
--gradient_accumulation_steps 4
--lambda_sd 1.0
--lambda_kd_output 1.0
--lambda_kd_feat 1.0

Evaluation on MS-COCO Benchmark

We used the following codes to obtain the results on MS-COCO. After generating 512×512 images with the PNDM scheduler and 25 denoising steps, we downsampled them to 256×256 for computing scores.

Generation with released models (using BK-SDM-Small as default)

On a single 3090 GPU, '(2)' takes ~10 hours per model, and '(3)' takes a few minutes.

  • (1) Download metadata.csv and real_im256.npz:

    bash scripts/get_mscoco_files.sh
    
    # ./data/mscoco_val2014_30k/metadata.csv: 30K prompts from the MS-COCO validation set (used in '(2)')  
    # ./data/mscoco_val2014_41k_full/real_im256.npz: FID statistics of 41K real images (used in '(3)')
    Note on 'real_im256.npz'
    • Following the evaluation protocol [DALL·E, Imagen], the FID stat for real images was computed over the full validation set (41K images) of MS-COCO. A precomputed stat file is downloaded via '(1)' at ./data/mscoco_val2014_41k_full/real_im256.npz.
    • Additionally, real_im256.npz can be computed with python3 src/get_stat_mscoco_val2014.py, which downloads the whole images, resizes them to 256×256, and computes the FID stat.
  • (2) Generate 512×512 images over 30K prompts from the MS-COCO validation set → Resize them to 256×256:

    python3 src/generate.py 
    
    # python3 src/generate.py --model_id nota-ai/bk-sdm-base --save_dir ./results/bk-sdm-base
    # python3 src/generate.py --model_id nota-ai/bk-sdm-tiny --save_dir ./results/bk-sdm-tiny  

    [Batched generation] Increase --batch_sz (default: 1) for a faster inference at the cost of higher VRAM usage. Thanks @Godofnothing for providing this feature :)

    Click for inference cost details.
    • Setup: BK-SDM-Small on MS-COCO 30K image generation

    • We used an eval batch size of 1 for our paper results. Different batch sizes affect the sampling of random latent codes, resulting in slightly different generation scores.

      Eval Batch Size 1 2 4 8
      GPU Memory 4.9GB 6.3GB 11.3GB 19.6GB
      Generation Time 9.4h 7.9h 7.6h 7.3h
      FID 16.98 17.01 17.16 16.97
      IS 31.68 31.20 31.62 31.22
      CLIP Score 0.2677 0.2679 0.2677 0.2675
  • (3) Compute FID, IS, and CLIP score:

    bash scripts/eval_scores.sh
    
    # For the other models, modify the `./results/bk-sdm-*` path in the scripts to specify different models.

[After training] Generation with a trained U-Net

bash scripts/get_mscoco_files.sh
bash scripts/generate_with_trained_unet.sh

Results on Zero-shot MS-COCO 256×256 30K

See Results in MODEL_CARD.md

DreamBooth Finetuning with 🤗PEFT

Our lightweight SD backbones can be used for efficient personalized generation. DreamBooth refines text-to-image diffusion models given a small number of images. DreamBooth+LoRA can drastically reduce finetuning cost.

DreamBooth dataset

The dataset is downloaded at ./data/dreambooth/dataset [folder tree]: 30 subjects × 25 prompts × 4∼6 images.

git clone https://github.com/google/dreambooth ./data/dreambooth

DreamBooth finetuning (using BK-SDM-Base as default)

Our code was based on train_dreambooth.py of PEFT 0.1.0. To access the latest version, use this link.

  • (1) without LoRA — full finetuning & used in our paper
    bash scripts/finetune_full.sh # learning rate 1e-6
    bash scripts/generate_after_full_ft.sh
  • (2) with LoRA — parameter-efficient finetuning
    bash scripts/finetune_lora.sh # learning rate 1e-4
    bash scripts/generate_after_lora_ft.sh  
  • On a single 3090 GPU, finetuning takes 10~20 minutes per subject.

Results of Personalized Generation

See DreamBooth Results in MODEL_CARD.md

Gradio Demo

Check out our Gradio demo and the codes (main: app.py)!

[Aug/01/2023] featured in Hugging Face Spaces of the week 🔥 Spaces of the week

Core ML Weights

For iOS or macOS applications, we have converted our models to Core ML format. They are available at 🤗Hugging Face Models (nota-ai/coreml-bk-sdm) and can be used with Apple's Core ML Stable Diffusion library.

  • 4-sec inference on iPhone 14 (with 10 denoising steps): results

License

This project, along with its weights, is subject to the CreativeML Open RAIL-M license, which aims to mitigate any potential negative effects arising from the use of highly advanced machine learning systems. A summary of this license is as follows.

1. You can't use the model to deliberately produce nor share illegal or harmful outputs or content,
2. We claim no rights on the outputs you generate, you are free to use them and are accountable for their use which should not go against the provisions set in the license, and
3. You may re-distribute the weights and use the model commercially and/or as a service. If you do, please be aware you have to include the same use restrictions as the ones in the license and share a copy of the CreativeML OpenRAIL-M to all your users.

Acknowledgments

Citation

@article{kim2023architectural,
  title={BK-SDM: A Lightweight, Fast, and Cheap Version of Stable Diffusion},
  author={Kim, Bo-Kyeong and Song, Hyoung-Kyu and Castells, Thibault and Choi, Shinkook},
  journal={arXiv preprint arXiv:2305.15798},
  year={2023},
  url={https://arxiv.org/abs/2305.15798}
}
@article{kim2023bksdm,
  title={BK-SDM: Architecturally Compressed Stable Diffusion for Efficient Text-to-Image Generation},
  author={Kim, Bo-Kyeong and Song, Hyoung-Kyu and Castells, Thibault and Choi, Shinkook},
  journal={ICML Workshop on Efficient Systems for Foundation Models (ES-FoMo)},
  year={2023},
  url={https://openreview.net/forum?id=bOVydU0XKC}
}