-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathnfloat.t
396 lines (342 loc) · 11.1 KB
/
nfloat.t
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
-- SPDX-FileCopyrightText: 2024 René Hiemstra <[email protected]>
-- SPDX-FileCopyrightText: 2024 Torsten Keßler <[email protected]>
--
-- SPDX-License-Identifier: MIT
require "terralibext"
local uname = io.popen("uname", "r"):read("*a")
-- Wrap FLINT without inlines
local flint = terralib.includec("flint/nfloat.h", {"-DNFLOAT_INLINES_C=1"})
local gr = terralib.includec("flint/gr.h", {"-DGR_INLINES_C=1"})
if uname == "Darwin\n" then
terralib.linklibrary("libflint.dylib")
elseif uname == "Linux\n" then
terralib.linklibrary("libflint.so")
else
error("Not implemented for this OS.")
end
import "terraform"
local C = terralib.includec("stdio.h")
local base = require("base")
local tmath = require("tmath")
local concepts = require("concepts")
local suffix = {64, 128, 192, 256, 384, 512, 1024, 2048, 4096}
local float_type = {}
local context = {}
for _, N in pairs(suffix) do
float_type[N] = flint[string.format("nfloat%d_struct", N)]
-- Meta information on fixed precision floats is stored in a context.
-- Mathematically, they represent rings.
-- Here, we store them as global variables in a table such that
-- each float type has exactly one context it will use.
context[N] = global(flint.gr_ctx_t)
local ctx = context[N]:get()
-- Call clean_context() to release memory allocated by nfloat_ctx_init
flint.nfloat_ctx_init(ctx, N, 0)
end
local unary_math = {
"abs",
"sqrt",
"floor",
"ceil",
"exp",
"expm1",
"log",
"log1p",
"sin",
"cos",
"tan",
"sinh",
"cosh",
"tanh",
"gamma",
}
local binary_math = {
"pow",
}
--extract the exponent of an nfloat
local exponent = macro(function(value)
return `value.data.head[0]
end)
--extract the sign of an nfloat
local sign = macro(function(value)
return `terralib.select(value.data.head[1]==0, 1, -1)
end)
--extract the significant 64-bit part of the mantissa of an nfloat
local significant_part_mantissa = macro(function(value)
local M = value:gettype().type.traits.precision / 64
return `value.data.d[M-1]
end)
--shift significant 64-bit part of mantissa
local terra shiftandscale(n : uint64, e : int)
var res = n
var k = 0
while n > 0 do
k = k + 1
n = n << 1
end
return tmath.ldexp(double(res >> 64 - k), e-k)
end
local FixedFloat = terralib.memoize(function(N)
--float_type[N] stores the high-precision number using the following layout
--local M = N / 64
--struct float_type
-- head : uint64[2]
-- d : uint64[M]
--end
--here the 'head' stores the exponent and sign
-- head[0] --exponent
-- head[1] --sign
--and 'd' the mantissa:
-- d[0]
-- ...
-- d[M-1]
--here d[M-1] is the significant part of the mantissa, which means that
--it encodes the first 64 bits of the floating point number. This is used
--for example in truncation to a dpouble value.
--
--Note that the order of the mantissa is non-intuitive. The order is reversed
--as compared to how integers are typically stored. That's why bit shofting
--1 << 63 is needed in the following example
--
--example: N = 128, M = 2, representing the value 1
--x.data.head[0] = 1
--x.data.head[1] = 0
--x.data.d[0] = 0
--x.data.d[1] = uint64(1) << 63
local ctype = float_type[N]
assert(ctype, "No support for precision " .. N .. " in FixedFloat")
--get the context corresponding to precision N
local ctx = context[N]:get()
--arbitrary precision float is a wrapper around 'ctype'
local struct nfloat {
data: ctype
}
function nfloat.metamethods.__typename()
return string.format("FixedFloat(%d)", N)
end
base.AbstractBase(nfloat)
--type traits
nfloat.traits.precision = N
nfloat.traits.isfloat = true
local M = N / 64 --precision in quadwords
--generate the 'head' and 'd' for 'ctype' representing zero
--one, and eps
local function genfloat(value)
--initialize mantissa
local d = {}
for i = 1, M do
d[i] = 0ULL
end
--the order of the mantissa is opposite to how an integer is
--typically stored in memory. That's why we need to shift 63
--bits to the left.
local bitshiftone = bit.lshift(1ULL, 63)
if value == 0 then
return {{bitshiftone, 0ULL}, d}
elseif value == 1 then
d[M] = bitshiftone
return {{1ULL, 0ULL}, d}
elseif value == "eps" then
d[M] = bitshiftone
return {{N*-1ULL, 0ULL}, d}
end
end
local zero = terralib.new(nfloat, {terralib.new(ctype, genfloat(0))})
local unit = terralib.new(nfloat, {terralib.new(ctype, genfloat(1))})
local eps = terralib.new(nfloat, {terralib.new(ctype, genfloat("eps"))})
function nfloat:__newzero() return zero end
function nfloat:__newunit() return unit end
function nfloat:__neweps() return eps end
function nfloat:zero() return constant(zero) end
function nfloat:unit() return constant(unit) end
--distance from 1.0 to next floating point value
function nfloat:eps() return constant(eps) end
local terra new()
var data: ctype
flint.nfloat_init(&data, ctx)
return nfloat {data}
end
local terra from_double(x: double)
var f = new()
flint.nfloat_set_d(&f.data, x, ctx)
return f
end
local terra from_str(s: rawstring)
var f = new()
flint.nfloat_set_str(&f.data, s, ctx)
return f
end
local from = terralib.overloadedfunction("from", {from_double, from_str})
local to_str = macro(function(x)
local digits = tmath.ndigits(N / 8)
return quote
var str: rawstring
-- TODO: Fix memory leak
-- defer flint.flint_free(str)
gr.gr_get_str_n(&str, &x.data, [digits], ctx)
in
str
end
end)
function nfloat.metamethods.__cast(from, to, exp)
if to == nfloat then
if from:isarithmetic() then
return `from_double(exp)
elseif from:ispointer() and from.type == int8 then
return `from_str(exp)
else
error("Cannot cast from " .. from .. " to " .. to)
end
end
error("Unknown type")
end
local binary = {
__add = flint.nfloat_add,
__mul = flint.nfloat_mul,
__sub = flint.nfloat_sub,
__div = flint.nfloat_div,
}
for key, method in pairs(binary) do
nfloat.metamethods[key] = terra(self: nfloat, other:nfloat)
var res = new()
[method](&res.data, &self.data, &other.data, ctx)
return res
end
end
local terra fmod(value : nfloat, modulus : nfloat)
var tmp = new()
flint.nfloat_div(&tmp.data, &value.data, &modulus.data, ctx)
flint.nfloat_floor(&tmp, &tmp, ctx)
flint.nfloat_mul(&tmp.data, &tmp.data, &modulus.data, ctx)
flint.nfloat_sub(&tmp.data, &value.data, &tmp.data, ctx)
return tmp
end
tmath["fmod"]:adddefinition(fmod)
nfloat.metamethods.__mod = terra(self: nfloat, other: nfloat)
return fmod(self, other)
end
local unary = {
__unm = flint.nfloat_neg,
}
for key, method in pairs(unary) do
nfloat.metamethods[key] = terra(self: nfloat)
var res = new()
[method](&res.data, &self.data, ctx)
return res
end
end
local function cmp(sign)
local terra impl(self: &ctype, other: &ctype, ctx: flint.gr_ctx_t)
var res = 0
flint.nfloat_cmp(&res, self, other, ctx)
return res == sign
end
return impl
end
local boolean = {
__eq = cmp(0),
__lt = cmp(-1),
__gt = cmp(1)
}
for key, method in pairs(boolean) do
nfloat.metamethods[key] = terra(self: nfloat, other: nfloat)
return [method](&self.data, &other.data, ctx)
end
end
nfloat.metamethods.__le = terra(self: nfloat, other: nfloat)
return self < other or self == other
end
nfloat.metamethods.__ge = terra(self: nfloat, other: nfloat)
return self > other or self == other
end
nfloat.metamethods.__ne = terra(self: nfloat, other: nfloat)
return not (self == other)
end
local terra round(value : nfloat)
value = value + 0.5
flint.nfloat_floor(&value, &value, ctx)
return value
end
tmath["round"]:adddefinition(round)
local terra pi()
var res = new()
flint.nfloat_pi(&res, ctx)
return res
end
terra nfloat:truncatetodouble()
if @self == 0 then
return 0.0
else
var e = exponent(self)
var m = significant_part_mantissa(self)
var s = sign(self)
return s * shiftandscale(m, e)
end
end
--for now we format up to double precision.
--ToDo: specialized print.
local format = global(rawstring, "%0.2f")
--length of static buffer
--+1 for sign
--+1 for /0 terminating character
local maxlen = tmath.ndigits(sizeof(double)) + 1 + 1
tmath.numtostr:adddefinition(
terra(v : nfloat)
var buffer : int8[maxlen]
C.snprintf(buffer, maxlen, format, v:truncatetodouble())
return buffer
end
)
tmath.numtostr.format[nfloat] = format
for _, func in pairs(unary_math) do
local name = "nfloat_" .. func
local terra impl(x: nfloat)
var y: nfloat
flint.[name](&y.data, &x.data, ctx)
return y
end
tmath[func]:adddefinition(impl)
end
for _, func in pairs(binary_math) do
local name = "nfloat_" .. func
local terra impl(x: nfloat, y: nfloat)
var z: nfloat
flint.[name](&z.data, &x.data, &y.data, ctx)
return z
end
tmath[func]:adddefinition(impl)
end
tmath.min:adddefinition(terra(x : nfloat, y : nfloat)
return terralib.select(x < y, x, y)
end)
tmath.max:adddefinition(terra(x : nfloat, y : nfloat)
return terralib.select(x > y, x, y)
end)
tmath.conj:adddefinition(terra(x: nfloat) return x end)
tmath.real:adddefinition(terra(x: nfloat) return x end)
tmath.imag:adddefinition(terra(x: nfloat) return [nfloat](0) end)
do
local terra impl(x: nfloat, y: nfloat, z: nfloat)
return x * y + z
end
tmath.fusedmuladd:adddefinition(impl)
end
for k, v in pairs({from = from, tostr = to_str, pi = pi}) do
nfloat.staticmethods[k] = v
end
return nfloat
end)
local terra clean_context()
escape
for _, N in pairs(suffix) do
local val = context[N]:get()
emit quote
gr.gr_ctx_clear(val)
end
end
end
end
return {
FixedFloat = FixedFloat,
clean_context = clean_context
}