Skip to content

Commit

Permalink
Implement multi result primitives
Browse files Browse the repository at this point in the history
Closes #1555
  • Loading branch information
martijnbastiaan committed Oct 29, 2020
1 parent ae21bb4 commit f325ea0
Show file tree
Hide file tree
Showing 30 changed files with 695 additions and 254 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,4 @@ bbTemplate bbCtx = do
Just compName' = exprToString intrinsicName
compName = TextS.pack compName'

(Identifier result Nothing,_) = bbResult bbCtx
[(Identifier result Nothing,_)] = bbResults bbCtx
Original file line number Diff line number Diff line change
Expand Up @@ -104,4 +104,4 @@ sbioTemplate bbCtx = do
, (outputEnable, Bool, _)
] = bbInputs bbCtx

(Identifier result Nothing,resTy) = bbResult bbCtx
[(Identifier result Nothing,resTy)] = bbResults bbCtx
8 changes: 7 additions & 1 deletion clash-ghc/src-ghc/Clash/GHC/Evaluator.hs
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,12 @@ stepPrim pInfo m tcm
[] -> ghcPrimStep tcm (forcePrims m) pInfo [] [] m
tys -> newBinder tys (Prim pInfo) m tcm

stepMultiPrim :: PrimInfo -> Step
stepMultiPrim pInfo m tcm =
case fst $ splitFunForallTy (multiPrimType pInfo) of
[] -> Nothing -- don't evaluate multi prims
tys -> newBinder tys (MultiPrim pInfo) m tcm

stepLam :: Id -> Term -> Step
stepLam x e = ghcUnwind (Lambda x e)

Expand Down Expand Up @@ -234,6 +240,7 @@ ghcStep m = case mTerm m of
Data dc -> stepData dc m
Literal l -> stepLiteral l m
Prim p -> stepPrim p m
MultiPrim p -> stepMultiPrim p m
Lam v x -> stepLam v x m
TyLam v x -> stepTyLam v x m
App x y -> stepApp x y m
Expand Down Expand Up @@ -456,4 +463,3 @@ letSubst h acc id0 =
where
(i,ids') = freshId ids
x' = modifyVarName (`setUnique` i) x

13 changes: 7 additions & 6 deletions clash-ghc/src-ghc/Clash/GHC/GenerateBindings.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
-}

{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}

module Clash.GHC.GenerateBindings
Expand Down Expand Up @@ -215,8 +216,8 @@ mkBindings primMap bindings clsOps unlocatable = do
checkPrimitive :: CompiledPrimMap -> GHC.CoreBndr -> C2C ()
checkPrimitive primMap v = do
nm <- qualifiedNameString (GHC.varName v)
case HashMap.lookup nm primMap of
Just (extractPrim -> Just (BlackBox _ _ _ _ _ _ _ _ _ inc r ri templ)) -> do
case HashMap.lookup nm primMap >>= extractPrim of
Just (BlackBox{resultNames, resultInits, template, includes}) -> do
let
info = GHC.idInfo v
inline = GHC.inlinePragmaSpec $ GHC.inlinePragInfo info
Expand All @@ -231,10 +232,10 @@ checkPrimitive primMap v = do
warnIf cond msg = traceIf cond ("\n"++loc++"Warning: "++msg) return ()
qName <- Text.unpack <$> qualifiedNameString (GHC.varName v)
let primStr = "primitive " ++ qName ++ " "
let usedArgs = concat [ maybe [] getUsedArguments r
, maybe [] getUsedArguments ri
, getUsedArguments templ
, concatMap (getUsedArguments . snd) inc
let usedArgs = concat [ concatMap getUsedArguments resultNames
, concatMap getUsedArguments resultInits
, getUsedArguments template
, concatMap (getUsedArguments . snd) includes
]

let warnArgs [] = return ()
Expand Down
7 changes: 7 additions & 0 deletions clash-lib/prims/common/Clash_Transformations.json
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,11 @@
, "template" : "~ERRORO"
}
}
, { "BlackBox" :
{ "name" : "c$multiPrimSelect"
, "workInfo" : "Always"
, "kind" : "Expression"
, "template" : "!__SHOULD NOT BE RENDERED__! ~ARG[0]~ARG[1]"
}
}
]
51 changes: 28 additions & 23 deletions clash-lib/src/Clash/Core/Evaluator/Models.hs
Original file line number Diff line number Diff line change
Expand Up @@ -279,11 +279,12 @@ data PatResult
-- case x of ... Stuck if "x" is neutral (cannot choose an alternative)
--
data Neutral a
= NeVar !Id
| NePrim !PrimInfo ![Either a Type]
| NeApp !(Neutral a) !a
| NeTyApp !(Neutral a) !Type
| NeCase !a !Type ![(Pat, a)]
= NeVar !Id
| NePrim !PrimInfo ![Either a Type]
| NeMultiPrim !PrimInfo ![Either a Type]
| NeApp !(Neutral a) !a
| NeTyApp !(Neutral a) !Type
| NeCase !a !Type ![(Pat, a)]
deriving (Show)

