Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR] Don't generate 'fixBy' if you don't need to #5954

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,51 @@ applyFun = runQuote $ do
. lamAbs () x (TyVar () a)
$ apply () (var () f) (var () x)

{- Note [Recursion combinators]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, I confused branches and added the Note in this one rather than the strictify-fix one. Doesn't matter much, I guess.

We create singly recursive and mutually recursive functions using different combinators.

For singly recursive functions we use the Z combinator (a strict cousin of the Y combinator) that in
UPLC looks like this:

\f -> (\s -> s s) (\s -> f (\x -> s s x))

We have benchmarked its Haskell version at
https://github.com/IntersectMBO/plutus/tree/9538fc9829426b2ecb0628d352e2d7af96ec8204/doc/notes/fomega/z-combinator-benchmarks
and observed that in Haskell there's no detectable difference in performance of functions defined
using explicit recursion versus the Z combinator. However Haskell is a compiled language and Plutus
is interpreted, so it's very likely that natively supporting recursion in Plutus instead of
compiling recursive functions to combinators would significantly boost performance.

We've tried using

\f -> (\s -> s s) (\s x -> f (s s) x)

instead of

\f -> (\s -> s s) (\s -> f (\x -> s s x))

and while it worked OK at the PLC level, it wasn't a suitable primitive for compilation of recursive
functions, because it would add laziness in unexpected places, see
https://github.com/IntersectMBO/plutus/issues/5961
so we had to change it.

We use

\f -> (\s -> s s) (\s x -> f (s s) x)

instead of the more standard

\f -> (\s x -> f (s s) x) (\s x -> f (s s) x)

because in practice @f@ gets inlined and we wouldn't be able to do so if it occurred twice in the
term. Plus the former also allows us to save on the size of the term.

For mutually recursive functions we use the 'fixBy' combinator, which is, to the best of our
knowledge, our own invention. It was first described at
https://github.com/IntersectMBO/plutus/blob/067e74f0606fddc5e183dd45209b461e293a6224/doc/notes/fomega/mutual-term-level-recursion/FixN.agda
and fully specified in our "Unraveling recursion: compiling an IR with recursion to System F" paper.
-}

-- | @Self@ as a PLC type.
--
-- > fix \(self :: * -> *) (a :: *) -> self a -> a
Expand Down Expand Up @@ -144,7 +189,6 @@ fixAndType = runQuote $ do
$ TyFun () (TyFun () funAB funAB) funAB
pure (fixTerm, fixType)


-- | A type that looks like a transformation.
--
-- > trans F G Q : F Q -> G Q
Expand Down Expand Up @@ -337,6 +381,7 @@ fixNAndType n fixByTerm = runQuote $ do
]
pure (fixNTerm, fixNType)

-- See Note [Recursion combinators].
-- | Get the fixed-point of a single recursive function.
getSingleFixOf
:: (TermLike term TyName Name uni fun)
Expand All @@ -346,6 +391,7 @@ getSingleFixOf ann fix1 fun@FunctionDef{_functionDefType=(FunctionType _ dom cod
abstractedBody = mkIterLamAbs [functionDefVarDecl fun] $ _functionDefTerm fun
in apply ann instantiatedFix abstractedBody

-- See Note [Recursion combinators].
-- | Get the fixed-point of a list of mutually recursive functions.
--
-- > MutualFixOf _ fixN [ FunctionDef _ fN1 (FunctionType _ a1 b1) f1
Expand Down
11 changes: 6 additions & 5 deletions plutus-core/plutus-ir/src/PlutusIR/Compiler/Recursion.hs
Original file line number Diff line number Diff line change
Expand Up @@ -98,15 +98,16 @@ mkFixpoint bs = do
name <- liftQuote $ toProgramName fixByKey
let (fixByTerm, fixByType) = Function.fixByAndType
pure (PLC.Def (PLC.VarDecl noProvenance name (noProvenance <$ fixByType)) (noProvenance <$ fixByTerm, Strict), mempty)
fixBy <- lookupOrDefineTerm p0 fixByKey mkFixByDef
Unisay marked this conversation as resolved.
Show resolved Hide resolved

let mkFixNDef = do
name <- liftQuote $ toProgramName fixNKey
let ((fixNTerm, fixNType), fixNDeps) =
if arity == 1
then (Function.fixAndType, mempty)
((fixNTerm, fixNType), fixNDeps) <-
if arity == 1
then pure (Function.fixAndType, mempty)
-- fixN depends on fixBy
else (Function.fixNAndType arity (void fixBy), Set.singleton fixByKey)
else do
fixBy <- lookupOrDefineTerm p0 fixByKey mkFixByDef
pure (Function.fixNAndType arity (void fixBy), Set.singleton fixByKey)
pure (PLC.Def (PLC.VarDecl noProvenance name (noProvenance <$ fixNType)) (noProvenance <$ fixNTerm, Strict), fixNDeps)
fixN <- lookupOrDefineTerm p0 fixNKey mkFixNDef

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,27 @@
1.1.0
[
[
(lam s_1651 [ s_1651 s_1651 ])
(lam s_1609 [ s_1609 s_1609 ])
(lam
s_1652
s_1610
(lam
i_1653
i_1611
[
[
[
[
(force (builtin ifThenElse))
[ [ (builtin equalsInteger) (con integer 0) ] i_1653 ]
[ [ (builtin equalsInteger) (con integer 0) ] i_1611 ]
]
(lam u_1654 (con integer 1))
(lam u_1612 (con integer 1))
]
(lam
u_1655
u_1613
[
[ (builtin multiplyInteger) i_1653 ]
[ (builtin multiplyInteger) i_1611 ]
[
(lam x_1656 [ [ s_1652 s_1652 ] x_1656 ])
[ [ (builtin subtractInteger) i_1653 ] (con integer 1) ]
(lam x_1614 [ [ s_1610 s_1610 ] x_1614 ])
[ [ (builtin subtractInteger) i_1611 ] (con integer 1) ]
]
]
)
Expand Down
Loading