Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stack-safe FreeAp #1748

Merged
merged 6 commits into from
Jul 27, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions docs/src/main/tut/datatypes/freeapplicative.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,5 +164,104 @@ val prodCompiler: FunctionK[ValidationOp, ValidateAndLog] = parCompiler and logC
val prodValidation = prog.foldMap[ValidateAndLog](prodCompiler)
```

### The way FreeApplicative#foldMap works
Despite being an imperative loop, there is a functional intuition behind `FreeApplicative#foldMap`.

The new `FreeApplicative`'s `foldMap` is a sort of mutually-recursive function that operates on an argument stack and a
function stack, where the argument stack has type `List[FreeApplicative[F, _]]` and the functions have type `List[Fn[G, _, _]]`.
`Fn[G[_, _]]` contains a function to be `Ap`'d that has already been translated to the target `Applicative`,
as well as the number of functions that were `Ap`'d immediately subsequently to it.

#### Main re-association loop
Pull an argument out of the stack, eagerly remove right-associated `Ap` nodes, by looping on the right and
adding the `Ap` nodes' arguments on the left to the argument stack; at the end, pushes a single function to the
function stack of the applied functions, the rest of which will be pushed in this loop in later iterations.
Once all of the `Ap` nodes on the right are removed, the loop resets to deal with the ones on the left.

Here's an example `FreeApplicative` value to demonstrate the loop's function, at the end of every iteration.
Every node in the tree is annotated with an identifying number and the concrete type of the node
(A -> `Ap`, L -> `Lift`, P -> `Pure`), and an apostrophe to denote where `argF` (the current argument) currently
points; as well the argument and function branches off `Ap` nodes are explicitly denoted.

```
==> begin.
'1A
/ \
arg/ \fun
/ \
/ \
2A 3A
arg/ \fun arg/ \fun
/ \ / \
4L 5P 6L 7L

args: Nil
functions: Nil
==> loop.

1A
/ \
arg/ \fun
/ \
/ \
2A '3A
arg/ \fun arg/ \fun
/ \ / \
4L 5P 6L 7L

args: 2A :: Nil
functions: Nil
==> loop.

1A
/ \
arg/ \fun
/ \
/ \
2A 3A
arg/ \fun arg/ \fun
/ \ / \
4L 5P 6L '7L

args: 6L :: 2A :: Nil
functions: Fn(gab = foldArg(7L), argc = 2) :: Nil
==> finished.
```

At the end of the loop the entire right branch of `Ap`s under `argF` has been peeled off into a single curried function,
all of the arguments to that function are on the argument stack and that function itself is on the function stack,
annotated with the amount of arguments it takes.

#### Function application loop
Once `argF` isn't an `Ap` node, a loop runs which pulls functions from the stack until it reaches a curried function,
in which case it applies the function to `argF` transformed into a `G[Any]` value, and pushes the resulting function
back to the function stack, before returning to the main loop.

I'll continue the example from before here:
```
==> loop.
1A
/ \
arg/ \fun
/ \
/ \
2A 3A
arg/ \fun arg/ \fun
/ \ / \
4L 5P '6L 7L

args: 2A :: Nil
functions: Fn(gab = foldArg(7L) ap foldArg(6L), argc = 1) :: Nil
==> finished.
```

At the end of this loop every function on the top of the function stack with `length == 1` (not curried)
has been applied to a single argument from the argument stack, and the first curried function (`length != 1`)
on the stack has been applied to a single argument from the argument stack.
The reason we can't keep applying the curried function to arguments is that the node on top of the argument
stack *must* be an `Ap` node if the function is curried, so we can't translate it directly to `G[_]`.

Once the last function has been applied to the last argument, the fold has finished and the result is returned.

