Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix JuvixTree type unification #2972

Merged
merged 3 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Juvix/Compiler/Asm/Extra/Memory.hs
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ checkValueStack' loc tab tys mem = do
mapM_
( \(ty, idx) -> do
let ty' = fromJust $ topValueStack idx mem
unless (isSubtype' ty' ty) $
unless (isSubtype ty' ty) $
throw $
AsmError loc $
"type mismatch on value stack cell "
Expand Down
238 changes: 119 additions & 119 deletions src/Juvix/Compiler/Tree/Extra/Type.hs
Original file line number Diff line number Diff line change
Expand Up @@ -39,127 +39,127 @@ curryType ty = case typeArgs ty of
in foldr (\tyarg ty'' -> mkTypeFun [tyarg] ty'') (typeTarget ty') tyargs

isSubtype :: Type -> Type -> Bool
isSubtype ty1 ty2 = case (ty1, ty2) of
(TyDynamic, _) -> True
(_, TyDynamic) -> True
(TyConstr TypeConstr {..}, TyInductive TypeInductive {..}) ->
_typeConstrInductive == _typeInductiveSymbol
(TyConstr c1, TyConstr c2) ->
c1 ^. typeConstrInductive == c2 ^. typeConstrInductive
&& c1 ^. typeConstrTag == c2 ^. typeConstrTag
&& all (uncurry isSubtype) (zip (c1 ^. typeConstrFields) (c2 ^. typeConstrFields))
(TyFun t1, TyFun t2) ->
let l1 = toList (t1 ^. typeFunArgs)
l2 = toList (t2 ^. typeFunArgs)
r1 = t1 ^. typeFunTarget
r2 = t2 ^. typeFunTarget
in length l1 == length l2 && all (uncurry isSubtype) (zip l2 l1) && isSubtype r1 r2
(TyInteger (TypeInteger l1 u1), TyInteger (TypeInteger l2 u2)) ->
checkBounds (>=) l1 l2 && checkBounds (<=) u1 u2
where
checkBounds :: (Integer -> Integer -> Bool) -> Maybe Integer -> Maybe Integer -> Bool
checkBounds _ Nothing Nothing = True
checkBounds _ Nothing (Just _) = False
checkBounds _ (Just _) Nothing = True
checkBounds cmp (Just x) (Just y) = cmp x y
(TyBool {}, TyBool {}) -> True
(TyString, TyString) -> True
(TyField, TyField) -> True
(TyByteArray, TyByteArray) -> True
(TyUnit, TyUnit) -> True
(TyVoid, TyVoid) -> True
(TyInductive {}, TyInductive {}) -> ty1 == ty2
(TyUnit, _) -> False
(_, TyUnit) -> False
(TyVoid, _) -> False
(_, TyVoid) -> False
(TyInteger {}, _) -> False
(_, TyInteger {}) -> False
(TyString, _) -> False
(_, TyString) -> False
(TyField, _) -> False
(_, TyField) -> False
(TyByteArray, _) -> False
(_, TyByteArray) -> False
(TyBool {}, _) -> False
(_, TyBool {}) -> False
(TyFun {}, _) -> False
(_, TyFun {}) -> False
(_, TyConstr {}) -> False

isSubtype' :: Type -> Type -> Bool
isSubtype' ty1 ty2
-- The guard is to ensure correct behaviour with dynamic type targets. E.g.
-- `A -> B -> C -> D` should be a subtype of `(A, B) -> *`.
| tgt1 == TyDynamic || tgt2 == TyDynamic =
isSubtype
(curryType ty1)
(curryType ty2)
where
tgt1 = typeTarget (uncurryType ty1)
tgt2 = typeTarget (uncurryType ty2)
isSubtype' ty1 ty2 =
isSubtype ty1 ty2
isSubtype ty1 ty2 =
let (ty1', ty2') =
if
| typeTarget (uncurryType ty1) == TyDynamic || typeTarget (uncurryType ty2) == TyDynamic ->
(curryType ty1, curryType ty2)
| otherwise ->
(ty1, ty2)
in case (ty1', ty2') of
(TyDynamic, _) -> True
(_, TyDynamic) -> True
(TyConstr TypeConstr {..}, TyInductive TypeInductive {..}) ->
_typeConstrInductive == _typeInductiveSymbol
(TyConstr c1, TyConstr c2) ->
c1 ^. typeConstrInductive == c2 ^. typeConstrInductive
&& c1 ^. typeConstrTag == c2 ^. typeConstrTag
&& all (uncurry isSubtype) (zip (c1 ^. typeConstrFields) (c2 ^. typeConstrFields))
(TyFun t1, TyFun t2) ->
let l1 = toList (t1 ^. typeFunArgs)
l2 = toList (t2 ^. typeFunArgs)
r1 = t1 ^. typeFunTarget
r2 = t2 ^. typeFunTarget
in length l1 == length l2 && all (uncurry isSubtype) (zip l2 l1) && isSubtype r1 r2
(TyInteger (TypeInteger l1 u1), TyInteger (TypeInteger l2 u2)) ->
checkBounds (>=) l1 l2 && checkBounds (<=) u1 u2
where
checkBounds :: (Integer -> Integer -> Bool) -> Maybe Integer -> Maybe Integer -> Bool
checkBounds _ Nothing Nothing = True
checkBounds _ Nothing (Just _) = False
checkBounds _ (Just _) Nothing = True
checkBounds cmp (Just x) (Just y) = cmp x y
(TyBool {}, TyBool {}) -> True
(TyString, TyString) -> True
(TyField, TyField) -> True
(TyByteArray, TyByteArray) -> True
(TyUnit, TyUnit) -> True
(TyVoid, TyVoid) -> True
(TyInductive {}, TyInductive {}) -> ty1 == ty2
(TyUnit, _) -> False
(_, TyUnit) -> False
(TyVoid, _) -> False
(_, TyVoid) -> False
(TyInteger {}, _) -> False
(_, TyInteger {}) -> False
(TyString, _) -> False
(_, TyString) -> False
(TyField, _) -> False
(_, TyField) -> False
(TyByteArray, _) -> False
(_, TyByteArray) -> False
(TyBool {}, _) -> False
(_, TyBool {}) -> False
(TyFun {}, _) -> False
(_, TyFun {}) -> False
(_, TyConstr {}) -> False

unifyTypes :: forall t e r. (Members '[Error TreeError, Reader (Maybe Location), Reader (InfoTable' t e)] r) => Type -> Type -> Sem r Type
unifyTypes ty1 ty2 = case (ty1, ty2) of
(TyDynamic, x) -> return x
(x, TyDynamic) -> return x
(TyInductive TypeInductive {..}, TyConstr TypeConstr {..})
| _typeInductiveSymbol == _typeConstrInductive ->
return ty1
(TyConstr {}, TyInductive {}) -> unifyTypes @t @e ty2 ty1
(TyConstr c1, TyConstr c2)
| c1 ^. typeConstrInductive == c2 ^. typeConstrInductive
&& c1 ^. typeConstrTag == c2 ^. typeConstrTag -> do
flds <- zipWithM (unifyTypes @t @e) (c1 ^. typeConstrFields) (c2 ^. typeConstrFields)
return $ TyConstr (set typeConstrFields flds c1)
(TyConstr c1, TyConstr c2)
| c1 ^. typeConstrInductive == c2 ^. typeConstrInductive ->
return $ TyInductive (TypeInductive (c1 ^. typeConstrInductive))
(TyFun t1, TyFun t2)
| length (t1 ^. typeFunArgs) == length (t2 ^. typeFunArgs) -> do
let args1 = toList (t1 ^. typeFunArgs)
args2 = toList (t2 ^. typeFunArgs)
tgt1 = t1 ^. typeFunTarget
tgt2 = t2 ^. typeFunTarget
args <- zipWithM (unifyTypes @t @e) args1 args2
tgt <- unifyTypes @t @e tgt1 tgt2
return $ TyFun (TypeFun (nonEmpty' args) tgt)
(TyInteger (TypeInteger l1 u1), TyInteger (TypeInteger l2 u2)) ->
return $ TyInteger (TypeInteger (unifyBounds min l1 l2) (unifyBounds max u1 u2))
where
unifyBounds :: (Integer -> Integer -> Integer) -> Maybe Integer -> Maybe Integer -> Maybe Integer
unifyBounds _ Nothing _ = Nothing
unifyBounds _ _ Nothing = Nothing
unifyBounds f (Just x) (Just y) = Just (f x y)
(TyBool {}, TyBool {})
| ty1 == ty2 -> return ty1
(TyString, TyString) -> return TyString
(TyField, TyField) -> return TyField
(TyByteArray, TyByteArray) -> return TyByteArray
(TyUnit, TyUnit) -> return TyUnit
(TyVoid, TyVoid) -> return TyVoid
(TyInductive {}, TyInductive {})
| ty1 == ty2 -> return ty1
(TyUnit, _) -> err
(_, TyUnit) -> err
(TyVoid, _) -> err
(_, TyVoid) -> err
(TyInteger {}, _) -> err
(_, TyInteger {}) -> err
(TyString, _) -> err
(_, TyString) -> err
(TyField, _) -> err
(_, TyField) -> err
(TyByteArray, _) -> err
(_, TyByteArray) -> err
(TyBool {}, _) -> err
(_, TyBool {}) -> err
(TyFun {}, _) -> err
(_, TyFun {}) -> err
(TyInductive {}, _) -> err
(_, TyConstr {}) -> err
unifyTypes ty1 ty2 =
let (ty1', ty2') =
if
| typeTarget (uncurryType ty1) == TyDynamic || typeTarget (uncurryType ty2) == TyDynamic ->
(curryType ty1, curryType ty2)
| otherwise ->
(ty1, ty2)
in case (ty1', ty2') of
(TyDynamic, x) -> return x
(x, TyDynamic) -> return x
(TyInductive TypeInductive {..}, TyConstr TypeConstr {..})
| _typeInductiveSymbol == _typeConstrInductive ->
return ty1
(TyConstr {}, TyInductive {}) -> unifyTypes @t @e ty2 ty1
(TyConstr c1, TyConstr c2)
| c1 ^. typeConstrInductive == c2 ^. typeConstrInductive
&& c1 ^. typeConstrTag == c2 ^. typeConstrTag -> do
flds <- zipWithM (unifyTypes @t @e) (c1 ^. typeConstrFields) (c2 ^. typeConstrFields)
return $ TyConstr (set typeConstrFields flds c1)
(TyConstr c1, TyConstr c2)
| c1 ^. typeConstrInductive == c2 ^. typeConstrInductive ->
return $ TyInductive (TypeInductive (c1 ^. typeConstrInductive))
(TyFun t1, TyFun t2)
| length (t1 ^. typeFunArgs) == length (t2 ^. typeFunArgs) -> do
let args1 = toList (t1 ^. typeFunArgs)
args2 = toList (t2 ^. typeFunArgs)
tgt1 = t1 ^. typeFunTarget
tgt2 = t2 ^. typeFunTarget
args <- zipWithM (unifyTypes @t @e) args1 args2
tgt <- unifyTypes @t @e tgt1 tgt2
return $ TyFun (TypeFun (nonEmpty' args) tgt)
(TyInteger (TypeInteger l1 u1), TyInteger (TypeInteger l2 u2)) ->
return $ TyInteger (TypeInteger (unifyBounds min l1 l2) (unifyBounds max u1 u2))
where
unifyBounds :: (Integer -> Integer -> Integer) -> Maybe Integer -> Maybe Integer -> Maybe Integer
unifyBounds _ Nothing _ = Nothing
unifyBounds _ _ Nothing = Nothing
unifyBounds f (Just x) (Just y) = Just (f x y)
(TyBool {}, TyBool {})
| ty1 == ty2 -> return ty1
(TyString, TyString) -> return TyString
(TyField, TyField) -> return TyField
(TyByteArray, TyByteArray) -> return TyByteArray
(TyUnit, TyUnit) -> return TyUnit
(TyVoid, TyVoid) -> return TyVoid
(TyInductive {}, TyInductive {})
| ty1 == ty2 -> return ty1
(TyUnit, _) -> err
(_, TyUnit) -> err
(TyVoid, _) -> err
(_, TyVoid) -> err
(TyInteger {}, _) -> err
(_, TyInteger {}) -> err
(TyString, _) -> err
(_, TyString) -> err
(TyField, _) -> err
(_, TyField) -> err
(TyByteArray, _) -> err
(_, TyByteArray) -> err
(TyBool {}, _) -> err
(_, TyBool {}) -> err
(TyFun {}, _) -> err
(_, TyFun {}) -> err
(TyInductive {}, _) -> err
(_, TyConstr {}) -> err
where
err :: Sem r a
err = do
Expand Down
8 changes: 6 additions & 2 deletions test/Tree/Asm/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module Tree.Asm.Base where
import Asm.Run.Base qualified as Asm
import Base
import Juvix.Compiler.Asm.Translation.FromTree qualified as Asm
import Juvix.Compiler.Tree.Pipeline qualified as Tree
import Juvix.Compiler.Tree.Translation.FromSource
import Juvix.Data.PPOutput

Expand All @@ -18,5 +19,8 @@ treeAsmAssertion mainFile expectedFile step = do
Left err -> assertFailure (prettyString err)
Right tabIni -> do
step "Translate"
let tab = Asm.fromTree tabIni
Asm.asmRunAssertion' tab expectedFile step
case run $ runError @JuvixError $ Tree.toAsm tabIni of
Left err -> assertFailure (prettyString (fromJuvixError @GenericError err))
Right tab -> do
let tab' = Asm.fromTree tab
Asm.asmRunAssertion' tab' expectedFile step
7 changes: 6 additions & 1 deletion test/Tree/Eval/Positive.hs
Original file line number Diff line number Diff line change
Expand Up @@ -239,5 +239,10 @@ tests =
"Test040: ByteArray"
$(mkRelDir ".")
$(mkRelFile "test040.jvt")
$(mkRelFile "out/test040.out")
$(mkRelFile "out/test040.out"),
PosTest
"Test041: Type unification"
$(mkRelDir ".")
$(mkRelFile "test041.jvt")
$(mkRelFile "out/test041.out")
]
1 change: 1 addition & 0 deletions tests/Tree/positive/out/test041.out
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
0
41 changes: 41 additions & 0 deletions tests/Tree/positive/test041.jvt
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
type Foldable {
mkFoldable : ((* → * → *) → * → * → *) → Foldable;
}

type Box {
mkBox : * → Box;
}

function lambda_16(integer, integer) : integer;
function lambda_18((integer, integer) → integer, integer, Box) : integer;
function foldableBoxintegerI() : Foldable;
function go_17(integer) : integer;
function main() : integer;

function lambda_16(_X : integer, _X' : integer) : integer {
_X'
}

function lambda_18(f : (integer, integer) → integer, ini : integer, _X : Box) : integer {
case[Box](_X) {
mkBox: save {
call[go_17](tmp[0].mkBox[0])
}
}
}

function foldableBoxintegerI() : Foldable {
alloc[mkFoldable](calloc[lambda_18]())
}

function go_17(x' : integer) : integer {
x'
}

function main() : integer {
case[Foldable](call[foldableBoxintegerI]()) {
mkFoldable: save {
ccall(tmp[0].mkFoldable[0], calloc[lambda_16](), 0, alloc[mkBox](0))
}
}
}
Loading