-
Notifications
You must be signed in to change notification settings - Fork 26
/
Copy pathappearance_transfer_model.py
180 lines (141 loc) · 8.15 KB
/
appearance_transfer_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
from typing import List, Optional, Callable
import torch
import torch.nn.functional as F
from config import RunConfig
from constants import OUT_INDEX, STRUCT_INDEX, STYLE_INDEX
from models.stable_diffusion import CrossImageAttentionStableDiffusionPipeline
from utils import attention_utils
from utils.adain import masked_adain, adain
from utils.model_utils import get_stable_diffusion_model
from utils.segmentation import Segmentor
class AppearanceTransferModel:
def __init__(self, config: RunConfig, pipe: Optional[CrossImageAttentionStableDiffusionPipeline] = None):
self.config = config
self.pipe = get_stable_diffusion_model() if pipe is None else pipe
self.register_attention_control()
self.segmentor = Segmentor(prompt=config.prompt, object_nouns=[config.object_noun])
self.latents_app, self.latents_struct = None, None
self.zs_app, self.zs_struct = None, None
self.image_app_mask_32, self.image_app_mask_64 = None, None
self.image_struct_mask_32, self.image_struct_mask_64 = None, None
self.enable_edit = False
self.step = 0
def set_latents(self, latents_app: torch.Tensor, latents_struct: torch.Tensor):
self.latents_app = latents_app
self.latents_struct = latents_struct
def set_noise(self, zs_app: torch.Tensor, zs_struct: torch.Tensor):
self.zs_app = zs_app
self.zs_struct = zs_struct
def set_masks(self, masks: List[torch.Tensor]):
self.image_app_mask_32, self.image_struct_mask_32, self.image_app_mask_64, self.image_struct_mask_64 = masks
def get_adain_callback(self):
def callback(st: int, timestep: int, latents: torch.FloatTensor) -> Callable:
self.step = st
# Compute the masks using prompt mixing self-segmentation and use the masks for AdaIN operation
if self.config.use_masked_adain and self.step == self.config.adain_range.start:
masks = self.segmentor.get_object_masks()
self.set_masks(masks)
# Apply AdaIN operation using the computed masks
if self.config.adain_range.start <= self.step < self.config.adain_range.end:
if self.config.use_masked_adain:
latents[0] = masked_adain(latents[0], latents[1], self.image_struct_mask_64, self.image_app_mask_64)
else:
latents[0] = adain(latents[0], latents[1])
return callback
def register_attention_control(self):
model_self = self
class AttentionProcessor:
def __init__(self, place_in_unet: str):
self.place_in_unet = place_in_unet
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires torch 2.0, to use it, please upgrade torch to 2.0.")
def __call__(self,
attn,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask=None,
temb=None,
perform_swap: bool = False):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
is_cross = encoder_hidden_states is not None
if not is_cross:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
should_mix = False
# Potentially apply our cross image attention operation
# To do so, we need to be in a self-attention alyer in the decoder part of the denoising network
if perform_swap and not is_cross and "up" in self.place_in_unet and model_self.enable_edit:
if attention_utils.should_mix_keys_and_values(model_self, hidden_states):
should_mix = True
if model_self.step % 5 == 0 and model_self.step < 40:
# Inject the structure's keys and values
key[OUT_INDEX] = key[STRUCT_INDEX]
value[OUT_INDEX] = value[STRUCT_INDEX]
else:
# Inject the appearance's keys and values
key[OUT_INDEX] = key[STYLE_INDEX]
value[OUT_INDEX] = value[STYLE_INDEX]
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# Compute the cross attention and apply our contrasting operation
hidden_states, attn_weight = attention_utils.compute_scaled_dot_product_attention(
query, key, value,
edit_map=perform_swap and model_self.enable_edit and should_mix,
is_cross=is_cross,
contrast_strength=model_self.config.contrast_strength,
)
# Update attention map for segmentation
if model_self.config.use_masked_adain and model_self.step == model_self.config.adain_range.start - 1:
model_self.segmentor.update_attention(attn_weight, is_cross)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query[OUT_INDEX].dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
def register_recr(net_, count, place_in_unet):
if net_.__class__.__name__ == 'ResnetBlock2D':
pass
if net_.__class__.__name__ == 'Attention':
net_.set_processor(AttentionProcessor(place_in_unet + f"_{count + 1}"))
return count + 1
elif hasattr(net_, 'children'):
for net__ in net_.children():
count = register_recr(net__, count, place_in_unet)
return count
cross_att_count = 0
sub_nets = self.pipe.unet.named_children()
for net in sub_nets:
if "down" in net[0]:
cross_att_count += register_recr(net[1], 0, "down")
elif "up" in net[0]:
cross_att_count += register_recr(net[1], 0, "up")
elif "mid" in net[0]:
cross_att_count += register_recr(net[1], 0, "mid")