-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdsl_synthesis.py
222 lines (174 loc) · 6.95 KB
/
dsl_synthesis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
import os
import json
import tqdm
import itertools
from betamark import arc_agi
from random import sample
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, Normalize
test_challenges_path = "data/arc-agi_test_challenges.json"
train_challenges_path = "data/arc-agi_training_challenges.json"
train_solutions_path = "data/arc-agi_training_solutions.json"
with open(test_challenges_path) as fp:
test_challenges = json.load(fp)
with open(train_challenges_path) as fp:
train_challenges = json.load(fp)
with open(train_solutions_path) as fp:
train_solutions = json.load(fp)
def plot_task(task):
"""plots a task"""
examples = task["train"]
n_examples = len(examples)
cmap = ListedColormap(
[
"#000",
"#0074D9",
"#FF4136",
"#2ECC40",
"#FFDC00",
"#AAAAAA",
"#F012BE",
"#FF851B",
"#7FDBFF",
"#870C25",
]
)
norm = Normalize(vmin=0, vmax=9)
figure, axes = plt.subplots(2, n_examples, figsize=(n_examples * 4, 8))
for column, example in enumerate(examples):
axes[0, column].imshow(example["input"], cmap=cmap, norm=norm)
axes[1, column].imshow(example["output"], cmap=cmap, norm=norm)
axes[0, column].axis("off")
axes[1, column].axis("off")
plt.show()
# defining a handful of basic primitives
def tophalf(grid):
"""upper half"""
return grid[: len(grid) // 2]
def rot90(grid):
"""clockwise rotation by 90 degrees"""
return list(zip(*grid[::-1]))
def hmirror(grid):
"""mirroring along horizontal"""
return grid[::-1]
def compress(grid):
"""removes frontiers"""
ri = [i for i, r in enumerate(grid) if len(set(r)) == 1]
ci = [j for j, c in enumerate(zip(*grid)) if len(set(c)) == 1]
return [
[v for j, v in enumerate(r) if j not in ci]
for i, r in enumerate(grid)
if i not in ri
]
def trim(grid):
"""removes border"""
return [r[1:-1] for r in grid[1:-1]]
# defining the DSL as the set of the primitives
DSL_primitives = {tophalf, rot90, hmirror, compress, trim}
primitive_names = {p.__name__ for p in DSL_primitives}
print(f"DSL consists of {len(DSL_primitives)} primitives: {primitive_names}")
# the maximum composition depth to consider
MAX_DEPTH = 6
# construct the program strings of all programs expressible by composing at most MAX_DEPTH primitives
program_strings = []
for depth in range(1, MAX_DEPTH + 1):
primitive_tuples = itertools.product(*[primitive_names] * depth)
for primitives in primitive_tuples:
left_side = "".join([p + "(" for p in primitives])
right_side = ")" * depth
program_string = f"lambda grid: {left_side}grid{right_side}"
program_strings.append(program_string)
# print some of the program strings
print(f"Space to search consists of {len(program_strings)} programs:\n")
print("\n".join([*program_strings[:10], "..."]))
# map program strings to programs
programs = {prog_str: eval(prog_str) for prog_str in program_strings}
# for each task, search over the programs and if a working program is found, remember it
guesses = dict()
# iterate over all tasks
for key, task in tqdm.tqdm(train_challenges.items()):
train_inputs = [example["input"] for example in task["train"]]
train_outputs = [example["output"] for example in task["train"]]
hypotheses = []
# iterate over all programs
for program_string, program in programs.items():
try:
if all([program(i) == o for i, o in zip(train_inputs, train_outputs)]):
# remember program if it explains all training examples
hypotheses.append(program_string)
except:
pass
# select first program for making predictions
if len(hypotheses) > 0:
print(f"found {len(hypotheses)} candidate programs for task {key}!")
guesses[key] = hypotheses[0]
print(f"\nMade guesses for {len(guesses)} tasks")
# make predictions and evaluate them
solved = dict()
# iterate over all tasks for which a guess exists
for key, program_string in guesses.items():
test_inputs = [example["input"] for example in train_challenges[key]["test"]]
program = eval(program_string)
if all([program(i) == o for i, o in zip(test_inputs, train_solutions[key])]):
# mark predition as correct if all test examples are solved by the program
solved[key] = program_string
print(f"Predictions correct for {len(solved)}/{len(guesses)} tasks")
# visualize solved tasks
for key, program_string in solved.items():
print(f'For task "{key}", found program "{program_string}"')
plot_task(train_challenges[key])
# let's try to make a submission
submission = dict()
# iterate over all tasks
for key, task in tqdm.tqdm(test_challenges.items()):
train_inputs = [example["input"] for example in task["train"]]
train_outputs = [example["output"] for example in task["train"]]
hypotheses = []
# iterate over all programs
for program_string, program in programs.items():
try:
if all([program(i) == o for i, o in zip(train_inputs, train_outputs)]):
# remember program if it explains all training examples
hypotheses.append(program_string)
except:
pass
# select first program for making predictions
predictions = [example["input"] for example in task["test"]]
if len(hypotheses) > 0:
print(f"found {len(hypotheses)} candidate programs for task {key}!")
program_string = hypotheses[0]
program = eval(program_string)
try:
predictions = [program(example["input"]) for example in task["test"]]
except:
pass
# print(predictions[0])
submission[key] = [{"attempt_1": grid, "attempt_2": grid} for grid in predictions]
print(f"\nMade guesses for {len(guesses)} tasks")
def make_dsl_prediction(task):
train_inputs = [example["input"] for example in task["train"]]
train_outputs = [example["output"] for example in task["train"]]
hypotheses = []
# iterate over all programs
for program_string, program in programs.items():
try:
if all([program(i) == o for i, o in zip(train_inputs, train_outputs)]):
# remember program if it explains all training examples
hypotheses.append(program_string)
except:
pass
# select first program for making predictions
predictions = [example["input"] for example in task["test"]]
if len(hypotheses) > 0:
print(f"found {len(hypotheses)} candidate programs for task {key}!")
program_string = hypotheses[-1]
program = eval(program_string)
try:
predictions = [program(example["input"]) for example in task["test"]]
return predictions[0]
except:
return [[0]]
result = arc_agi.run_eval(user_func=make_dsl_prediction)
print(result)
# with open("submission.json", "w") as fp:
# json.dump(submission, fp)