-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcodegen.py
243 lines (191 loc) · 6.78 KB
/
codegen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
from copy import deepcopy
from typing import Tuple, List
import black
import ast
import astor
source = open("spoffy/modules/sansio.py").read()
parsed = ast.parse(source)
def to_source(node) -> str:
s = astor.to_source(node, pretty_source=lambda x: "".join(x))
return s
def get_classes() -> List[ast.ClassDef]:
return [
node for node in ast.walk(parsed) if isinstance(node, ast.ClassDef)
]
def get_module_classes():
return [
c
for c in get_classes()
if c.bases and c.bases[0].id == "RequestBuilder" # type: ignore
]
def get_class_methods(klass: ast.ClassDef) -> List[ast.FunctionDef]:
return [
node for node in ast.walk(klass) if isinstance(node, ast.FunctionDef)
]
def get_attributes(klass: ast.ClassDef) -> List[ast.Attribute]:
return [
node for node in ast.walk(klass) if isinstance(node, ast.Attribute)
]
def get_request_methods(klass: ast.ClassDef) -> List[ast.FunctionDef]:
pass
def build_ret(method, rtype_arg):
tmpl = (
"return self._make_request(self.b.artist(artist_id=artist_id), Artist)"
)
ret = ast.parse(tmpl).body[0]
ret.value.args[1] = rtype_arg # type: ignore
# TODO: Skip first arg (it's self)
ret.value.args[0].args = [] # type: ignore
ret.value.args[0].keywords = signature_to_keywords(method) # type: ignore
ret.value.args[0].func.attr = method.name # type: ignore
async_ret = deepcopy(ret)
async_ret.value = ast.Await(ret.value) # type: ignore
return ret, async_ret
def signature_to_keywords(method):
"""
Convert a method signature to
keyword arguments ast objects
"""
keywords = [
ast.keyword(arg.arg, ast.Name(arg.arg, ast.Load()))
for arg in method.args.args[1:]
]
if method.args.kwarg:
# Forward **kwargs to builder call if method has them
kwarg_name = method.args.kwarg.arg
keywords.append( # type: ignore
ast.keyword(None, ast.Name(kwarg_name, ast.Load()))
)
return keywords
def build_method_def(
method: ast.FunctionDef,
) -> Tuple[ast.FunctionDef, ast.AsyncFunctionDef]:
orig = method
method = deepcopy(method)
rtype_arg = method.decorator_list[0].args[0] # type: ignore
method.decorator_list = []
method.returns = rtype_arg
# method.returns.id = rtype # type: ignore
call, async_call = build_ret(method, rtype_arg)
docstring = ast.get_docstring(orig)
method.body = [call]
if docstring:
method.body.insert(0, orig.body[0])
async_method = deepcopy(method)
async_method = ast.AsyncFunctionDef( # type: ignore
**async_method.__dict__
)
async_method.body[-1] = async_call
return method, async_method # type: ignore
def add_mixins(name, klass: ast.ClassDef):
mixintree = ast.parse(open("spoffy/modules/mixins.py").read())
for node in mixintree.body:
if not isinstance(node, ast.ClassDef):
continue
if name + "Mixin" != node.name:
continue
klass.bases.append(
ast.Attribute(
ast.Name("mixins", ast.Load()), node.name, ast.Load()
)
)
def find_method(klass, name):
for item in klass.body:
if not isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)):
continue
if item.name == name:
return item
def add_extras(klass: ast.ClassDef):
if not klass.name.endswith("Auth"):
return
assign_methods = [
(
"authorize_user",
"get_token_from_code",
_doc_authorize_user,
),
(
"authorize_client",
"get_token_from_client_credentials",
_doc_authorize_client,
),
(
"refresh_authorization",
"get_token_from_refresh_token",
_doc_refresh_authorization,
),
]
tmpl = "self._assign_token(self.method_name())"
for methname, mirroredname, docstring in assign_methods:
mirrored = find_method(klass, mirroredname)
meth = deepcopy(mirrored)
meth.name = methname
keywords = signature_to_keywords(mirrored)
assignment = ast.parse(tmpl).body[0]
# assignment.value.args[0].s = assignto # type: ignore
assignment.value.args[0].func.attr = mirrored.name # type: ignore
assignment.value.args[0].keywords = keywords # type: ignore
if isinstance(meth, ast.AsyncFunctionDef):
assignment = ast.Expr(ast.Await(assignment.value)) # type: ignore
body = [assignment]
if docstring:
docobj = ast.Expr(ast.Str(docstring))
body.insert(0, docobj)
meth.body = body
meth.returns = None
klass.body.append(meth)
def build_class(klass: ast.ClassDef) -> Tuple[ast.ClassDef, ast.ClassDef]:
tmpl = "__builder_class__ = builders.Artists"
klasscopy = deepcopy(klass)
methods = [
build_method_def(method) for method in get_class_methods(klasscopy)
]
builder_class = ast.parse(tmpl).body[0]
builder_class.value.attr = klass.name # type: ignore
klasscopy.body = [builder_class] # type: ignore
klasscopy.bases[0].id = "ApiModule" # type: ignore
async_class = deepcopy(klasscopy)
async_class.name = "Async" + async_class.name
async_class.bases[0].id = "AsyncApiModule" # type: ignore
sync_meths, async_meths = list(zip(*methods))
klasscopy.body += sync_meths
async_class.body += async_meths
add_mixins(klasscopy.name, klasscopy)
add_mixins(klasscopy.name, async_class)
add_extras(klasscopy)
add_extras(async_class)
return klasscopy, async_class
def build_classes():
new_body: List[ast.ClassDef] = []
for klass in get_module_classes():
new_body += list(build_class(klass))
return new_body
_doc_authorize_client = """
Authorize this API instance using
its client ID and client Secret
"""
_doc_authorize_user = """
Authorize this API instance using a response code
from oauth login
"""
_doc_refresh_authorization = """
:param refresh_token: Optional refresh token to use
instead of the token stored on this instance
"""
if __name__ == "__main__":
classes = build_classes()
with open("spoffy/modules/modules.py", "r+") as output_file:
out_source = output_file.read()
output_file.seek(0)
out_tree = ast.parse(out_source)
out_body = [
node
for node in out_tree.body
if isinstance(node, (ast.Import, ast.ImportFrom))
]
out_body += classes
out_tree = ast.Module(out_body)
source = to_source(out_tree)
source = black.format_str(source, mode=black.Mode(line_length=79))
output_file.write(source)
print(source)