-
Notifications
You must be signed in to change notification settings - Fork 347
/
Copy pathdynamic_proxy.go
172 lines (141 loc) · 3.87 KB
/
dynamic_proxy.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
package proxy
import (
"bytes"
"fmt"
"go/ast"
"go/format"
"go/parser"
"go/token"
"strings"
"text/template"
)
func generate(file string) (string, error) {
fset := token.NewFileSet() // positions are relative to fset
f, err := parser.ParseFile(fset, file, nil, parser.ParseComments)
if err != nil {
return "", err
}
// 获取代理需要的数据
data := proxyData{
Package: f.Name.Name,
}
// 构建注释和 node 的关系
cmap := ast.NewCommentMap(fset, f, f.Comments)
for node, group := range cmap {
// 从注释 @proxy 接口名,获取接口名称
name := getProxyInterfaceName(group)
if name == "" {
continue
}
// 获取代理的类名
data.ProxyStructName = node.(*ast.GenDecl).Specs[0].(*ast.TypeSpec).Name.Name
// 从文件中查找接口
obj := f.Scope.Lookup(name)
// 类型转换,注意: 这里没有对断言进行判断,可能会导致 panic
t := obj.Decl.(*ast.TypeSpec).Type.(*ast.InterfaceType)
for _, field := range t.Methods.List {
fc := field.Type.(*ast.FuncType)
// 代理的方法
method := &proxyMethod{
Name: field.Names[0].Name,
}
// 获取方法的参数和返回值
method.Params, method.ParamNames = getParamsOrResults(fc.Params)
method.Results, method.ResultNames = getParamsOrResults(fc.Results)
data.Methods = append(data.Methods, method)
}
}
// 生成文件
tpl, err := template.New("").Parse(proxyTpl)
if err != nil {
return "", err
}
buf := &bytes.Buffer{}
if err := tpl.Execute(buf, data); err != nil {
return "", err
}
// 使用 go fmt 对生成的代码进行格式化
src, err := format.Source(buf.Bytes())
if err != nil {
return "", err
}
return string(src), nil
}
// getParamsOrResults 获取参数或者是返回值
// 返回带类型的参数,以及不带类型的参数,以逗号间隔
func getParamsOrResults(fields *ast.FieldList) (string, string) {
var (
params []string
paramNames []string
)
for i, param := range fields.List {
// 循环获取所有的参数名
var names []string
for _, name := range param.Names {
names = append(names, name.Name)
}
if len(names) == 0 {
names = append(names, fmt.Sprintf("r%d", i))
}
paramNames = append(paramNames, names...)
// 参数名加参数类型组成完整的参数
param := fmt.Sprintf("%s %s",
strings.Join(names, ","),
param.Type.(*ast.Ident).Name,
)
params = append(params, strings.TrimSpace(param))
}
return strings.Join(params, ","), strings.Join(paramNames, ",")
}
func getProxyInterfaceName(groups []*ast.CommentGroup) string {
for _, commentGroup := range groups {
for _, comment := range commentGroup.List {
if strings.Contains(comment.Text, "@proxy") {
interfaceName := strings.TrimLeft(comment.Text, "// @proxy ")
return strings.TrimSpace(interfaceName)
}
}
}
return ""
}
// 生成代理类的文件模板
const proxyTpl = `
package {{.Package}}
type {{ .ProxyStructName }}Proxy struct {
child *{{ .ProxyStructName }}
}
func New{{ .ProxyStructName }}Proxy(child *{{ .ProxyStructName }}) *{{ .ProxyStructName }}Proxy {
return &{{ .ProxyStructName }}Proxy{child: child}
}
{{ range .Methods }}
func (p *{{$.ProxyStructName}}Proxy) {{ .Name }} ({{ .Params }}) ({{ .Results }}) {
// before 这里可能会有一些统计的逻辑
start := time.Now()
{{ .ResultNames }} = p.child.{{ .Name }}({{ .ParamNames }})
// after 这里可能也有一些监控统计的逻辑
log.Printf("user login cost time: %s", time.Now().Sub(start))
return {{ .ResultNames }}
}
{{ end }}
`
type proxyData struct {
// 包名
Package string
// 需要代理的类名
ProxyStructName string
// 需要代理的方法
Methods []*proxyMethod
}
// proxyMethod 代理的方法
type proxyMethod struct {
// 方法名
Name string
// 参数,含参数类型
Params string
// 参数名
ParamNames string
// 返回值
Results string
// 返回值名
ResultNames string
}