Skip to content

Commit

Permalink
[SD] a small indentation fix (#681)
Browse files Browse the repository at this point in the history
* [SD] a small indentation fix

* [SD] fixed the validation file name to match the README

* [SD] Added Quality section to the README

* [SD] requirements cleanup

* [SD] Removed img2dataset installation from Docker to the download script
  • Loading branch information
ahmadki authored Sep 7, 2023
1 parent 2f4a93f commit 00f04c5
Show file tree
Hide file tree
Showing 9 changed files with 35 additions and 16 deletions.
20 changes: 20 additions & 0 deletions stable_diffusion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ understand the fundamentals of the Stable Diffusion model.
- [Validation metrics](#validation-metrics)
- [FID](#fid)
- [CLIP](#clip)
- [Quality](#quality)
- [Quality metric](#quality-metric)
- [Quality target](#quality-target)
- [Evaluation frequency](#evaluation-frequency)
- [Evaluation thoroughness](#evaluation-thoroughness)
- [Reference runs](#reference-runs)
- [Rules](#rules)
- [BibTeX](#bibtex)
Expand Down Expand Up @@ -232,6 +237,21 @@ Further insights and an independent evaluation of the FID score can be found in
### CLIP
CLIP is a reference free metric that can be used to evaluate the correlation between a caption for an image and the actual content of the image, it has been found to be highly correlated with human judgement. A higher CLIP Score implies that the caption matches closer to image.

# Quality
## Quality metric
Both FID and CLIP are used to evaulte the model's quality.

## Quality target
FID<=90 and CLIP>=0.15

## Evaluation frequency
Every 512,000 images, or `CEIL(512000 / global_batch_size)` if 512,000 is not divisible by GBS.

Please refer to the benchmark rules [here](https://github.com/mlcommons/training_policies/blob/master/training_rules.adoc) for the exact evaluation rules.

## Evaluation thoroughness
All the prompts in the [coco-2014](#coco-2014) validation subset.

# Reference runs
The benchmark is expected to have the following convergence profile:

Expand Down
2 changes: 1 addition & 1 deletion stable_diffusion/configs/train_01x08x08.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ model:
enabled: True
inception_weights_url: https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth
cache_dir: /checkpoints/inception
gt_path: /datasets/coco2014/val2014_30k_stats.npz
gt_path: /datasets/coco2014/val2014_512x512_30k_stats.npz
clip:
enabled: True
clip_version: "ViT-H-14"
Expand Down
2 changes: 1 addition & 1 deletion stable_diffusion/configs/train_32x08x02.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ model:
enabled: True
inception_weights_url: https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth
cache_dir: /checkpoints/inception
gt_path: /datasets/coco2014/val2014_30k_stats.npz
gt_path: /datasets/coco2014/val2014_512x512_30k_stats.npz
clip:
enabled: True
clip_version: "ViT-H-14"
Expand Down
2 changes: 1 addition & 1 deletion stable_diffusion/configs/train_32x08x02_raw_images.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ model:
enabled: True
inception_weights_url: https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth
cache_dir: /checkpoints/inception
gt_path: /datasets/coco2014/val2014_30k_stats.npz
gt_path: /datasets/coco2014/val2014_512x512_30k_stats.npz
clip:
enabled: True
clip_version: "ViT-H-14"
Expand Down
2 changes: 1 addition & 1 deletion stable_diffusion/configs/train_32x08x04.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ model:
enabled: True
inception_weights_url: https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth
cache_dir: /checkpoints/inception
gt_path: /datasets/coco2014/val2014_30k_stats.npz
gt_path: /datasets/coco2014/val2014_512x512_30k_stats.npz
clip:
enabled: True
clip_version: "ViT-H-14"
Expand Down
2 changes: 1 addition & 1 deletion stable_diffusion/configs/train_32x08x08.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ model:
enabled: True
inception_weights_url: https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth
cache_dir: /checkpoints/inception
gt_path: /datasets/coco2014/val2014_30k_stats.npz
gt_path: /datasets/coco2014/val2014_512x512_30k_stats.npz
clip:
enabled: True
clip_version: "ViT-H-14"
Expand Down
18 changes: 9 additions & 9 deletions stable_diffusion/ldm/models/diffusion/ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,17 +211,17 @@ def __init__(self,
self.validation_run_clip = validation_config["clip"]["enabled"]

if self.validation_save_images or self.validation_run_fid or self.validation_run_clip:
if self.validation_save_images:
self.validation_base_output_dir = validation_config["save_images"]["base_output_dir"]
if self.validation_save_images:
self.validation_base_output_dir = validation_config["save_images"]["base_output_dir"]

if self.validation_run_fid:
self.inception_weights_url = validation_config["fid"]["inception_weights_url"]
self.inception_cache_dir = validation_config["fid"]["cache_dir"]
self.fid_gt_path = validation_config["fid"]["gt_path"]
if self.validation_run_fid:
self.inception_weights_url = validation_config["fid"]["inception_weights_url"]
self.inception_cache_dir = validation_config["fid"]["cache_dir"]
self.fid_gt_path = validation_config["fid"]["gt_path"]

if self.validation_run_clip:
self.clip_version = validation_config["clip"]["clip_version"]
self.clip_cache_dir = validation_config["clip"]["cache_dir"]
if self.validation_run_clip:
self.clip_version = validation_config["clip"]["clip_version"]
self.clip_cache_dir = validation_config["clip"]["cache_dir"]


def register_schedule(self,
Expand Down
2 changes: 0 additions & 2 deletions stable_diffusion/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@ datasets==2.10.1
colossalai==0.2.7
invisible-watermark==0.1.5
diffusers==0.14.0
img2dataset==1.41.0
cloudpathlib==0.13.0
git+https://github.com/facebookresearch/xformers.git@5eb0dbf315d14b5f7b38ac2ff3d8379beca7df9b#egg=xformers
bitsandbytes==0.37.2
# TODO(ahmadki): use github.com:mlcommons/logging.git once the SD PR is merged
git+https://github.com/mlcommons/logging.git@8405a08bbfc724f8888c419461c02d55a6ac960c
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ done

mkdir -p ${OUTPUT_DIR}

pip install img2dataset==1.41.0
img2dataset \
--url_list ${METADATA_DIR} \
--input_format "parquet" \
Expand Down

0 comments on commit 00f04c5

Please sign in to comment.