-- | A term which has been normalised to weak head normal form (WHNF). This has
Expand Down Expand Up @@ -312,15 +313,16 @@ data Neutral a
-- respects lexical scoping.
--
data Value
= VNeu !(Neutral Value)
| VLit !Literal
| VData !DataCon ![Either Term Type] !LocalEnv
| VPrim !PrimInfo ![Either Term Type] !LocalEnv
| VLam !Id !Term !LocalEnv
| VTyLam !TyVar !Term !LocalEnv
| VCast !Value !Type !Type
| VTick !Value !TickInfo
| VThunk !Term !LocalEnv
= VNeu !(Neutral Value)
| VLit !Literal
| VData !DataCon ![Either Term Type] !LocalEnv
| VPrim !PrimInfo ![Either Term Type] !LocalEnv
| VMultiPrim !PrimInfo ![Either Term Type] !LocalEnv
| VLam !Id !Term !LocalEnv
| VTyLam !TyVar !Term !LocalEnv
| VCast !Value !Type !Type
| VTick !Value !TickInfo
| VThunk !Term !LocalEnv
deriving (Show)

collectValueTicks :: Value -> (Value, [TickInfo])
Expand All @@ -336,14 +338,15 @@ addTicks = foldr (flip VTick)
-- and all partially applied functions in subterms are eta-expanded.
--
data Nf
= NNeu !(Neutral Nf)
| NLit !Literal
| NData !DataCon ![Either Nf Type]
| NPrim !PrimInfo ![Either Nf Type]
| NLam !Id !Nf
| NTyLam !TyVar !Nf
| NCast !Nf !Type !Type
| NTick !Nf !TickInfo
= NNeu !(Neutral Nf)
| NLit !Literal
| NData !DataCon ![Either Nf Type]
| NPrim !PrimInfo ![Either Nf Type]
| NMultiPrim !PrimInfo ![Either Nf Type]
| NLam !Id !Nf
| NTyLam !TyVar !Nf
| NCast !Nf !Type !Type
| NTick !Nf !TickInfo
deriving (Show)

