Skip to content

Commit

Permalink
[Community Pipelines] Fix pad_tokens_and_weights in lpw_stable_diffus…
Browse files Browse the repository at this point in the history
…ion (#925)

[Community Pipelines] fix pad_tokens_and_weights in lpw_stable_diffusion
  • Loading branch information
SkyTNT authored Oct 20, 2022
1 parent 6f6eef7 commit ba74a8b
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 34 deletions.
73 changes: 53 additions & 20 deletions examples/community/lpw_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str], max_len
"""
tokens = []
weights = []
truncated = False
for text in prompt:
texts_and_weights = parse_prompt_attention(text)
text_token = []
Expand All @@ -140,21 +141,21 @@ def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str], max_len
# tokenize and discard the starting and the ending token
token = pipe.tokenizer(word).input_ids[1:-1]
text_token += token

# copy the weight by length of token
text_weight += [weight] * len(token)

# stop if the text is too long (longer than truncation limit)
if len(text_token) > max_length:
truncated = True
break

# truncate
if len(text_token) > max_length:
truncated = True
text_token = text_token[:max_length]
text_weight = text_weight[:max_length]

tokens.append(text_token)
weights.append(text_weight)
if truncated:
logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
return tokens, weights


Expand All @@ -173,9 +174,9 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_midd
if len(weights[i]) == 0:
w = [1.0] * weights_length
else:
for j in range((len(weights[i]) - 1) // chunk_length + 1):
for j in range(max_embeddings_multiples):
w.append(1.0) # weight for starting token in this chunk
w += weights[i][j * chunk_length : min(len(weights[i]), (j + 1) * chunk_length)]
w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
w.append(1.0) # weight for ending token in this chunk
w += [1.0] * (weights_length - len(w))
weights[i] = w[:]
Expand All @@ -184,7 +185,10 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_midd


def get_unweighted_text_embeddings(
pipe: DiffusionPipeline, text_input: torch.Tensor, chunk_length: int, no_boseos_middle: Optional[bool] = True
pipe: DiffusionPipeline,
text_input: torch.Tensor,
chunk_length: int,
no_boseos_middle: Optional[bool] = True,
):
"""
When the length of tokens is a multiple of the capacity of the text encoder,
Expand Down Expand Up @@ -285,7 +289,8 @@ def get_weighted_text_embeddings(
max_length = max(max_length, max([len(token) for token in uncond_tokens]))

max_embeddings_multiples = min(
max_embeddings_multiples, (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1
max_embeddings_multiples,
(max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
)
max_embeddings_multiples = max(1, max_embeddings_multiples)
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
Expand Down Expand Up @@ -317,12 +322,18 @@ def get_weighted_text_embeddings(

# get the embeddings
text_embeddings = get_unweighted_text_embeddings(
pipe, prompt_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle
pipe,
prompt_tokens,
pipe.tokenizer.model_max_length,
no_boseos_middle=no_boseos_middle,
)
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
if uncond_prompt is not None:
uncond_embeddings = get_unweighted_text_embeddings(
pipe, uncond_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle
pipe,
uncond_tokens,
pipe.tokenizer.model_max_length,
no_boseos_middle=no_boseos_middle,
)
uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)

Expand Down Expand Up @@ -632,16 +643,29 @@ def __call__(
# Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`.
latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
latents_shape = (
batch_size * num_images_per_prompt,
self.unet.in_channels,
height // 8,
width // 8,
)

if latents is None:
if self.device.type == "mps":
# randn does not exist on mps
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
self.device
)
latents = torch.randn(
latents_shape,
generator=generator,
device="cpu",
dtype=latents_dtype,
).to(self.device)
else:
latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
latents = torch.randn(
latents_shape,
generator=generator,
device=self.device,
dtype=latents_dtype,
)
else:
if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
Expand Down Expand Up @@ -684,11 +708,19 @@ def __call__(
# add noise to latents using the timesteps
if self.device.type == "mps":
# randn does not exist on mps
noise = torch.randn(init_latents.shape, generator=generator, device="cpu", dtype=latents_dtype).to(
self.device
)
noise = torch.randn(
init_latents.shape,
generator=generator,
device="cpu",
dtype=latents_dtype,
).to(self.device)
else:
noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=latents_dtype)
noise = torch.randn(
init_latents.shape,
generator=generator,
device=self.device,
dtype=latents_dtype,
)
latents = self.scheduler.add_noise(init_latents, noise, timesteps)

t_start = max(num_inference_steps - init_timestep + offset, 0)
Expand Down Expand Up @@ -741,7 +773,8 @@ def __call__(
self.device
)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
images=image,
clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype),
)
else:
has_nsfw_concept = None
Expand Down
51 changes: 37 additions & 14 deletions examples/community/lpw_stable_diffusion_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def get_prompts_with_weights(pipe, prompt: List[str], max_length: int):
"""
tokens = []
weights = []
truncated = False
for text in prompt:
texts_and_weights = parse_prompt_attention(text)
text_token = []
Expand All @@ -138,21 +139,21 @@ def get_prompts_with_weights(pipe, prompt: List[str], max_length: int):
# tokenize and discard the starting and the ending token
token = pipe.tokenizer(word, return_tensors="np").input_ids[0, 1:-1]
text_token += list(token)

# copy the weight by length of token
text_weight += [weight] * len(token)

# stop if the text is too long (longer than truncation limit)
if len(text_token) > max_length:
truncated = True
break

# truncate
if len(text_token) > max_length:
truncated = True
text_token = text_token[:max_length]
text_weight = text_weight[:max_length]

tokens.append(text_token)
weights.append(text_weight)
if truncated:
logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
return tokens, weights


Expand All @@ -171,9 +172,9 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_midd
if len(weights[i]) == 0:
w = [1.0] * weights_length
else:
for j in range((len(weights[i]) - 1) // chunk_length + 1):
for j in range(max_embeddings_multiples):
w.append(1.0) # weight for starting token in this chunk
w += weights[i][j * chunk_length : min(len(weights[i]), (j + 1) * chunk_length)]
w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
w.append(1.0) # weight for ending token in this chunk
w += [1.0] * (weights_length - len(w))
weights[i] = w[:]
Expand All @@ -182,7 +183,10 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_midd


def get_unweighted_text_embeddings(
pipe, text_input: np.array, chunk_length: int, no_boseos_middle: Optional[bool] = True
pipe,
text_input: np.array,
chunk_length: int,
no_boseos_middle: Optional[bool] = True,
):
"""
When the length of tokens is a multiple of the capacity of the text encoder,
Expand Down Expand Up @@ -276,7 +280,10 @@ def get_weighted_text_embeddings(
uncond_tokens = [
token[1:-1]
for token in pipe.tokenizer(
uncond_prompt, max_length=max_length, truncation=True, return_tensors="np"
uncond_prompt,
max_length=max_length,
truncation=True,
return_tensors="np",
).input_ids
]
uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
Expand All @@ -287,7 +294,8 @@ def get_weighted_text_embeddings(
max_length = max(max_length, max([len(token) for token in uncond_tokens]))

max_embeddings_multiples = min(
max_embeddings_multiples, (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1
max_embeddings_multiples,
(max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
)
max_embeddings_multiples = max(1, max_embeddings_multiples)
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
Expand Down Expand Up @@ -319,12 +327,18 @@ def get_weighted_text_embeddings(

# get the embeddings
text_embeddings = get_unweighted_text_embeddings(
pipe, prompt_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle
pipe,
prompt_tokens,
pipe.tokenizer.model_max_length,
no_boseos_middle=no_boseos_middle,
)
prompt_weights = np.array(prompt_weights, dtype=text_embeddings.dtype)
if uncond_prompt is not None:
uncond_embeddings = get_unweighted_text_embeddings(
pipe, uncond_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle
pipe,
uncond_tokens,
pipe.tokenizer.model_max_length,
no_boseos_middle=no_boseos_middle,
)
uncond_weights = np.array(uncond_weights, dtype=uncond_embeddings.dtype)

Expand Down Expand Up @@ -559,7 +573,12 @@ def __call__(
noise = None

if init_image is None:
latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8)
latents_shape = (
batch_size * num_images_per_prompt,
4,
height // 8,
width // 8,
)

if latents is None:
latents = generator.randn(*latents_shape).astype(latents_dtype)
Expand Down Expand Up @@ -625,7 +644,9 @@ def __call__(

# predict the noise residual
noise_pred = self.unet(
sample=latent_model_input, timestep=np.array([t]), encoder_hidden_states=text_embeddings
sample=latent_model_input,
timestep=np.array([t]),
encoder_hidden_states=text_embeddings,
)
noise_pred = noise_pred[0]

Expand All @@ -640,7 +661,9 @@ def __call__(
if mask is not None:
# masking
init_latents_proper = self.scheduler.add_noise(
torch.from_numpy(init_latents_orig), torch.from_numpy(noise), torch.tensor([t])
torch.from_numpy(init_latents_orig),
torch.from_numpy(noise),
torch.tensor([t]),
).numpy()
latents = (init_latents_proper * mask) + (latents * (1 - mask))

Expand Down

0 comments on commit ba74a8b

Please sign in to comment.