Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce indentation hell. #1147

Merged
merged 2 commits into from
Nov 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 73 additions & 69 deletions lib/prelude.dx
Original file line number Diff line number Diff line change
Expand Up @@ -1179,9 +1179,9 @@ def with_alloc {a b} [Storable a] (n:Nat) (action: Ptr a -> {IO} b) : {IO} b =
result

def with_table_ptr {a b n} [Storable a] (xs:n=>a) (action : Ptr a -> {IO} b) : {IO} b =
with_alloc (size n) \ptr.
for i. store (ptr +>> ordinal i) xs.i
action ptr
ptr <- with_alloc (size n)
for i. store (ptr +>> ordinal i) xs.i
action ptr

def table_from_ptr {a} [Storable a] (n:Type) [Ix n] (ptr:Ptr a) : {IO} n=>a =
for i. load $ ptr +>> ordinal i
Expand Down Expand Up @@ -1284,18 +1284,18 @@ TODO: Move this to be with reductions?
It's a kind of `scan`.

def cumsum {n a} [Add a] (xs: n=>a) : n=>a =
with_state zero \total.
for i.
newTotal = get total + xs.i
total := newTotal
newTotal
total <- with_state zero
for i.
newTotal = get total + xs.i
total := newTotal
newTotal

def cumsum_low {n a} [Add a] (xs: n=>a) : n=>a =
with_state zero \total.
for i.
oldTotal = get total
total := oldTotal + xs.i
oldTotal
total <- with_state zero
for i.
oldTotal = get total
total := oldTotal + xs.i
oldTotal

'## Automatic differentiation

Expand Down Expand Up @@ -1535,16 +1535,18 @@ data DynBuffer a =
def with_dynamic_buffer {a b} [Storable a]
(action: DynBuffer a -> {IO} b) : {IO} b =
initMaxSize = 256
with_alloc 1 \sizePtr. with_alloc 1 \maxSizePtr. with_alloc 1 \bufferPtr.
store sizePtr 0
store maxSizePtr initMaxSize
store bufferPtr $ malloc initMaxSize
result = action $ MkDynBuffer {
size = sizePtr
, maxSize = maxSizePtr
, buffer = bufferPtr }
free $ load bufferPtr
result
sizePtr <- with_alloc 1
store sizePtr 0
maxSizePtr <- with_alloc 1
store maxSizePtr initMaxSize
bufferPtr <- with_alloc 1
store bufferPtr $ malloc initMaxSize
result = action $ MkDynBuffer {
size = sizePtr
, maxSize = maxSizePtr
, buffer = bufferPtr }
free $ load bufferPtr
result

