-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathInference.cs
150 lines (136 loc) · 6.83 KB
/
Inference.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
namespace hm_infer_cs;
using static Utilities;
public static class Inference
{
public static BaseType Analyze(this SExpr e)
{
TypeVariable.ResetChar();
BaseType Aux(SExpr expr, Dictionary<string, BaseType> env, List<BaseType> ngen)
{
BaseType FindSymbol(string s)
{
return env.TryGetValue(s, out var res)
? res
: throw new Exception($"Unknown symbol '{s}', available are {string.Join(", ", env.Keys)}");
}
BaseType Self(SExpr node)
{
return Aux(node, env, ngen);
}
switch (expr)
{
case SList(var vals):
switch (vals[0])
{
case SSymbol("let"):
{
var newenv = new Dictionary<string, BaseType>(env);
foreach (var binding in vals[1].Expect<SList>("let: expected list of bindings").Values
.Select(v => v.Expect<SList>("let: expected binding").Values))
{
if (binding.Count != 2)
throw new Exception("let binding: expected pair");
newenv[binding[0].Expect<SSymbol>("let binding: expected identifier").Value] =
Self(binding[1]);
}
return Aux(vals[2], newenv, ngen);
}
case SSymbol("let*"):
{
var bindings = vals[1].Expect<SList>("let*: expected list of bindings").Values
.ToArray();
var first = bindings[0];
var body = vals[2];
if (bindings.Length > 1)
return Self(
new SList(new List<SExpr>
{
new SSymbol("let*"), new SList(new List<SExpr> { first }),
new SList(new List<SExpr>
{
new SSymbol("let*"),
new SList(bindings[1..].ToList()), body
})
}));
var binding = first.Expect<SList>("let*: expected binding").Values;
var newenv = new Dictionary<string, BaseType>(env)
{
[binding[0].Expect<SSymbol>("let* binding: expected identifier").Value] =
Self(binding[1])
};
return Aux(vals[2], newenv, ngen);
}
case SSymbol("letrec"):
{
var newenv = new Dictionary<string, BaseType>(env);
var ftypes = new List<(SExpr, BaseType)>();
foreach (var binding in vals[1].Expect<SList>("letrec: expected list of bindings").Values
.Select(v => v.Expect<SList>("letrec: expected binding").Values))
{
if (binding.Count != 2)
throw new Exception("letrec binding: expected pair");
ftypes.Add((binding[1],
newenv[binding[0].Expect<SSymbol>("letrec binding: expected identifier").Value] =
new TypeVariable()));
}
foreach (var (val, type) in ftypes)
type.Unify(Aux(val, newenv, ftypes.Select(f => f.Item2).ToList()));
return Aux(vals[2], newenv, ngen);
}
case SSymbol("lambda"):
{
var pnames = vals[1].Expect<SList>("lambda: expected list of parameters").Values
.ToArray();
var param = pnames[0];
var body = vals[2];
if (pnames.Length > 1)
return Self(
new SList(new List<SExpr>
{
new SSymbol("lambda"), new SList(new List<SExpr> { param }),
new SList(new List<SExpr>
{
new SSymbol("lambda"),
new SList(pnames[1..].ToList()), body
})
}));
var ptype = new TypeVariable();
var newenv = new Dictionary<string, BaseType>(env)
{
[param.Expect<SSymbol>("lambda: expected identifier").Value] = ptype
};
var newngen = new List<BaseType>(ngen) { ptype };
return Ft(ptype, Aux(body, newenv, newngen));
}
default:
{
var f = vals[0];
var arg = vals[1];
if (vals.Count > 2)
return Self(new SList(new[]
{
new SList(new List<SExpr> { f, arg })
}.Concat(vals.Skip(2)).ToList()));
var val = Self(f);
var rettype = new TypeVariable();
var argtype = Self(arg);
var functype = Ft(argtype, rettype);
functype.Unify(val);
return rettype;
}
}
case SAtom<int>(_):
return Int;
case SAtom<bool>(_):
return Bool;
case SAtom<string>(_):
return Str;
case SSymbol(var name):
return FindSymbol(name).Duplicate(new Dictionary<BaseType, BaseType>(), ngen);
default:
throw new Exception($"Unknown type for Scheme object '{expr}'");
}
}
return Aux(e, Env, new List<BaseType>());
}
}