## References
Deeper explanations can be found in this paper [Free Applicative Functors by Paolo Capriotti](http://www.paolocapriotti.com/assets/applicative.pdf)
169 changes: 145 additions & 24 deletions free/src/main/scala/cats/free/FreeApplicative.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,48 +4,147 @@ package free
import cats.arrow.FunctionK
import cats.data.Const

/** Applicative Functor for Free */
sealed abstract class FreeApplicative[F[_], A] extends Product with Serializable { self =>
import scala.annotation.tailrec

/**
* Applicative Functor for Free,
* implementation inspired by https://github.com/safareli/free/pull/31/
*/
sealed abstract class FreeApplicative[F[_], A] extends Product with Serializable {
self =>
// ap => apply alias needed so we can refer to both
// FreeApplicative.ap and FreeApplicative#ap
import FreeApplicative.{FA, Pure, Ap, ap => apply, lift}
import FreeApplicative.{FA, Pure, Ap, lift}

final def ap[B](b: FA[F, A => B]): FA[F, B] =
final def ap[B](b: FA[F, A => B]): FA[F, B] = {
b match {
case Pure(f) =>
this.map(f)
case Ap(pivot, fn) =>
apply(pivot)(self.ap(fn.map(fx => a => p => fx(p)(a))))
case _ =>
Ap(b, this)
}
}

final def map[B](f: A => B): FA[F, B] =
final def map[B](f: A => B): FA[F, B] = {
this match {
case Pure(a) => Pure(f(a))
case Ap(pivot, fn) => apply(pivot)(fn.map(f compose _))
case _ => Ap(Pure(f), this)
}
}

/** Interprets/Runs the sequence of operations using the semantics of Applicative G
* Tail recursive only if G provides tail recursive interpretation (ie G is FreeMonad)
*/
final def foldMap[G[_]](f: FunctionK[F, G])(implicit G: Applicative[G]): G[A] =
this match {
case Pure(a) => G.pure(a)
case Ap(pivot, fn) => G.map2(f(pivot), fn.foldMap(f))((a, g) => g(a))
/** Interprets/Runs the sequence of operations using the semantics of `Applicative` G[_].
* Tail recursive.
*/
// scalastyle:off method.length
final def foldMap[G[_]](f: F ~> G)(implicit G: Applicative[G]): G[A] = {
import FreeApplicative._
// the remaining arguments to G[A => B]'s
var argsF: List[FA[F, Any]] = this.asInstanceOf[FA[F, Any]] :: Nil
var argsFLength: Int = 1
// the remaining stack of G[A => B]'s to be applied to the arguments
var fns: List[Fn[G, Any, Any]] = Nil
var fnsLength: Int = 0

@tailrec
def loop(): G[Any] = {
var argF: FA[F, Any] = argsF.head
argsF = argsF.tail
argsFLength -= 1

// rip off every `Ap` in `argF` in function position
if (argF.isInstanceOf[Ap[F, _, _]]) {
val lengthInitial = argsFLength
// reassociate the functions into a single fn,
// and move the arguments into argsF
do {
val ap = argF.asInstanceOf[Ap[F, Any, Any]]
argsF ::= ap.fp
argsFLength += 1
argF = ap.fn.asInstanceOf[FA[F, Any]]
} while (argF.isInstanceOf[Ap[F, _, _]])
// consecutive `ap` calls have been queued as operations;
// argF is no longer an `Ap` node, so the entire topmost left-associated
// function application branch has been looped through and we've
// moved (`argsFLength` - `lengthInitial`) arguments to the stack, through
// (`argsFLength` - `lengthInitial`) `Ap` nodes, so the function on the right
// which consumes them all must have (`argsFLength` - `lengthInitial`) arguments
val argc = argsFLength - lengthInitial
fns ::= Fn[G, Any, Any](foldArg(argF.asInstanceOf[FA[F, Any => Any]], f), argc)
fnsLength += 1
loop()
} else {
val argT: G[Any] = foldArg(argF, f)
if (fns ne Nil) {
// single right-associated function application
var fn = fns.head
fns = fns.tail
fnsLength -= 1
var res = G.ap(fn.gab)(argT)
if (fn.argc > 1) {
// this function has more than 1 argument,
// bail out of nested right-associated function application
fns ::= Fn(res.asInstanceOf[G[Any => Any]], fn.argc - 1)
fnsLength += 1
loop()
} else {
if (fnsLength > 0) {
// we've got a nested right-associated `Ap` tree,
// so apply as many functions as possible
@tailrec
def innerLoop(): Unit = {
fn = fns.head
fns = fns.tail
fnsLength -= 1
res = G.ap(fn.gab)(res)
if (fn.argc > 1) {
fns ::= Fn(res.asInstanceOf[G[Any => Any]], fn.argc - 1)
fnsLength += 1
}
// we have to bail out if fn has more than one argument,
// because it means we may have more left-associated trees
// deeper to the right in the application tree
if (fn.argc == 1 && fnsLength > 0) innerLoop()
}

innerLoop()
}
if (fnsLength == 0) res
else loop()
}
} else argT
}
}

/** Interpret/run the operations using the semantics of `Applicative[F]`.
* Tail recursive only if `F` provides tail recursive interpretation.
*/
loop().asInstanceOf[G[A]]
}
// scalastyle:on method.length


/**
* Interpret/run the operations using the semantics of `Applicative[F]`.
* Stack-safe.
*/
final def fold(implicit F: Applicative[F]): F[A] =
foldMap(FunctionK.id[F])

/** Interpret this algebra into another FreeApplicative */
final def compile[G[_]](f: FunctionK[F, G]): FA[G, A] =
/**
* Interpret this algebra into another algebra.
* Stack-safe.
*/
final def compile[G[_]](f: F ~> G): FA[G, A] =
foldMap[FA[G, ?]] {
λ[FunctionK[F, FA[G, ?]]](fa => lift(f(fa)))
}

/** Interpret this algebra into a Monoid */

/**
* Interpret this algebra into a FreeApplicative over another algebra.
* Stack-safe.
*/
def flatCompile[G[_]](f: F ~> FA[G, ?]): FA[G, A] =
foldMap(f)

/** Interpret this algebra into a Monoid. */
final def analyze[M: Monoid](f: FunctionK[F, λ[α => M]]): M =
foldMap[Const[M, ?]](
λ[FunctionK[F, Const[M, ?]]](x => Const(f(x)))
Expand All @@ -63,23 +162,45 @@ sealed abstract class FreeApplicative[F[_], A] extends Product with Serializable
object FreeApplicative {
type FA[F[_], A] = FreeApplicative[F, A]

// Internal helper function for foldMap, it folds only Pure and Lift nodes
private[free] def foldArg[F[_], G[_], A](node: FA[F, A], f: F ~> G)(implicit G: Applicative[G]): G[A] =
if (node.isInstanceOf[Pure[F, A]]) {
val Pure(x) = node
G.pure(x)
} else {
val Lift(fa) = node
f(fa)
}

/** Represents a curried function `F[A => B => C => ...]`
* that has been constructed with chained `ap` calls.
* Fn#argc denotes the amount of curried params remaining.
*/
private final case class Fn[G[_], A, B](gab: G[A => B], argc: Int)

private final case class Pure[F[_], A](a: A) extends FA[F, A]

private final case class Ap[F[_], P, A](pivot: F[P], fn: FA[F, P => A]) extends FA[F, A]
private final case class Lift[F[_], A](fa: F[A]) extends FA[F, A]

private final case class Ap[F[_], P, A](fn: FA[F, P => A], fp: FA[F, P]) extends FA[F, A]

final def pure[F[_], A](a: A): FA[F, A] =
Pure(a)

final def ap[F[_], P, A](fp: F[P])(f: FA[F, P => A]): FA[F, A] = Ap(fp, f)
final def ap[F[_], P, A](fp: F[P])(f: FA[F, P => A]): FA[F, A] =
Ap(f, Lift(fp))

final def lift[F[_], A](fa: F[A]): FA[F, A] =
ap(fa)(Pure(a => a))
Lift(fa)

implicit final def freeApplicative[S[_]]: Applicative[FA[S, ?]] = {
new Applicative[FA[S, ?]] {
override def product[A, B](fa: FA[S, A], fb: FA[S, B]): FA[S, (A, B)] = ap(fa.map((a: A) => (b: B) => (a, b)))(fb)

override def map[A, B](fa: FA[S, A])(f: A => B): FA[S, B] = fa.map(f)

override def ap[A, B](f: FA[S, A => B])(fa: FA[S, A]): FA[S, B] = fa.ap(f)

def pure[A](a: A): FA[S, A] = Pure(a)
}
}
Expand Down
Loading