Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Do not merge] Script to compare safety checkers #219

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from .safety_checker import StableDiffusionSafetyChecker
from .safety import SafetyChecker

original_checker = SafetyChecker()


class StableDiffusionPipeline(DiffusionPipeline):
Expand Down Expand Up @@ -149,9 +152,23 @@ def __call__(
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()

original_result = original_checker(self.numpy_to_pil(image))

# run safety checker
safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(torch_device)
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values)
image, has_nsfw_concept, result = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values)

def check_values_the_same(dict_1, dict_2, name):
dict_1_values = dict_1[name].values()
dict_2_values = dict_2[name].values()
the_same = torch.allclose(torch.tensor(list(dict_1_values)), torch.tensor(list(dict_2_values)), atol=1e-3)
if not the_same:
print("Original", dict_1[name])
print("Diffusers", dict_2[name])

for dict_1, dict_2 in zip(original_result, result):
for name in ['special_scores', 'concept_scores']:
check_values_the_same(dict_1, dict_2, name)

if output_type == "pil":
image = self.numpy_to_pil(image)
Expand Down
94 changes: 94 additions & 0 deletions src/diffusers/pipelines/stable_diffusion/safety.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
#!/usr/bin/env python3
from cgitb import reset
from typing import OrderedDict
import torch, torch.nn as nn
import open_clip
import numpy as np
import yaml

from open_clip import create_model_and_transforms

model, _, preprocess = create_model_and_transforms("ViT-L-14", "openai")

def normalized(a, axis=-1, order=2):
"""Normalize the given array along the specified axis in order to"""
l2 = np.atleast_1d(np.linalg.norm(a, order, axis))
l2[l2 == 0] = 1
return a / np.expand_dims(l2, axis)

def pw_cosine_distance(input_a, input_b):
normalized_input_a = torch.nn.functional.normalize(input_a)
normalized_input_b = torch.nn.functional.normalize(input_b)
return torch.mm(normalized_input_a, normalized_input_b.T)

class SafetyChecker(nn.Module):
def __init__(self, device = 'cuda') -> None:
super().__init__()
self.clip_model = model.to(device)
self.preprocess = preprocess
self.device = device
safety_settings = yaml.safe_load(open("/home/patrick/safety_settings.yml", "r"))
self.concepts_dict = dict(safety_settings["nsfw"]["concepts"])
self.special_care_dict = dict(safety_settings["special"]["concepts"])
self.concept_embeds = self.get_text_embeds(
list(self.concepts_dict.keys()))
self.special_care_embeds = self.get_text_embeds(
list(self.special_care_dict.keys()))

def get_image_embeds(self, input):
"""Get embeddings for images or tensor"""
with torch.cuda.amp.autocast():
with torch.no_grad():
# Preprocess if input is a list of PIL images
if isinstance(input, list):
l = []
for image in input:
l.append(self.preprocess(image))
img_tensor = torch.stack(l)
# input is a tensor
elif isinstance(input, torch.Tensor):
img_tensor = input
return self.clip_model.encode_image(img_tensor.half().to(self.device))

def get_text_embeds(self, input):
"""Get text embeddings for a list of text"""
with torch.cuda.amp.autocast():
with torch.no_grad():
input = open_clip.tokenize(input).to(self.device)
return(self.clip_model.encode_text(input))

def forward(self, images):
"""Get embeddings for images and output nsfw and concept scores"""
image_embeds = self.get_image_embeds(images)
concept_list = list(self.concepts_dict.keys())
special_list = list(self.special_care_dict.keys())
special_cos_dist = pw_cosine_distance(image_embeds,
self.special_care_embeds).cpu().numpy()
cos_dist = pw_cosine_distance(image_embeds,
self.concept_embeds).cpu().numpy()
result = []
for i in range(image_embeds.shape[0]):
result_img = {
"special_scores":{},
"special_care":[],
"concept_scores":{},
"bad_concepts":[]}
adjustment = 0.05
for j in range(len(special_cos_dist[0])):
concept_name = special_list[j]
concept_cos = special_cos_dist[i][j]
concept_threshold = self.special_care_dict[concept_name]
result_img["special_scores"][concept_name] = round(
concept_cos - concept_threshold + adjustment,3)
if result_img["special_scores"][concept_name] > 0:
result_img["special_care"].append({concept_name,result_img["special_scores"][concept_name]})
adjustment = 0.01
for j in range(len(cos_dist[0])):
concept_name = concept_list[j]
concept_cos = cos_dist[i][j]
concept_threshold = self.concepts_dict[concept_name]
result_img["concept_scores"][concept_name] = round(concept_cos - concept_threshold + adjustment,3)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This means that if adjustment > concept_threshold -> the image will always be a bad image

Copy link

@yuimo yuimo Oct 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the theory behind this code? why you calculate the similarity between the embedding of the image and a full-ones tensor below:
self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False)
self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False)
thanks very much

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just means that if cosine similarity is above a certain threshold then images will be blocked

if result_img["concept_scores"][concept_name]> 0:
result_img["bad_concepts"].append(concept_name)
result.append(result_img)
return result
2 changes: 1 addition & 1 deletion src/diffusers/pipelines/stable_diffusion/safety_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,4 @@ def forward(self, clip_input, images):
" Try again with a different prompt and/or seed."
)

return images, has_nsfw_concepts
return images, has_nsfw_concepts, result