-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathtest_main_scripts.py
97 lines (73 loc) · 4.53 KB
/
test_main_scripts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import subprocess
import unittest
class TestScripts(unittest.TestCase):
scripts_path = "./scripts/"
def run_script(self, script_name, **kwargs):
script = f"python3 {script_name}.py"
# check if in the kwargs there is dirct key, for this one just add the value to the command
if 'direct' in kwargs:
script += " " + kwargs['direct']
del kwargs['direct']
params = ' '.join(f"{(len(k) > 1 and '--' or '-') + k} {v}" for k, v in kwargs.items())
command = f"{script} {params}"
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
output, error = process.communicate()
# Log the output and errors
print(f"\n{'-' * 60}\nRunning: {command}")
print(f"Output: {output.decode()}")
if error:
print(f"Error: {error.decode()}")
# Assert that return code is 0
self.assertEqual(process.returncode, 0, f"{command} failed with return code {process.returncode}")
def test_01_download_models(self):
self.run_script("./download_models")
def test_02_process_models(self):
self.run_script("./process_models")
def test_03_text_to_image(self):
self.run_script(f"{self.scripts_path}text_to_image",
prompt="\"character, chibi, waifu, side scrolling, white background, centered\"",
num_images=1)
# python3 ./scripts/txt2img.py --num_images 2 --prompt 'A purple rainbow, filled with grass'
def test_04_1_text_to_image(self):
self.run_script(f"{self.scripts_path}txt2img", prompt="\"A purple rainbow, filled with grass\"", num_images=1)
def test_05_embed_prompts(self):
self.run_script(f"{self.scripts_path}embed_prompts",
prompts="\"A painting of a computer virus, An old photo of a computer scientist, A computer drawing a computer\"")
# Add other test cases similarly...
def test_06_generate_images_from_embeddings(self):
self.run_script(f"{self.scripts_path}generate_images_from_embeddings", temperature=1.2,
ddim_eta=0.2)
def test_07_generate_images_from_distributions(self):
self.run_script(f"{self.scripts_path}generate_images_from_distributions", d=4, params_steps=1,
params_range='0.49 0.54',
num_seeds=4, temperature=1.2, ddim_eta=1.2)
def test_08_generate_images_from_temperature_range(self):
self.run_script(f"{self.scripts_path}generate_images_from_temperature_range", d=4, params_range='0.49 0.54',
params_steps=1,
temperature_steps=1, temperature_range='0.8 0.9')
def test_09_generate_images_and_encodings(self):
self.run_script(f"{self.scripts_path}generate_images_and_encodings",
prompt="\"An oil painting of a computer generated image of a geometric pattern\"",
num_iterations=1)
def test_10_embed_prompts_and_generate_images(self):
self.run_script(f"{self.scripts_path}embed_prompts_and_generate_images", num_iterations=1)
def test_12_grid_generator(self):
self.run_script("./utility/scripts/grid_generator", input_path="./test/test_images/clip_segmentation",
output_path="./tmp", rows=3, columns=2, img_size=256)
def test_14_chad_sort(self):
self.run_script(f"{self.scripts_path}chad_score",
direct="--model-path=\"input/model/chad_score/chad-score-v1.pth\" --image-path=\"./test/test_images/test_img.jpg\"")
def test_15_chad_sort(self):
self.run_script(f"{self.scripts_path}chad_sort",
direct="--dataset-path=\"./test/test_zip_files/test-generated-dataset-correct-format.zip\" --output-path=\"./output/chad_sort/\"")
def test_16_run_generation_task(self):
self.run_script(f"{self.scripts_path}run_generation_task",
task_path="\"./test/test_generation_task/generate_images_from_random_prompt_v1.json\"")
def test_17_run_prompts_ga(self):
self.run_script(f"{self.scripts_path}prompts_ga", generations=2)
def test_18_generate_images_from_prompt_generator(self):
self.run_script(f"{self.scripts_path}generate_images_from_prompt_generator",
checkpoint_path="\"./input/model/sd/v1-5-pruned-emaonly/v1-5-pruned-emaonly.safetensors\"", cfg_scale=7, num_images=1, num_phrases=12,
output="\"./output/\"")
if __name__ == "__main__":
unittest.main()