Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix JuvixTree type unification #2972

Merged
merged 3 commits into from
Aug 27, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
fix JuvixTree type unification
lukaszcz committed Aug 23, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
commit 5b02f81f76d6c9d78390197f41495fdf4ad0a1b7
224 changes: 119 additions & 105 deletions src/Juvix/Compiler/Tree/Extra/Type.hs
Original file line number Diff line number Diff line change
@@ -39,53 +39,60 @@ curryType ty = case typeArgs ty of
in foldr (\tyarg ty'' -> mkTypeFun [tyarg] ty'') (typeTarget ty') tyargs

isSubtype :: Type -> Type -> Bool
isSubtype ty1 ty2 = case (ty1, ty2) of
(TyDynamic, _) -> True
(_, TyDynamic) -> True
(TyConstr TypeConstr {..}, TyInductive TypeInductive {..}) ->
_typeConstrInductive == _typeInductiveSymbol
(TyConstr c1, TyConstr c2) ->
c1 ^. typeConstrInductive == c2 ^. typeConstrInductive
&& c1 ^. typeConstrTag == c2 ^. typeConstrTag
&& all (uncurry isSubtype) (zip (c1 ^. typeConstrFields) (c2 ^. typeConstrFields))
(TyFun t1, TyFun t2) ->
let l1 = toList (t1 ^. typeFunArgs)
l2 = toList (t2 ^. typeFunArgs)
r1 = t1 ^. typeFunTarget
r2 = t2 ^. typeFunTarget
in length l1 == length l2 && all (uncurry isSubtype) (zip l2 l1) && isSubtype r1 r2
(TyInteger (TypeInteger l1 u1), TyInteger (TypeInteger l2 u2)) ->
checkBounds (>=) l1 l2 && checkBounds (<=) u1 u2
where
checkBounds :: (Integer -> Integer -> Bool) -> Maybe Integer -> Maybe Integer -> Bool
checkBounds _ Nothing Nothing = True
checkBounds _ Nothing (Just _) = False
checkBounds _ (Just _) Nothing = True
checkBounds cmp (Just x) (Just y) = cmp x y
(TyBool {}, TyBool {}) -> True
(TyString, TyString) -> True
(TyField, TyField) -> True
(TyByteArray, TyByteArray) -> True
(TyUnit, TyUnit) -> True
(TyVoid, TyVoid) -> True
(TyInductive {}, TyInductive {}) -> ty1 == ty2
(TyUnit, _) -> False
(_, TyUnit) -> False
(TyVoid, _) -> False
(_, TyVoid) -> False
(TyInteger {}, _) -> False
(_, TyInteger {}) -> False
(TyString, _) -> False
(_, TyString) -> False
(TyField, _) -> False
(_, TyField) -> False
(TyByteArray, _) -> False
(_, TyByteArray) -> False
(TyBool {}, _) -> False
(_, TyBool {}) -> False
(TyFun {}, _) -> False
(_, TyFun {}) -> False
(_, TyConstr {}) -> False
isSubtype ty1 ty2 =
let (ty1', ty2') =
if
| typeTarget (uncurryType ty1) == TyDynamic || typeTarget (uncurryType ty2) == TyDynamic ->
(curryType ty1, curryType ty2)
| otherwise ->
(ty1, ty2)
in case (ty1', ty2') of
(TyDynamic, _) -> True
(_, TyDynamic) -> True
(TyConstr TypeConstr {..}, TyInductive TypeInductive {..}) ->
_typeConstrInductive == _typeInductiveSymbol
(TyConstr c1, TyConstr c2) ->
c1 ^. typeConstrInductive == c2 ^. typeConstrInductive
&& c1 ^. typeConstrTag == c2 ^. typeConstrTag
&& all (uncurry isSubtype) (zip (c1 ^. typeConstrFields) (c2 ^. typeConstrFields))
(TyFun t1, TyFun t2) ->
let l1 = toList (t1 ^. typeFunArgs)
l2 = toList (t2 ^. typeFunArgs)
r1 = t1 ^. typeFunTarget
r2 = t2 ^. typeFunTarget
in length l1 == length l2 && all (uncurry isSubtype) (zip l2 l1) && isSubtype r1 r2
(TyInteger (TypeInteger l1 u1), TyInteger (TypeInteger l2 u2)) ->
checkBounds (>=) l1 l2 && checkBounds (<=) u1 u2
where
checkBounds :: (Integer -> Integer -> Bool) -> Maybe Integer -> Maybe Integer -> Bool
checkBounds _ Nothing Nothing = True
checkBounds _ Nothing (Just _) = False
checkBounds _ (Just _) Nothing = True
checkBounds cmp (Just x) (Just y) = cmp x y
(TyBool {}, TyBool {}) -> True
(TyString, TyString) -> True
(TyField, TyField) -> True
(TyByteArray, TyByteArray) -> True
(TyUnit, TyUnit) -> True
(TyVoid, TyVoid) -> True
(TyInductive {}, TyInductive {}) -> ty1 == ty2
(TyUnit, _) -> False
(_, TyUnit) -> False
(TyVoid, _) -> False
(_, TyVoid) -> False
(TyInteger {}, _) -> False
(_, TyInteger {}) -> False
(TyString, _) -> False
(_, TyString) -> False
(TyField, _) -> False
(_, TyField) -> False
(TyByteArray, _) -> False
(_, TyByteArray) -> False
(TyBool {}, _) -> False
(_, TyBool {}) -> False
(TyFun {}, _) -> False
(_, TyFun {}) -> False
(_, TyConstr {}) -> False

isSubtype' :: Type -> Type -> Bool
isSubtype' ty1 ty2
@@ -102,64 +109,71 @@ isSubtype' ty1 ty2 =
isSubtype ty1 ty2

unifyTypes :: forall t e r. (Members '[Error TreeError, Reader (Maybe Location), Reader (InfoTable' t e)] r) => Type -> Type -> Sem r Type
unifyTypes ty1 ty2 = case (ty1, ty2) of
(TyDynamic, x) -> return x
(x, TyDynamic) -> return x
(TyInductive TypeInductive {..}, TyConstr TypeConstr {..})
| _typeInductiveSymbol == _typeConstrInductive ->
return ty1
(TyConstr {}, TyInductive {}) -> unifyTypes @t @e ty2 ty1
(TyConstr c1, TyConstr c2)
| c1 ^. typeConstrInductive == c2 ^. typeConstrInductive
&& c1 ^. typeConstrTag == c2 ^. typeConstrTag -> do
flds <- zipWithM (unifyTypes @t @e) (c1 ^. typeConstrFields) (c2 ^. typeConstrFields)
return $ TyConstr (set typeConstrFields flds c1)
(TyConstr c1, TyConstr c2)
| c1 ^. typeConstrInductive == c2 ^. typeConstrInductive ->
return $ TyInductive (TypeInductive (c1 ^. typeConstrInductive))
(TyFun t1, TyFun t2)
| length (t1 ^. typeFunArgs) == length (t2 ^. typeFunArgs) -> do
let args1 = toList (t1 ^. typeFunArgs)
args2 = toList (t2 ^. typeFunArgs)
tgt1 = t1 ^. typeFunTarget
tgt2 = t2 ^. typeFunTarget
args <- zipWithM (unifyTypes @t @e) args1 args2
tgt <- unifyTypes @t @e tgt1 tgt2
return $ TyFun (TypeFun (nonEmpty' args) tgt)
(TyInteger (TypeInteger l1 u1), TyInteger (TypeInteger l2 u2)) ->
return $ TyInteger (TypeInteger (unifyBounds min l1 l2) (unifyBounds max u1 u2))
where
unifyBounds :: (Integer -> Integer -> Integer) -> Maybe Integer -> Maybe Integer -> Maybe Integer
unifyBounds _ Nothing _ = Nothing
unifyBounds _ _ Nothing = Nothing
unifyBounds f (Just x) (Just y) = Just (f x y)
(TyBool {}, TyBool {})
| ty1 == ty2 -> return ty1
(TyString, TyString) -> return TyString
(TyField, TyField) -> return TyField
(TyByteArray, TyByteArray) -> return TyByteArray
(TyUnit, TyUnit) -> return TyUnit
(TyVoid, TyVoid) -> return TyVoid
(TyInductive {}, TyInductive {})
| ty1 == ty2 -> return ty1
(TyUnit, _) -> err
(_, TyUnit) -> err
(TyVoid, _) -> err
(_, TyVoid) -> err
(TyInteger {}, _) -> err
(_, TyInteger {}) -> err
(TyString, _) -> err
(_, TyString) -> err
(TyField, _) -> err
(_, TyField) -> err
(TyByteArray, _) -> err
(_, TyByteArray) -> err
(TyBool {}, _) -> err
(_, TyBool {}) -> err
(TyFun {}, _) -> err
(_, TyFun {}) -> err
(TyInductive {}, _) -> err
(_, TyConstr {}) -> err
unifyTypes ty1 ty2 =
let (ty1', ty2') =
if
| typeTarget (uncurryType ty1) == TyDynamic || typeTarget (uncurryType ty2) == TyDynamic ->
(curryType ty1, curryType ty2)
| otherwise ->
(ty1, ty2)
in case (ty1', ty2') of
(TyDynamic, x) -> return x
(x, TyDynamic) -> return x
(TyInductive TypeInductive {..}, TyConstr TypeConstr {..})
| _typeInductiveSymbol == _typeConstrInductive ->
return ty1
(TyConstr {}, TyInductive {}) -> unifyTypes @t @e ty2 ty1
(TyConstr c1, TyConstr c2)
| c1 ^. typeConstrInductive == c2 ^. typeConstrInductive
&& c1 ^. typeConstrTag == c2 ^. typeConstrTag -> do
flds <- zipWithM (unifyTypes @t @e) (c1 ^. typeConstrFields) (c2 ^. typeConstrFields)
return $ TyConstr (set typeConstrFields flds c1)
(TyConstr c1, TyConstr c2)
| c1 ^. typeConstrInductive == c2 ^. typeConstrInductive ->
return $ TyInductive (TypeInductive (c1 ^. typeConstrInductive))
(TyFun t1, TyFun t2)
| length (t1 ^. typeFunArgs) == length (t2 ^. typeFunArgs) -> do
let args1 = toList (t1 ^. typeFunArgs)
args2 = toList (t2 ^. typeFunArgs)
tgt1 = t1 ^. typeFunTarget
tgt2 = t2 ^. typeFunTarget
args <- zipWithM (unifyTypes @t @e) args1 args2
tgt <- unifyTypes @t @e tgt1 tgt2
return $ TyFun (TypeFun (nonEmpty' args) tgt)
(TyInteger (TypeInteger l1 u1), TyInteger (TypeInteger l2 u2)) ->
return $ TyInteger (TypeInteger (unifyBounds min l1 l2) (unifyBounds max u1 u2))
where
unifyBounds :: (Integer -> Integer -> Integer) -> Maybe Integer -> Maybe Integer -> Maybe Integer
unifyBounds _ Nothing _ = Nothing
unifyBounds _ _ Nothing = Nothing
unifyBounds f (Just x) (Just y) = Just (f x y)
(TyBool {}, TyBool {})
| ty1 == ty2 -> return ty1
(TyString, TyString) -> return TyString
(TyField, TyField) -> return TyField
(TyByteArray, TyByteArray) -> return TyByteArray
(TyUnit, TyUnit) -> return TyUnit
(TyVoid, TyVoid) -> return TyVoid
(TyInductive {}, TyInductive {})
| ty1 == ty2 -> return ty1
(TyUnit, _) -> err
(_, TyUnit) -> err
(TyVoid, _) -> err
(_, TyVoid) -> err
(TyInteger {}, _) -> err
(_, TyInteger {}) -> err
(TyString, _) -> err
(_, TyString) -> err
(TyField, _) -> err
(_, TyField) -> err
(TyByteArray, _) -> err
(_, TyByteArray) -> err
(TyBool {}, _) -> err
(_, TyBool {}) -> err
(TyFun {}, _) -> err
(_, TyFun {}) -> err
(TyInductive {}, _) -> err
(_, TyConstr {}) -> err
where
err :: Sem r a
err = do