-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathutils.py
100 lines (87 loc) · 4.1 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import numpy as np
import torch
from PIL import Image
import cv2
from typing import Tuple
def aggregate_attention(prompts, attention_store, res: int, from_where, is_cross: bool, select: int, is_cpu=True):
out = []
attention_maps = attention_store.get_average_attention()
num_pixels = res ** 2
for location in from_where:
for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
if item.shape[1] == num_pixels:
cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select]
out.append(cross_maps)
out = torch.cat(out, dim=0)
out = out.sum(0) / out.shape[0]
return out.cpu() if is_cpu else out
def show_cross_attention(prompts, tokenizer, attention_store, res: int, from_where, select: int = 0, save_path=None):
tokens = tokenizer.encode(prompts[select])
decoder = tokenizer.decode
attention_maps = aggregate_attention(prompts, attention_store, res, from_where, True, select)
images = []
for i in range(len(tokens)):
image = attention_maps[:, :, i]
image = 255 * image / image.max()
image = image.unsqueeze(-1).expand(*image.shape, 3)
image = image.detach().cpu().numpy().astype(np.uint8)
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
image = cv2.applyColorMap(image, cv2.COLORMAP_BONE)
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
image = np.array(Image.fromarray(image).resize((256, 256)))
image = text_under_image(image, decoder(int(tokens[i])))
images.append(image)
view_images(np.stack(images, axis=0), save_path=save_path)
def show_self_attention_comp(prompts, attention_store, res: int, from_where,
max_com=7, select: int = 0, save_path=None):
attention_maps = aggregate_attention(prompts, attention_store, res, from_where, False, select).numpy().reshape(
(res ** 2, res ** 2))
u, s, vh = np.linalg.svd(attention_maps - np.mean(attention_maps, axis=1, keepdims=True))
images = []
for i in range(max_com):
image = vh[i].reshape(res, res)
image = image - image.min()
image = 255 * image / image.max()
image = np.repeat(np.expand_dims(image, axis=2), 3, axis=2).astype(np.uint8)
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
image = cv2.applyColorMap(image, cv2.COLORMAP_BONE)
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
image = Image.fromarray(image).resize((256, 256))
image = np.array(image)
images.append(image)
view_images(np.concatenate(images, axis=1), save_path=save_path)
def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)):
h, w, c = image.shape
offset = int(h * .2)
img = np.ones((h + offset, w, c), dtype=np.uint8) * 255
font = cv2.FONT_HERSHEY_SIMPLEX
img[:h] = image
textsize = cv2.getTextSize(text, font, 1, 2)[0]
text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2
cv2.putText(img, text, (text_x, text_y), font, 1, text_color, 2)
return img
def view_images(images, num_rows=1, offset_ratio=0.02, save_path=None, show=False):
if type(images) is list:
num_empty = len(images) % num_rows
elif images.ndim == 4:
num_empty = images.shape[0] % num_rows
else:
images = [images]
num_empty = 0
empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty
num_items = len(images)
h, w, c = images[0].shape
offset = int(h * offset_ratio)
num_cols = num_items // num_rows
image_ = np.ones((h * num_rows + offset * (num_rows - 1),
w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255
for i in range(num_rows):
for j in range(num_cols):
image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[
i * num_cols + j]
pil_img = Image.fromarray(image_)
if show:
pil_img.show()
if save_path is not None:
pil_img.save(save_path)