Skip to content

Commit

Permalink
Use the new <- syntax to reduce indentation throughout the prelude.
Browse files Browse the repository at this point in the history
  • Loading branch information
axch committed Nov 7, 2022
1 parent 7627d9e commit 626ee4d
Showing 1 changed file with 73 additions and 69 deletions.
142 changes: 73 additions & 69 deletions lib/prelude.dx
Original file line number Diff line number Diff line change
Expand Up @@ -1179,9 +1179,9 @@ def with_alloc {a b} [Storable a] (n:Nat) (action: Ptr a -> {IO} b) : {IO} b =
result

def with_table_ptr {a b n} [Storable a] (xs:n=>a) (action : Ptr a -> {IO} b) : {IO} b =
with_alloc (size n) \ptr.
for i. store (ptr +>> ordinal i) xs.i
action ptr
ptr <- with_alloc (size n)
for i. store (ptr +>> ordinal i) xs.i
action ptr

def table_from_ptr {a} [Storable a] (n:Type) [Ix n] (ptr:Ptr a) : {IO} n=>a =
for i. load $ ptr +>> ordinal i
Expand Down Expand Up @@ -1284,18 +1284,18 @@ TODO: Move this to be with reductions?
It's a kind of `scan`.

def cumsum {n a} [Add a] (xs: n=>a) : n=>a =
with_state zero \total.
for i.
newTotal = get total + xs.i
total := newTotal
newTotal
total <- with_state zero
for i.
newTotal = get total + xs.i
total := newTotal
newTotal

def cumsum_low {n a} [Add a] (xs: n=>a) : n=>a =
with_state zero \total.
for i.
oldTotal = get total
total := oldTotal + xs.i
oldTotal
total <- with_state zero
for i.
oldTotal = get total
total := oldTotal + xs.i
oldTotal

'## Automatic differentiation

Expand Down Expand Up @@ -1535,16 +1535,18 @@ data DynBuffer a =
def with_dynamic_buffer {a b} [Storable a]
(action: DynBuffer a -> {IO} b) : {IO} b =
initMaxSize = 256
with_alloc 1 \sizePtr. with_alloc 1 \maxSizePtr. with_alloc 1 \bufferPtr.
store sizePtr 0
store maxSizePtr initMaxSize
store bufferPtr $ malloc initMaxSize
result = action $ MkDynBuffer {
size = sizePtr
, maxSize = maxSizePtr
, buffer = bufferPtr }
free $ load bufferPtr
result
sizePtr <- with_alloc 1
store sizePtr 0
maxSizePtr <- with_alloc 1
store maxSizePtr initMaxSize
bufferPtr <- with_alloc 1
store bufferPtr $ malloc initMaxSize
result = action $ MkDynBuffer {
size = sizePtr
, maxSize = maxSizePtr
, buffer = bufferPtr }
free $ load bufferPtr
result

def maybe_increase_buffer_size {a} [Storable a]
((MkDynBuffer db): DynBuffer a) (sizeDelta:Nat) : {IO} Unit =
Expand Down Expand Up @@ -1768,7 +1770,8 @@ def lift_state {a b c h eff} (ref: Ref h c) (f:a -> {|eff} b) (x:a) : {State h|e

-- A little iteration combinator
def iter {a eff} (body: Nat -> {|eff} IterResult a) : {|eff} a =
result = yield_state Nothing \resultRef. with_state 0 \i.
result = yield_state Nothing \resultRef.
i <- with_state 0
while do
continue = is_nothing $ get resultRef
if continue then
Expand Down Expand Up @@ -1817,16 +1820,16 @@ def fread (stream:Stream ReadMode) : {IO} String =
(MkStream stream') = stream
-- TODO: allow reading longer files!
n = 4096
with_alloc n \ptr:(Ptr Char).
with_dynamic_buffer \buf.
iter \_.
(MkPtr rawPtr) = ptr
numRead = i_to_w32 $ i64_to_i $ freadFFI rawPtr (i_to_i64 1) (n_to_i64 n) stream'
extend_dynamic_buffer buf $ string_from_char_ptr numRead ptr
if numRead == n_to_w32 n
then Continue
else Done ()
load_dynamic_buffer buf
ptr:(Ptr Char) <- with_alloc n
buf <- with_dynamic_buffer
iter \_.
(MkPtr rawPtr) = ptr
numRead = i_to_w32 $ i64_to_i $ freadFFI rawPtr (i_to_i64 1) (n_to_i64 n) stream'
extend_dynamic_buffer buf $ string_from_char_ptr numRead ptr
if numRead == n_to_w32 n
then Continue
else Done ()
load_dynamic_buffer buf

'### Print

Expand Down Expand Up @@ -2262,34 +2265,35 @@ def is_power_of_2 (x:Nat) : Bool =
else 0 == %and x' (%isub x' (1::NatRep))

def natlog2 (x:Nat) : Nat =
tmp = yield_state 0 \ansRef.
run_state 1 \cmpRef.
while do
if x >= (get cmpRef)
then
ansRef := (get ansRef) + 1
cmpRef := rep_to_nat $ %shl (nat_to_rep $ get cmpRef) (1 :: NatRep)
True
else
False
tmp = yield_state 0 \ans.
cmp <- run_state 1
while do
if x >= (get cmp)
then
ans := (get ans) + 1
cmp := rep_to_nat $ %shl (nat_to_rep $ get cmp) (1 :: NatRep)
True
else
False
unsafe_nat_diff tmp 1 -- TODO: something less horrible

def general_integer_power {a} (times:a->a->a) (one:a) (base:a) (power:Nat) : a =
-- Implements exponentiation by squaring.
-- This could be nicer if there were a way to explicitly
-- specify which typelcass instance to use for Mul.
yield_state one \ans.
with_state power \pow. with_state base \z.
while do
if get pow > 0
then
if is_odd (get pow)
then ans := times (get ans) (get z)
z := times (get z) (get z)
pow := intdiv2 (get pow)
True
else
False
pow <- with_state power
z <- with_state base
while do
if get pow > 0
then
if is_odd (get pow)
then ans := times (get ans) (get z)
z := times (get z) (get z)
pow := intdiv2 (get pow)
True
else
False

def intpow {a} [Mul a] (base:a) (power:Nat) : a =
general_integer_power (*) one base power
Expand Down Expand Up @@ -2318,20 +2322,20 @@ def list_length {a} ((AsList n _):List a) : Nat = n
def concat {n a} (lists:n=>(List a)) : List a =
totalSize = sum for i. list_length lists.i
AsList _ $ with_state 0 \listIdx.
with_state 0 \eltIdx.
for i:(Fin totalSize).
while do
continue = get eltIdx >= list_length (lists.((get listIdx)@_))
if continue
then
eltIdx := 0
listIdx := get listIdx + 1
else ()
continue
(AsList _ xs) = lists.((get listIdx)@_)
eltIdxVal = get eltIdx
eltIdx := eltIdxVal + 1
xs.(eltIdxVal@_)
eltIdx <- with_state 0
for i:(Fin totalSize).
while do
continue = get eltIdx >= list_length (lists.((get listIdx)@_))
if continue
then
eltIdx := 0
listIdx := get listIdx + 1
else ()
continue
(AsList _ xs) = lists.((get listIdx)@_)
eltIdxVal = get eltIdx
eltIdx := eltIdxVal + 1
xs.(eltIdxVal@_)

def cat_maybes {a n} (xs:n=>Maybe a) : List a =
(num_res, res_inds) = yield_state (0::Nat, for i:n. Nothing) \ref.
Expand Down

0 comments on commit 626ee4d

Please sign in to comment.