-- Embedding WHNF and HNF values back into Term.
Expand All @@ -358,6 +361,7 @@ instance (AsTerm a) => AsTerm (Neutral a) where
asTerm = \case
NeVar v -> Var v
NePrim p args -> mkApps (Prim p) (first asTerm <$> args)
NeMultiPrim p args -> mkApps (MultiPrim p) (first asTerm <$> args)
NeApp x y -> App (asTerm x) (asTerm y)
NeTyApp x ty -> TyApp (asTerm x) ty
NeCase x ty alts -> Case (asTerm x) ty (second asTerm <$> alts)
Expand All @@ -368,6 +372,7 @@ instance AsTerm Value where
VLit l -> Literal l
VData dc args env -> instHeap env . bindHeap env $ mkApps (Data dc) args
VPrim p args env -> instHeap env . bindHeap env $ mkApps (Prim p) args
VMultiPrim p args env -> instHeap env . bindHeap env $ mkApps (MultiPrim p) args
VLam x e env -> instHeap env $ bindHeap env (Lam x e)
VTyLam x e env -> instHeap env $ bindHeap env (TyLam x e)
VCast x a b -> Cast (asTerm x) a b
Expand Down Expand Up @@ -403,11 +408,11 @@ instance AsTerm Nf where
NLit l -> Literal l
NData dc args -> mkApps (Data dc) (first asTerm <$> args)
NPrim p args -> mkApps (Prim p) (first asTerm <$> args)
NMultiPrim p args -> mkApps (MultiPrim p) (first asTerm <$> args)
NLam x e -> Lam x (asTerm e)
NTyLam x e -> TyLam x (asTerm e)
NCast x a b -> Cast (asTerm x) a b
NTick x ti -> Tick ti (asTerm x)

instance (AsTerm a, AsTerm b) => AsTerm (Either a b) where
asTerm = either asTerm asTerm

49 changes: 48 additions & 1 deletion clash-lib/src/Clash/Core/Evaluator/Semantics.hs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ evaluateWith matchLit matchData evalPrim = go
Literal l -> pure (VLit l)
Data dc -> evaluateDataWith go dc
Prim p -> evaluatePrimWith go evalPrim p
MultiPrim p -> evaluateMultiPrimWith go p
Lam x e -> evaluateLam x e
TyLam x e -> evaluateTyLam x e
App x y -> evaluateAppWith go evalPrim apply x (Left y)
Expand Down Expand Up @@ -125,6 +126,10 @@ etaExpand x =
tcm <- getTyConMap
expand tcm (primType p) y

y@(MultiPrim p, _) -> do
tcm <- getTyConMap
expand tcm (multiPrimType p) y

_ -> pure x
where
expand :: TyConMap -> Type -> (Term, [Either Term Type]) -> Eval Term
Expand Down Expand Up @@ -233,6 +238,19 @@ evaluatePrimWith eval evalPrim p
| isFullyApplied (primType p) [] = evalPrim p []
| otherwise = etaExpand (Prim p) >>= eval

-- | Default implementation for evaluating primitive operations. If the primop
-- is not nullary, then it is eta-expanded and its eta-expanded form returned
-- as a neutral term.
--
-- TODO: Implement evaluation for multi return primitives
evaluateMultiPrimWith
:: (Term -> Eval Value)
-> PrimInfo
-> Eval Value
evaluateMultiPrimWith eval p
| isFullyApplied (primType p) [] = pure (VNeu (NeMultiPrim p []))
| otherwise = etaExpand (MultiPrim p) >>= eval

-- | Default implementation for evaluating lambdas. As a term with a lambda
-- at the head is already in WHNF, this simply returns the term under the
-- lambda with the current local environment.
Expand Down Expand Up @@ -275,6 +293,14 @@ evaluateAppWith eval evalPrim apply x y
primRes <- evalPrim p pArgs
foldM apply primRes (first Left <$> rArgs)

| MultiPrim p <- f
, nArgs <- length . fst $ splitFunForallTy (multiPrimType p)
= case compare (length args) nArgs of
LT -> etaExpand term >>= eval
_ -> do
nfArgs <- mapM (bitraverse eval pure) args
pure (VNeu (NeMultiPrim p nfArgs))

