From ba74a8be7a620da0558f27802a19736627e9e64a Mon Sep 17 00:00:00 2001 From: SkyTNT Date: Fri, 21 Oct 2022 01:26:04 +0800 Subject: [PATCH] [Community Pipelines] Fix pad_tokens_and_weights in lpw_stable_diffusion (#925) [Community Pipelines] fix pad_tokens_and_weights in lpw_stable_diffusion --- examples/community/lpw_stable_diffusion.py | 73 ++++++++++++++----- .../community/lpw_stable_diffusion_onnx.py | 51 +++++++++---- 2 files changed, 90 insertions(+), 34 deletions(-) diff --git a/examples/community/lpw_stable_diffusion.py b/examples/community/lpw_stable_diffusion.py index 3d4ec23e3aea..4980f8c8be06 100644 --- a/examples/community/lpw_stable_diffusion.py +++ b/examples/community/lpw_stable_diffusion.py @@ -132,6 +132,7 @@ def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str], max_len """ tokens = [] weights = [] + truncated = False for text in prompt: texts_and_weights = parse_prompt_attention(text) text_token = [] @@ -140,21 +141,21 @@ def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str], max_len # tokenize and discard the starting and the ending token token = pipe.tokenizer(word).input_ids[1:-1] text_token += token - # copy the weight by length of token text_weight += [weight] * len(token) - # stop if the text is too long (longer than truncation limit) if len(text_token) > max_length: + truncated = True break - # truncate if len(text_token) > max_length: + truncated = True text_token = text_token[:max_length] text_weight = text_weight[:max_length] - tokens.append(text_token) weights.append(text_weight) + if truncated: + logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") return tokens, weights @@ -173,9 +174,9 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_midd if len(weights[i]) == 0: w = [1.0] * weights_length else: - for j in range((len(weights[i]) - 1) // chunk_length + 1): + for j in range(max_embeddings_multiples): w.append(1.0) # weight for starting token in this chunk - w += weights[i][j * chunk_length : min(len(weights[i]), (j + 1) * chunk_length)] + w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))] w.append(1.0) # weight for ending token in this chunk w += [1.0] * (weights_length - len(w)) weights[i] = w[:] @@ -184,7 +185,10 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_midd def get_unweighted_text_embeddings( - pipe: DiffusionPipeline, text_input: torch.Tensor, chunk_length: int, no_boseos_middle: Optional[bool] = True + pipe: DiffusionPipeline, + text_input: torch.Tensor, + chunk_length: int, + no_boseos_middle: Optional[bool] = True, ): """ When the length of tokens is a multiple of the capacity of the text encoder, @@ -285,7 +289,8 @@ def get_weighted_text_embeddings( max_length = max(max_length, max([len(token) for token in uncond_tokens])) max_embeddings_multiples = min( - max_embeddings_multiples, (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1 + max_embeddings_multiples, + (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1, ) max_embeddings_multiples = max(1, max_embeddings_multiples) max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 @@ -317,12 +322,18 @@ def get_weighted_text_embeddings( # get the embeddings text_embeddings = get_unweighted_text_embeddings( - pipe, prompt_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle + pipe, + prompt_tokens, + pipe.tokenizer.model_max_length, + no_boseos_middle=no_boseos_middle, ) prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device) if uncond_prompt is not None: uncond_embeddings = get_unweighted_text_embeddings( - pipe, uncond_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle + pipe, + uncond_tokens, + pipe.tokenizer.model_max_length, + no_boseos_middle=no_boseos_middle, ) uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device) @@ -632,16 +643,29 @@ def __call__( # Unlike in other pipelines, latents need to be generated in the target device # for 1-to-1 results reproducibility with the CompVis implementation. # However this currently doesn't work in `mps`. - latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8) + latents_shape = ( + batch_size * num_images_per_prompt, + self.unet.in_channels, + height // 8, + width // 8, + ) if latents is None: if self.device.type == "mps": # randn does not exist on mps - latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to( - self.device - ) + latents = torch.randn( + latents_shape, + generator=generator, + device="cpu", + dtype=latents_dtype, + ).to(self.device) else: - latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) + latents = torch.randn( + latents_shape, + generator=generator, + device=self.device, + dtype=latents_dtype, + ) else: if latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") @@ -684,11 +708,19 @@ def __call__( # add noise to latents using the timesteps if self.device.type == "mps": # randn does not exist on mps - noise = torch.randn(init_latents.shape, generator=generator, device="cpu", dtype=latents_dtype).to( - self.device - ) + noise = torch.randn( + init_latents.shape, + generator=generator, + device="cpu", + dtype=latents_dtype, + ).to(self.device) else: - noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=latents_dtype) + noise = torch.randn( + init_latents.shape, + generator=generator, + device=self.device, + dtype=latents_dtype, + ) latents = self.scheduler.add_noise(init_latents, noise, timesteps) t_start = max(num_inference_steps - init_timestep + offset, 0) @@ -741,7 +773,8 @@ def __call__( self.device ) image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype) + images=image, + clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype), ) else: has_nsfw_concept = None diff --git a/examples/community/lpw_stable_diffusion_onnx.py b/examples/community/lpw_stable_diffusion_onnx.py index b5b9a2a65fb0..4ca37c0c4ad4 100644 --- a/examples/community/lpw_stable_diffusion_onnx.py +++ b/examples/community/lpw_stable_diffusion_onnx.py @@ -130,6 +130,7 @@ def get_prompts_with_weights(pipe, prompt: List[str], max_length: int): """ tokens = [] weights = [] + truncated = False for text in prompt: texts_and_weights = parse_prompt_attention(text) text_token = [] @@ -138,21 +139,21 @@ def get_prompts_with_weights(pipe, prompt: List[str], max_length: int): # tokenize and discard the starting and the ending token token = pipe.tokenizer(word, return_tensors="np").input_ids[0, 1:-1] text_token += list(token) - # copy the weight by length of token text_weight += [weight] * len(token) - # stop if the text is too long (longer than truncation limit) if len(text_token) > max_length: + truncated = True break - # truncate if len(text_token) > max_length: + truncated = True text_token = text_token[:max_length] text_weight = text_weight[:max_length] - tokens.append(text_token) weights.append(text_weight) + if truncated: + logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") return tokens, weights @@ -171,9 +172,9 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_midd if len(weights[i]) == 0: w = [1.0] * weights_length else: - for j in range((len(weights[i]) - 1) // chunk_length + 1): + for j in range(max_embeddings_multiples): w.append(1.0) # weight for starting token in this chunk - w += weights[i][j * chunk_length : min(len(weights[i]), (j + 1) * chunk_length)] + w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))] w.append(1.0) # weight for ending token in this chunk w += [1.0] * (weights_length - len(w)) weights[i] = w[:] @@ -182,7 +183,10 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_midd def get_unweighted_text_embeddings( - pipe, text_input: np.array, chunk_length: int, no_boseos_middle: Optional[bool] = True + pipe, + text_input: np.array, + chunk_length: int, + no_boseos_middle: Optional[bool] = True, ): """ When the length of tokens is a multiple of the capacity of the text encoder, @@ -276,7 +280,10 @@ def get_weighted_text_embeddings( uncond_tokens = [ token[1:-1] for token in pipe.tokenizer( - uncond_prompt, max_length=max_length, truncation=True, return_tensors="np" + uncond_prompt, + max_length=max_length, + truncation=True, + return_tensors="np", ).input_ids ] uncond_weights = [[1.0] * len(token) for token in uncond_tokens] @@ -287,7 +294,8 @@ def get_weighted_text_embeddings( max_length = max(max_length, max([len(token) for token in uncond_tokens])) max_embeddings_multiples = min( - max_embeddings_multiples, (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1 + max_embeddings_multiples, + (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1, ) max_embeddings_multiples = max(1, max_embeddings_multiples) max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 @@ -319,12 +327,18 @@ def get_weighted_text_embeddings( # get the embeddings text_embeddings = get_unweighted_text_embeddings( - pipe, prompt_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle + pipe, + prompt_tokens, + pipe.tokenizer.model_max_length, + no_boseos_middle=no_boseos_middle, ) prompt_weights = np.array(prompt_weights, dtype=text_embeddings.dtype) if uncond_prompt is not None: uncond_embeddings = get_unweighted_text_embeddings( - pipe, uncond_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle + pipe, + uncond_tokens, + pipe.tokenizer.model_max_length, + no_boseos_middle=no_boseos_middle, ) uncond_weights = np.array(uncond_weights, dtype=uncond_embeddings.dtype) @@ -559,7 +573,12 @@ def __call__( noise = None if init_image is None: - latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8) + latents_shape = ( + batch_size * num_images_per_prompt, + 4, + height // 8, + width // 8, + ) if latents is None: latents = generator.randn(*latents_shape).astype(latents_dtype) @@ -625,7 +644,9 @@ def __call__( # predict the noise residual noise_pred = self.unet( - sample=latent_model_input, timestep=np.array([t]), encoder_hidden_states=text_embeddings + sample=latent_model_input, + timestep=np.array([t]), + encoder_hidden_states=text_embeddings, ) noise_pred = noise_pred[0] @@ -640,7 +661,9 @@ def __call__( if mask is not None: # masking init_latents_proper = self.scheduler.add_noise( - torch.from_numpy(init_latents_orig), torch.from_numpy(noise), torch.tensor([t]) + torch.from_numpy(init_latents_orig), + torch.from_numpy(noise), + torch.tensor([t]), ).numpy() latents = (init_latents_proper * mask) + (latents * (1 - mask))