diff --git a/stable_diffusion/README.md b/stable_diffusion/README.md index d7b983538..c8b735e59 100644 --- a/stable_diffusion/README.md +++ b/stable_diffusion/README.md @@ -32,6 +32,11 @@ understand the fundamentals of the Stable Diffusion model. - [Validation metrics](#validation-metrics) - [FID](#fid) - [CLIP](#clip) +- [Quality](#quality) + - [Quality metric](#quality-metric) + - [Quality target](#quality-target) + - [Evaluation frequency](#evaluation-frequency) + - [Evaluation thoroughness](#evaluation-thoroughness) - [Reference runs](#reference-runs) - [Rules](#rules) - [BibTeX](#bibtex) @@ -232,6 +237,21 @@ Further insights and an independent evaluation of the FID score can be found in ### CLIP CLIP is a reference free metric that can be used to evaluate the correlation between a caption for an image and the actual content of the image, it has been found to be highly correlated with human judgement. A higher CLIP Score implies that the caption matches closer to image. +# Quality +## Quality metric +Both FID and CLIP are used to evaulte the model's quality. + +## Quality target +FID<=90 and CLIP>=0.15 + +## Evaluation frequency +Every 512,000 images, or `CEIL(512000 / global_batch_size)` if 512,000 is not divisible by GBS. + +Please refer to the benchmark rules [here](https://github.com/mlcommons/training_policies/blob/master/training_rules.adoc) for the exact evaluation rules. + +## Evaluation thoroughness +All the prompts in the [coco-2014](#coco-2014) validation subset. + # Reference runs The benchmark is expected to have the following convergence profile: diff --git a/stable_diffusion/configs/train_01x08x08.yaml b/stable_diffusion/configs/train_01x08x08.yaml index 73db364bf..a7497cb61 100644 --- a/stable_diffusion/configs/train_01x08x08.yaml +++ b/stable_diffusion/configs/train_01x08x08.yaml @@ -38,7 +38,7 @@ model: enabled: True inception_weights_url: https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth cache_dir: /checkpoints/inception - gt_path: /datasets/coco2014/val2014_30k_stats.npz + gt_path: /datasets/coco2014/val2014_512x512_30k_stats.npz clip: enabled: True clip_version: "ViT-H-14" diff --git a/stable_diffusion/configs/train_32x08x02.yaml b/stable_diffusion/configs/train_32x08x02.yaml index aaa4a4681..5fd81da1a 100644 --- a/stable_diffusion/configs/train_32x08x02.yaml +++ b/stable_diffusion/configs/train_32x08x02.yaml @@ -38,7 +38,7 @@ model: enabled: True inception_weights_url: https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth cache_dir: /checkpoints/inception - gt_path: /datasets/coco2014/val2014_30k_stats.npz + gt_path: /datasets/coco2014/val2014_512x512_30k_stats.npz clip: enabled: True clip_version: "ViT-H-14" diff --git a/stable_diffusion/configs/train_32x08x02_raw_images.yaml b/stable_diffusion/configs/train_32x08x02_raw_images.yaml index 582147325..c36d2ca4f 100644 --- a/stable_diffusion/configs/train_32x08x02_raw_images.yaml +++ b/stable_diffusion/configs/train_32x08x02_raw_images.yaml @@ -38,7 +38,7 @@ model: enabled: True inception_weights_url: https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth cache_dir: /checkpoints/inception - gt_path: /datasets/coco2014/val2014_30k_stats.npz + gt_path: /datasets/coco2014/val2014_512x512_30k_stats.npz clip: enabled: True clip_version: "ViT-H-14" diff --git a/stable_diffusion/configs/train_32x08x04.yaml b/stable_diffusion/configs/train_32x08x04.yaml index 747aa3d50..f275f1e20 100644 --- a/stable_diffusion/configs/train_32x08x04.yaml +++ b/stable_diffusion/configs/train_32x08x04.yaml @@ -38,7 +38,7 @@ model: enabled: True inception_weights_url: https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth cache_dir: /checkpoints/inception - gt_path: /datasets/coco2014/val2014_30k_stats.npz + gt_path: /datasets/coco2014/val2014_512x512_30k_stats.npz clip: enabled: True clip_version: "ViT-H-14" diff --git a/stable_diffusion/configs/train_32x08x08.yaml b/stable_diffusion/configs/train_32x08x08.yaml index 166a5deed..5773fa8ef 100644 --- a/stable_diffusion/configs/train_32x08x08.yaml +++ b/stable_diffusion/configs/train_32x08x08.yaml @@ -38,7 +38,7 @@ model: enabled: True inception_weights_url: https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth cache_dir: /checkpoints/inception - gt_path: /datasets/coco2014/val2014_30k_stats.npz + gt_path: /datasets/coco2014/val2014_512x512_30k_stats.npz clip: enabled: True clip_version: "ViT-H-14" diff --git a/stable_diffusion/ldm/models/diffusion/ddpm.py b/stable_diffusion/ldm/models/diffusion/ddpm.py index 13da5acdb..c7f42f5df 100644 --- a/stable_diffusion/ldm/models/diffusion/ddpm.py +++ b/stable_diffusion/ldm/models/diffusion/ddpm.py @@ -211,17 +211,17 @@ def __init__(self, self.validation_run_clip = validation_config["clip"]["enabled"] if self.validation_save_images or self.validation_run_fid or self.validation_run_clip: - if self.validation_save_images: - self.validation_base_output_dir = validation_config["save_images"]["base_output_dir"] + if self.validation_save_images: + self.validation_base_output_dir = validation_config["save_images"]["base_output_dir"] - if self.validation_run_fid: - self.inception_weights_url = validation_config["fid"]["inception_weights_url"] - self.inception_cache_dir = validation_config["fid"]["cache_dir"] - self.fid_gt_path = validation_config["fid"]["gt_path"] + if self.validation_run_fid: + self.inception_weights_url = validation_config["fid"]["inception_weights_url"] + self.inception_cache_dir = validation_config["fid"]["cache_dir"] + self.fid_gt_path = validation_config["fid"]["gt_path"] - if self.validation_run_clip: - self.clip_version = validation_config["clip"]["clip_version"] - self.clip_cache_dir = validation_config["clip"]["cache_dir"] + if self.validation_run_clip: + self.clip_version = validation_config["clip"]["clip_version"] + self.clip_cache_dir = validation_config["clip"]["cache_dir"] def register_schedule(self, diff --git a/stable_diffusion/requirements.txt b/stable_diffusion/requirements.txt index f4cc1a909..26802b5ff 100644 --- a/stable_diffusion/requirements.txt +++ b/stable_diffusion/requirements.txt @@ -19,9 +19,7 @@ datasets==2.10.1 colossalai==0.2.7 invisible-watermark==0.1.5 diffusers==0.14.0 -img2dataset==1.41.0 cloudpathlib==0.13.0 git+https://github.com/facebookresearch/xformers.git@5eb0dbf315d14b5f7b38ac2ff3d8379beca7df9b#egg=xformers bitsandbytes==0.37.2 -# TODO(ahmadki): use github.com:mlcommons/logging.git once the SD PR is merged git+https://github.com/mlcommons/logging.git@8405a08bbfc724f8888c419461c02d55a6ac960c diff --git a/stable_diffusion/scripts/datasets/laion400m-download-dataset.sh b/stable_diffusion/scripts/datasets/laion400m-download-dataset.sh index 707ef4124..16c436812 100755 --- a/stable_diffusion/scripts/datasets/laion400m-download-dataset.sh +++ b/stable_diffusion/scripts/datasets/laion400m-download-dataset.sh @@ -25,6 +25,7 @@ done mkdir -p ${OUTPUT_DIR} +pip install img2dataset==1.41.0 img2dataset \ --url_list ${METADATA_DIR} \ --input_format "parquet" \