-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcompile.mc
263 lines (246 loc) · 7.69 KB
/
compile.mc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
include "sys.mc"
include "option.mc"
include "log.mc"
include "mexpr/ast-builder.mc"
include "mexpr/utils.mc"
include "mexpr/symbolize.mc"
include "mexpr/cse.mc"
include "mexpr/const-transformer.mc"
include "./dae.mc"
include "./desugar.mc"
include "./dae-arg.mc"
let peadaeNameSpace = "PEADAE"
lang DAECompile =
DAE + PEvalLetInline + MExprFindSym + MExprSubstitute + BootParser + CSE
sem daeSrcPathExn : () -> String
sem daeSrcPathExn =| () ->
optionGetOrElse
(lam. error
(strJoin " " [
peadaeNameSpace,
"is unset. Please set it to point to the root of PEADAE source"
]))
(sysGetEnv peadaeNameSpace)
sem daeCompile : Options -> TmDAERec -> Expr
sem daeCompile options =| daer ->
let logDebug = lam head. lam msg.
logDebug (lam. strJoin "\n" [join ["daeCompile:", head, ":"], msg ()])
in
match typeCheck (symbolize (TmDAE daer)) with
TmDAE daer
then
-- Setup runtime
let runtime =
parseMCoreFile {
defaultBootParserParseMCoreFileArg with
eliminateDeadCode = true,
allowFree = true
} (join [daeSrcPathExn (), "/runtime.mc"])
in
let runtime = symbolize runtime in
let runtimeNames = [
"daeRuntimeRun",
"daeRuntimeBenchmarkRes",
"daeRuntimeBenchmarkJac",
"sin",
"cos",
"exp",
"pow",
"sqrt",
"arrayGet",
"cArray1Set",
"sundialsMatrixDenseSet",
"sundialsMatrixDenseUpdate"
] in
let runtimeNames =
foldl2
(lam runtimeNames. lam str. lam name.
mapUpdate str (lam. name) runtimeNames)
(mapEmpty cmpString)
runtimeNames
(findNamesOfStrings runtimeNames runtime)
in
-- Compile DAE
let daer = daeAnnotDVars daer in
let daer = if options.cse then daeCSE daer else daer in
logDebug "analysis"
(lam.
strJoin " " ["number of equations:", int2string (length daer.eqns)]);
let analysis = daeStructuralAnalysis daer in
logDebug "analysis"
(lam. strJoin " " [
"max equation offset",
int2string
(maxOrElse (lam. error "impossible") subi analysis.eqnsOffset)
]);
logDebug "analysis"
(lam. strJoin " " [
"number differentiated equations",
int2string
(length (filter (neqi 0) analysis.eqnsOffset))
]);
-- let daer = if options.cse then daeDestructiveCSE daer else daer in
let daer = daeIndexReduce analysis.eqnsOffset daer in
let state = daeFirstOrderState analysis.varOffset in
-- logDebug "first-order state"
-- (lam. strJoin "\n"
-- (mapi
-- (lam i. lam y. join [
-- int2string i, ":",
-- theseThese
-- (lam id. nameGetStr (daeID id))
-- (lam id. nameGetStr (daeID id))
-- (lam id1. lam id2. join [
-- nameGetStr (daeID id1), ",", nameGetStr (daeID id2)
-- ])
-- y
-- ])
-- state.ys));
let isdiffvars = daeIsDiffVars state in
let daer = daeOrderReduce state (nameSym "y") (nameSym "yp") daer in
let ts = [
daeGenInitExpr state daer,
daeGenResExpr daer,
daeGenOutExpr daer
]
in
let pevalInlineLets = pevalInlineLets (sideEffectEnvEmpty ()) in
match
if options.disablePeval then ts
else map (lam t. pevalInlineLets (peval t)) ts
with [iexpr, rexpr, oexpr]
in
match
if options.constantFold then
(pevalInlineLets iexpr, pevalInlineLets rexpr, pevalInlineLets oexpr)
else (iexpr, rexpr, oexpr)
with (iexpr, rexpr, oexpr)
in
-- match
-- if options.cse then
-- (cse iexpr, cse rexpr, cse oexpr)
-- else (iexpr, rexpr, oexpr)
-- with (iexpr, rexpr, oexpr)
-- in
let jacSpecThreshold =
match options.jacSpecThresholdAbsolute with Some n then
maxf (minf (divf (int2float n) (int2float (length daer.eqns))) 1.) 0.
else options.jacSpecThreshold
in
match
if options.numericJac then (ulam_ "" never_, ulam_ "" never_)
else
if options.disablePeval then
(daeGenMixedJacY 0. daer, daeGenMixedJacYp 0. daer)
else
(daeGenMixedJacY jacSpecThreshold daer,
daeGenMixedJacYp jacSpecThreshold daer)
with (jacY, jacYp)
in
-- Generate runtime
let _varids = nameSym "varids" in
let _initVals = nameSym "initVals" in
let _resf = nameSym "resf" in
let _jacYf = nameSym "jacYf" in
let _jacYpf = nameSym "jacYpf" in
let _outf = nameSym "outf" in
let t =
let n = length isdiffvars in
switch (options.benchmarkResidual, options.benchmarkJacobian)
case (true, false) then
bind_ (nulet_ _resf rexpr)
(appSeq_
(nvar_ (mapFindExn "daeRuntimeBenchmarkRes" runtimeNames))
[
int_ n,
(nvar_ _resf)
])
case (false, true) then
bindall_ [
nulet_ _jacYf jacY,
nulet_ _jacYpf jacYp,
(appSeq_
(nvar_ (mapFindExn "daeRuntimeBenchmarkJac" runtimeNames))
[
int_ n,
(nvar_ _jacYf),
(nvar_ _jacYpf)
])
]
case (true, true) then error "Unimplemented"
case (false, false) then
let t =
bindall_[
nulet_ _varids (seq_ (map bool_ isdiffvars)),
nulet_ _initVals iexpr,
nulet_ _resf rexpr,
nulet_ _jacYf jacY,
nulet_ _jacYpf jacYp,
nulet_ _outf oexpr,
appSeq_ (nvar_ (mapFindExn "daeRuntimeRun" runtimeNames))
(cons
(bool_ options.numericJac)
(map nvar_
[_varids, _initVals, _resf, _jacYf, _jacYpf, _outf]))
]
in
let env =
foldl
(lam env. lam x. match x with (str, c) in
mapInsert c (mapFindExn str runtimeNames) env)
(mapEmpty cmpConst)
(daeBuiltin ())
in
constTransformConstsToVars env t
end
in
-- let t =
-- substituteIdentifiers
-- (mapFromSeq
-- nameCmp
-- (map
-- (lam x. (x.1, mapFindExn x.0 runtimeNames))
-- (concat adBuiltinSymbols (mapBindings daeBuiltins))))
-- t
-- in
bind_ runtime t
else error "impossible"
end
lang TestLang = DAEParseAnalysis + DAEParseDesugar + DAECompile end
mexpr
use TestLang in
let _parse = lam prog.
let prog = daeParseExn "internal" prog in
logMsg logLevel.debug
(lam. strJoin "\n" ["Input program:", daeProgToString prog]);
let daer = daeDesugarProg prog in
match typeCheck (symbolize (TmDAE daer)) with TmDAE daer then
daer
else error "impossible"
in
-------------------
-- Test Pendulum --
-------------------
logSetLogLevel logLevel.error;
let dae = _parse "
let mul = lam x. lam y. x*y end
let pow2 = lam x. mul x x end
variables
x, y, h : Float
init
x = 1.;
x' = 2.;
y'' = 0. - 1.
equations
x'' = mul x h;
y'' = mul y h - 1.;
pow2 x + pow2 y = pow2 1.
output
{x, x', x''}
"
in
let t = daeCompile defaultOptions dae in
logMsg logLevel.debug
(lam. strJoin "\n" ["Output program:", expr2str t]);
utest typeCheck (symbolize t); true with true in
()