def maybe_increase_buffer_size {a} [Storable a]
((MkDynBuffer db): DynBuffer a) (sizeDelta:Nat) : {IO} Unit =
Expand Down Expand Up @@ -1768,7 +1770,8 @@ def lift_state {a b c h eff} (ref: Ref h c) (f:a -> {|eff} b) (x:a) : {State h|e

-- A little iteration combinator
def iter {a eff} (body: Nat -> {|eff} IterResult a) : {|eff} a =
result = yield_state Nothing \resultRef. with_state 0 \i.
result = yield_state Nothing \resultRef.
i <- with_state 0
while do
continue = is_nothing $ get resultRef
if continue then
Expand Down Expand Up @@ -1817,16 +1820,16 @@ def fread (stream:Stream ReadMode) : {IO} String =
(MkStream stream') = stream
-- TODO: allow reading longer files!
n = 4096
with_alloc n \ptr:(Ptr Char).
with_dynamic_buffer \buf.
iter \_.
(MkPtr rawPtr) = ptr
numRead = i_to_w32 $ i64_to_i $ freadFFI rawPtr (i_to_i64 1) (n_to_i64 n) stream'
extend_dynamic_buffer buf $ string_from_char_ptr numRead ptr
if numRead == n_to_w32 n
then Continue
else Done ()
load_dynamic_buffer buf
ptr:(Ptr Char) <- with_alloc n
buf <- with_dynamic_buffer
iter \_.
(MkPtr rawPtr) = ptr
numRead = i_to_w32 $ i64_to_i $ freadFFI rawPtr (i_to_i64 1) (n_to_i64 n) stream'
extend_dynamic_buffer buf $ string_from_char_ptr numRead ptr
if numRead == n_to_w32 n
then Continue
else Done ()
load_dynamic_buffer buf

'### Print

Expand Down Expand Up @@ -2262,34 +2265,35 @@ def is_power_of_2 (x:Nat) : Bool =
else 0 == %and x' (%isub x' (1::NatRep))

def natlog2 (x:Nat) : Nat =
tmp = yield_state 0 \ansRef.
run_state 1 \cmpRef.
while do
if x >= (get cmpRef)
then
ansRef := (get ansRef) + 1
cmpRef := rep_to_nat $ %shl (nat_to_rep $ get cmpRef) (1 :: NatRep)
True
else
False
tmp = yield_state 0 \ans.
cmp <- run_state 1
while do
if x >= (get cmp)
then
ans := (get ans) + 1
cmp := rep_to_nat $ %shl (nat_to_rep $ get cmp) (1 :: NatRep)
True
else
False
unsafe_nat_diff tmp 1 -- TODO: something less horrible

def general_integer_power {a} (times:a->a->a) (one:a) (base:a) (power:Nat) : a =
-- Implements exponentiation by squaring.
-- This could be nicer if there were a way to explicitly
-- specify which typelcass instance to use for Mul.
yield_state one \ans.
with_state power \pow. with_state base \z.
while do
if get pow > 0
then
if is_odd (get pow)
then ans := times (get ans) (get z)
z := times (get z) (get z)
pow := intdiv2 (get pow)
True
else
False
pow <- with_state power
z <- with_state base
while do
if get pow > 0
then
if is_odd (get pow)
then ans := times (get ans) (get z)
z := times (get z) (get z)
pow := intdiv2 (get pow)
True
else
False

def intpow {a} [Mul a] (base:a) (power:Nat) : a =
general_integer_power (*) one base power
Expand Down Expand Up @@ -2318,20 +2322,20 @@ def list_length {a} ((AsList n _):List a) : Nat = n
def concat {n a} (lists:n=>(List a)) : List a =
totalSize = sum for i. list_length lists.i
AsList _ $ with_state 0 \listIdx.
with_state 0 \eltIdx.
for i:(Fin totalSize).
while do
continue = get eltIdx >= list_length (lists.((get listIdx)@_))
if continue
then
eltIdx := 0
listIdx := get listIdx + 1
else ()
continue
(AsList _ xs) = lists.((get listIdx)@_)
eltIdxVal = get eltIdx
eltIdx := eltIdxVal + 1
xs.(eltIdxVal@_)
eltIdx <- with_state 0
for i:(Fin totalSize).
while do
continue = get eltIdx >= list_length (lists.((get listIdx)@_))
if continue
then
eltIdx := 0
listIdx := get listIdx + 1
else ()
continue
(AsList _ xs) = lists.((get listIdx)@_)
eltIdxVal = get eltIdx
eltIdx := eltIdxVal + 1
xs.(eltIdxVal@_)

def cat_maybes {a n} (xs:n=>Maybe a) : List a =
(num_res, res_inds) = yield_state (0::Nat, for i:n. Nothing) \ref.
Expand Down
6 changes: 6 additions & 0 deletions src/lib/AbstractSyntax.hs
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ effectOpDef (v, Just rp, rhs) = UEffectOpDef rp (fromString v) <$> block rhs
decl :: LetAnn -> CDecl -> SyntaxM (UDecl VoidS VoidS)
decl ann = dropSrc decl' where
decl' (CLet binder body) = ULet ann <$> patOptAnn binder <*> block body
decl' (CBind _ _) = throw SyntaxErr "Arrow binder syntax <- not permitted at the top level, because the binding would have unbounded scope."
decl' (CDef name params maybeTy body) = do
params' <- concat <$> (mapM argument $ nary Juxtapose params)
case maybeTy of
Expand Down Expand Up @@ -409,6 +410,11 @@ method (name, body) = UMethodDef (fromString name) <$> block body
block :: CBlock -> SyntaxM (UExpr VoidS)
block (CBlock []) = throw SyntaxErr "Block must end in expression"
block (CBlock [ExprDecl g]) = expr g
block (CBlock ((WithSrc pos (CBind binder rhs)):ds)) = do
binder' <- patOptAnn binder
rhs' <- block rhs
body <- block $ CBlock ds
return $ WithSrcE pos $ UApp rhs' $ ns $ ULam $ ULamExpr PlainArrow binder' body
block (CBlock (d@(WithSrc pos _):ds)) = do
d' <- decl PlainLet d
e' <- block $ CBlock ds
Expand Down
18 changes: 12 additions & 6 deletions src/lib/ConcreteSyntax.hs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ type NameAndArgs = (SourceName, [Group])
type CDecl = WithSrc CDecl'
data CDecl'
= CLet Group CBlock
-- Arrow binder <-
| CBind Group CBlock
-- name, args, type, body. The header should contain the parameters,
-- optional effects, and return type
| CDef SourceName Group (Maybe Group) CBlock
Expand Down Expand Up @@ -475,8 +477,8 @@ cBlock' = Left <$> realBlock <|> Right <$> cGroupNoSeparators where

cDecl :: Parser CDecl
cDecl = instanceDef True <|> (do
lhs <- funDefLet <|> (try $ simpleLet <* lookAhead (sym "="))
rhs <- sym "=" >> cBlock
lhs <- funDefLet <|> (try simpleLet)
rhs <- cBlock
return $ lhs rhs) <|> (ExprDecl <$> cGroup)

instanceDef :: Bool -> Parser CDecl
Expand All @@ -497,17 +499,21 @@ instanceMethod = do
return (fromString v, rhs)

simpleLet :: Parser (CBlock -> CDecl)
simpleLet = withSrc1 do
simpleLet = withSrc1 $ do
binder <- cGroupNoEqual
return $ CLet binder
next <- nextChar
case next of
'=' -> sym "=" >> return (CLet binder)
'<' -> sym "<-" >> return (CBind binder)
_ -> fail ""

funDefLet :: Parser (CBlock -> CDecl)
funDefLet = label "function definition" $ mayBreak $ withSrc1 do
funDefLet = label "function definition" (mayBreak $ withSrc1 do
keyWord DefKW
name <- anyName
args <- cGroupNoColon <|> pure (WithSrc Nothing CEmpty)
typeAnn <- optional (sym ":" >> cGroupNoEqual)
return (CDef name args typeAnn)
return (CDef name args typeAnn)) <* sym "="

cGroup :: Parser Group
cGroup = makeExprParser (withSrc leafGroup) ops
Expand Down
8 changes: 7 additions & 1 deletion src/lib/Lexing.hs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import Text.Megaparsec hiding (Label, State)
import Text.Megaparsec.Char hiding (space, eol)
import qualified Text.Megaparsec.Char as MC
import qualified Text.Megaparsec.Char.Lexer as L
import Text.Megaparsec.Debug

import Err
import LabeledItems
Expand All @@ -33,6 +34,11 @@ mustParseit s p = case parseit s p of
Success x -> x
Failure e -> error $ "This shouldn't happen:\n" ++ pprint e

debug :: (Show a) => String -> Parser a -> Parser a
debug lbl action = do
ctx <- ask
lift $ dbg lbl $ runReaderT action ctx

-- === Lexemes ===

type Lexer = Parser
Expand Down Expand Up @@ -141,7 +147,7 @@ doubleLit = lexeme $
knownSymStrs :: HS.HashSet String
knownSymStrs = HS.fromList
[ ".", ":", "::", "!", "=", "-", "+", "||", "&&"
, "$", "&", "&>", "|", ",", ",>", "+=", ":="
, "$", "&", "&>", "|", ",", ",>", "<-", "+=", ":="
, "->", "=>", "?->", "?=>", "--o", "--", "<<<", ">>>", "<<&", "&>>"
, "..", "<..", "..<", "..<", "<..<", "?", "#", "##", "#?", "#&", "#|", "@"]

Expand Down
1 change: 1 addition & 0 deletions src/lib/PPrint.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1144,6 +1144,7 @@ instance Pretty CDecl where

instance Pretty CDecl' where
pretty (CLet pat blk) = pArg pat <+> "=" <+> p blk
pretty (CBind pat blk) = pArg pat <+> "<-" <+> p blk
pretty (CDef name args (Just ty) blk) =
"def " <> fromString name <> " " <> pArg args <> " : " <> pArg ty <> " ="
<> nest 2 (hardline <> p blk)
Expand Down
9 changes: 9 additions & 0 deletions tests/parser-tests.dx
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,12 @@ data MyPair3 (a:Type) (b:Type) = MkPair3 (x:a) (y:b)

data TableInType n a [Ix n] table:(n=>a) =
MkTableInType -- Doesn't store any data except in the type!

'Left arrow <- desugars to a continuation lambda
(feature test for Issue 1137)

:p
x <- with_state 0
x := 4
get x
> 4