Skip to content

Commit

Permalink
Stack-safe FreeApplicative
Browse files Browse the repository at this point in the history
  • Loading branch information
edmundnoble committed Jul 5, 2017
1 parent 1931c93 commit 5e6adae
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 25 deletions.
162 changes: 137 additions & 25 deletions free/src/main/scala/cats/free/FreeApplicative.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,82 +4,194 @@ package free
import cats.arrow.FunctionK
import cats.data.Const

import scala.annotation.tailrec

/** Applicative Functor for Free */
sealed abstract class FreeApplicative[F[_], A] extends Product with Serializable { self =>
sealed abstract class FreeApplicative[F[_], A] extends Product with Serializable {
self =>
// ap => apply alias needed so we can refer to both
// FreeApplicative.ap and FreeApplicative#ap
import FreeApplicative.{FA, Pure, Ap, ap => apply, lift}
import FreeApplicative.{FA, Pure, Ap, lift}

final def ap[B](b: FA[F, A => B]): FA[F, B] =
final def ap[B](b: FA[F, A => B]): FA[F, B] = {
b match {
case Pure(f) =>
this.map(f)
case Ap(pivot, fn) =>
apply(pivot)(self.ap(fn.map(fx => a => p => fx(p)(a))))
case _ =>
Ap(b, this)
}
}

final def map[B](f: A => B): FA[F, B] =
final def map[B](f: A => B): FA[F, B] = {
this match {
case Pure(a) => Pure(f(a))
case Ap(pivot, fn) => apply(pivot)(fn.map(f compose _))
case _ => Ap(Pure(f), this)
}
}

/** Interprets/Runs the sequence of operations using the semantics of Applicative G
* Tail recursive only if G provides tail recursive interpretation (ie G is FreeMonad)
*/
final def foldMap[G[_]](f: FunctionK[F, G])(implicit G: Applicative[G]): G[A] =
this match {
case Pure(a) => G.pure(a)
case Ap(pivot, fn) => G.map2(f(pivot), fn.foldMap(f))((a, g) => g(a))
/** Interprets/Runs the sequence of operations using the semantics of `Applicative` G[_].
* Tail recursive.
*/
// scalastyle:off method.length
final def foldMap[G[_]](f: F ~> G)(implicit G: Applicative[G]): G[A] = {
import FreeApplicative._
// the remaining arguments to G[A => B]'s
var argsF: List[FA[F, Any]] = this.asInstanceOf[FA[F, Any]] :: Nil
var argsFLength: Int = 1
// the remaining stack of G[A => B]'s to be applied to the arguments
// Fn#length denotes the amount of curried params remaining
var fns: List[Fn[G, Any, Any]] = Nil
var fnsLength: Int = 0

@tailrec
def loop(): G[Any] = {
var argF: FA[F, Any] = argsF.head
argsF = argsF.tail
argsFLength -= 1

// rip off every `Ap` in `argF`, peeling off left-associated prefixes
if (argF.isInstanceOf[Ap[F, _, _]]) {
val lengthInitial = argsFLength
// reassociate the functions into a single fn,
// and move the arguments into argsF
do {
argF match {
case Ap(fn, fp) =>
argsF ::= fp.asInstanceOf[FA[F, Any]]
argsFLength += 1
argF = fn.asInstanceOf[FA[F, Any]]
case _ => ()
}
} while (argF.isInstanceOf[Ap[F, _, _]])
// argF is no longer an `Ap` node, so the entire topmost
// left-associated application branch has been looped through
// we've moved (`argsFLength` - `lengthInitial`) arguments to the stack, through
// (`argsFLength` - `lengthInitial`) `Ap` nodes, thus the function that consumes
// them all must have (`argsFLength` - `lengthInitial`) arguments
val fnLength = argsFLength - lengthInitial
fns ::= Fn[G, Any, Any](foldArg(argF.asInstanceOf[FA[F, Any => Any]], f), fnLength)
fnsLength += 1
loop()
} else {
val argT: G[Any] = foldArg(argF, f)
if (fns ne Nil) {
// single right-associated function application
var fn = fns.head
fns = fns.tail
fnsLength -= 1
var res = G.ap(fn.gab)(argT)
if (fn.length > 1) {
// this function has more than 1 argument,
// bail out of nested right-associated function application
fns ::= Fn(res.asInstanceOf[G[Any => Any]], fn.length - 1)
fnsLength += 1
loop()
} else {
if (fnsLength > 0) {
// we've got a nested right-associated `Ap` tree,
// so apply as many functions as possible
@tailrec
def innerLoop(): Unit = {
fn = fns.head
fns = fns.tail
fnsLength -= 1
res = G.ap(fn.gab)(res)
if (fn.length > 1) {
fns ::= Fn(res.asInstanceOf[G[Any => Any]], fn.length - 1)
fnsLength += 1
res = G.ap(fn.gab)(res)
}
if (fn.length == 1 && fnsLength > 0) innerLoop()
}
innerLoop()
}
if (fnsLength == 0) res
else loop()
}
} else argT
}
}

loop().asInstanceOf[G[A]]
}
// scalastyle:on method.length


/** Interpret/run the operations using the semantics of `Applicative[F]`.
* Tail recursive only if `F` provides tail recursive interpretation.
*/
final def fold(implicit F: Applicative[F]): F[A] =
* Tail recursive only if `F` provides tail recursive interpretation.
*/
final def fold(implicit F: Applicative[F]): F[A] = {
foldMap(FunctionK.id[F])
}

/** Interpret this algebra into another FreeApplicative */
final def compile[G[_]](f: FunctionK[F, G]): FA[G, A] =
final def compile[G[_]](f: F ~> G): FA[G, A] = {
foldMap[FA[G, ?]] {
λ[FunctionK[F, FA[G, ?]]](fa => lift(f(fa)))
}
}

def flatCompile[G[_]](f: F ~> FA[G, ?]): FA[G, A] = {
foldMap(f)
}

/** Interpret this algebra into a Monoid */
final def analyze[M:Monoid](f: FunctionK[F, λ[α => M]]): M =
final def analyze[M: Monoid](f: FunctionK[F, λ[α => M]]): M = {
foldMap[Const[M, ?]](
λ[FunctionK[F, Const[M, ?]]](x => Const(f(x)))
).getConst
}

/** Compile this FreeApplicative algebra into a Free algebra. */
final def monad: Free[F, A] =
final def monad: Free[F, A] = {
foldMap[Free[F, ?]] {
λ[FunctionK[F, Free[F, ?]]](fa => Free.liftF(fa))
}
}

override def toString: String = "FreeApplicative(...)"
}

object FreeApplicative {
type FA[F[_], A] = FreeApplicative[F, A]

// Internal helper function for foldMap, it folds only Pure and Lift nodes
private[free] def foldArg[F[_], G[_], A](node: FA[F, A], f: F ~> G)(implicit G: Applicative[G]): G[A] = {
node match {
case Pure(x) =>
G.pure(x)
case Lift(fa) =>
f(fa)
case _ => sys.error("\"impossible\", foldArg should be called on Pure or Lift nodes only")
}
}

private final case class Fn[G[_], A, B](gab: G[A => B], length: Int)

private final case class Pure[F[_], A](a: A) extends FA[F, A]

private final case class Ap[F[_], P, A](pivot: F[P], fn: FA[F, P => A]) extends FA[F, A]
private final case class Lift[F[_], A](fa: F[A]) extends FA[F, A]

private final case class Ap[F[_], P, A](fn: FA[F, P => A], fp: FA[F, P]) extends FA[F, A]

final def pure[F[_], A](a: A): FA[F, A] =
final def pure[F[_], A](a: A): FA[F, A] = {
Pure(a)
}

final def ap[F[_], P, A](fp: F[P])(f: FA[F, P => A]): FA[F, A] = Ap(fp, f)
final def ap[F[_], P, A](fp: F[P])(f: FA[F, P => A]): FA[F, A] = Ap(f, Lift(fp))

final def lift[F[_], A](fa: F[A]): FA[F, A] =
ap(fa)(Pure(a => a))
final def lift[F[_], A](fa: F[A]): FA[F, A] = {
Lift(fa)
}

implicit final def freeApplicative[S[_]]: Applicative[FA[S, ?]] = {
new Applicative[FA[S, ?]] {
override def product[A, B](fa: FA[S, A], fb: FA[S, B]): FA[S, (A, B)] = ap(fa.map((a: A) => (b: B) => (a, b)))(fb)

override def map[A, B](fa: FA[S, A])(f: A => B): FA[S, B] = fa.map(f)

override def ap[A, B](f: FA[S, A => B])(fa: FA[S, A]): FA[S, B] = fa.ap(f)

def pure[A](a: A): FA[S, A] = Pure(a)
}
}
Expand Down
25 changes: 25 additions & 0 deletions free/src/test/scala/cats/free/FreeApplicativeTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ class FreeApplicativeTests extends CatsSuite {
rr.toString.length should be > 0
}

test("fold/map is stack-safe") {
val r = FreeApplicative.lift[List, Int](List(333))
val rr = (1 to 10000).foldLeft(r)((r, _) => r.ap(FreeApplicative.lift[List, Int => Int](List((_: Int) + 1))))
rr.fold should be (List(333 + 10000))
val rx = (1 to 10000).foldRight(r)((_, r) => r.ap(FreeApplicative.lift[List, Int => Int](List((_: Int) + 1))))
rx.fold should be (List(333 + 10000))
}

test("FreeApplicative#fold") {
val n = 2
val o1 = Option(1)
Expand All @@ -56,6 +64,17 @@ class FreeApplicativeTests extends CatsSuite {
r1.foldMap(nt) should === (r2.foldMap(nt))
}

test("FreeApplicative#flatCompile") {
val x = FreeApplicative.lift[Id, Int](1)
val y = FreeApplicative.pure[Id, Int](2)
val f = x.map(i => (j: Int) => i + j)
val nt: Id ~> FreeApplicative[Id, ?] = new FunctionK[Id, FreeApplicative[Id, ?]] {
def apply[A](a: Id[A]): FreeApplicative[Id, A] = FreeApplicative.pure(a)
}
val r1 = y.ap(f)
r1.foldMap[FreeApplicative[Id, ?]](nt)(FreeApplicative.freeApplicative[Id]).fold should === (r1.flatCompile[Id](nt).fold)
}

test("FreeApplicative#monad") {
val x = FreeApplicative.lift[Id, Int](1)
val y = FreeApplicative.pure[Id, Int](2)
Expand All @@ -66,6 +85,12 @@ class FreeApplicativeTests extends CatsSuite {
r1.foldMap(nt) should === (r2.foldMap(nt))
}

test("FreeApplicative#ap") {
val x = FreeApplicative.ap[Id, Int, Int](1)(FreeApplicative.pure((_: Int) + 1))
val y = FreeApplicative.lift[Id, Int](1).ap(FreeApplicative.pure((_: Int) + 1))
x should === (y)
}

// Ensure that syntax and implicit resolution work as expected.
// If it compiles, it passes the "test".
object SyntaxTests {
Expand Down

0 comments on commit 5e6adae

Please sign in to comment.