-- Evaluating a function application may change the amount of fuel (e.g. if
-- the LHS of the application is a recursive function.) If we do not add fuel
-- back after calling apply, other subterms may not be unfolded as much as
Expand Down Expand Up @@ -478,6 +504,7 @@ quoteWith eval = go
VLit l -> pure (NLit l)
VData dc args env -> quoteDataWith (eval >=> go) dc args env
VPrim p args env -> quotePrimWith (eval >=> go) p args env
VMultiPrim p args env -> quoteMultiPrimWith (eval >=> go) p args env
VLam i x env -> quoteLamWith go apply (Left i) x env
VTyLam i x env -> quoteLamWith go apply (Right i) x env
VCast x a b -> quoteCastWith go x a b
Expand All @@ -487,6 +514,7 @@ quoteWith eval = go
goNe = \case
NeVar v -> quoteNeVar v
NePrim p args -> quoteNePrimWith go p args
NeMultiPrim p args -> quoteNeMultiPrimWith go p args
NeApp x y -> quoteNeAppWith go goNe x y
NeTyApp x ty -> quoteNeTyAppWith goNe x ty
NeCase x ty xs -> quoteNeCaseWith go x ty xs
Expand Down Expand Up @@ -516,6 +544,17 @@ quotePrimWith quote p args env =
quoteArgs <- traverse (bitraverse quote pure) args
pure (NPrim p quoteArgs)

quoteMultiPrimWith
:: (Term -> Eval Nf)
-> PrimInfo
-> [Either Term Type]
-> LocalEnv
-> Eval Nf
quoteMultiPrimWith quote p args env =
withLocalEnv env $ do
quoteArgs <- traverse (bitraverse quote pure) args
pure (NMultiPrim p quoteArgs)

quoteLamWith
:: (Value -> Eval Nf)
-> (Value -> Either TermOrValue Type -> Eval Value)
Expand Down Expand Up @@ -576,6 +615,15 @@ quoteNePrimWith quote p args = do
quoteArgs <- traverse (bitraverse quote pure) args
pure (NePrim p quoteArgs)

quoteNeMultiPrimWith
:: (Value -> Eval Nf)
-> PrimInfo
-> [Either Value Type]
-> Eval (Neutral Nf)
quoteNeMultiPrimWith quote p args = do
quoteArgs <- traverse (bitraverse quote pure) args
pure (NeMultiPrim p quoteArgs)

quoteNeAppWith
:: (Value -> Eval Nf)
-> (Neutral Value -> Eval (Neutral Nf))
Expand Down Expand Up @@ -606,4 +654,3 @@ quoteNeCaseWith quote x ty xs = do
quoteX <- quote x
quoteXs <- traverse (bitraverse pure quote) xs
pure (NeCase quoteX ty quoteXs)

1 change: 1 addition & 0 deletions clash-lib/src/Clash/Core/Pretty.hs
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ instance PrettyPrec Term where
Data dc -> pprPrec prec dc
Literal l -> pprPrec prec l
Prim p -> pprPrecPrim prec (primName p)
MultiPrim p -> pprPrecPrim prec ("multi$" <> primName p)
Lam v e1 -> annotate (AnnContext $ LamBody v) <$>
pprPrecLam prec [v] e1
TyLam tv e1 -> annotate (AnnContext $ TyLamBody tv) <$>
Expand Down
25 changes: 13 additions & 12 deletions clash-lib/src/Clash/Core/Subst.hs
Original file line number Diff line number Diff line change
Expand Up @@ -915,18 +915,19 @@ acmpTerm' inScope = go (mkRnEnv inScope)

getRank :: Term -> Word
getRank = \case
Var {} -> 0
Data {} -> 1
Literal {} -> 2
Prim {} -> 3
Cast {} -> 4
App {} -> 5
TyApp {} -> 6
Lam {} -> 7
TyLam {} -> 8
Letrec {} -> 9
Case {} -> 10
Tick {} -> 11
Var {} -> 0
Data {} -> 1
Literal {} -> 2
Prim {} -> 3
Cast {} -> 4
App {} -> 5
TyApp {} -> 6
Lam {} -> 7
TyLam {} -> 8
Letrec {} -> 9
Case {} -> 10
Tick {} -> 11
MultiPrim _ -> 12

thenCompare :: Ordering -> Ordering -> Ordering
thenCompare EQ rel = rel
Expand Down
Loading

0 comments on commit f325ea0

Please sign in to comment.