Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #92: Sync.map should suspend evaluation #96

Merged
merged 4 commits into from
Dec 14, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 31 additions & 25 deletions core/shared/src/main/scala/cats/effect/IO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -296,34 +296,40 @@ private[effect] trait IOLowPriorityInstances {
private[effect] trait IOInstances extends IOLowPriorityInstances {

implicit val ioEffect: Effect[IO] = new Effect[IO] {

def pure[A](a: A) = IO.pure(a)

def flatMap[A, B](ioa: IO[A])(f: A => IO[B]): IO[B] = ioa.flatMap(f)

// this will use stack proportional to the maximum number of joined async suspensions
def tailRecM[A, B](a: A)(f: A => IO[Either[A, B]]): IO[B] = f(a) flatMap {
case Left(a) => tailRecM(a)(f)
case Right(b) => pure(b)
}

override def attempt[A](ioa: IO[A]): IO[Either[Throwable, A]] = ioa.attempt

def handleErrorWith[A](ioa: IO[A])(f: Throwable => IO[A]): IO[A] =
override def pure[A](a: A): IO[A] =
IO.pure(a)
override def flatMap[A, B](ioa: IO[A])(f: A => IO[B]): IO[B] =
ioa.flatMap(f)
override def map[A, B](fa: IO[A])(f: A => B): IO[B] =
fa.map(f)
override def delay[A](thunk: => A): IO[A] =
IO(thunk)
override def unit: IO[Unit] =
IO.unit
override def attempt[A](ioa: IO[A]): IO[Either[Throwable, A]] =
ioa.attempt
override def handleErrorWith[A](ioa: IO[A])(f: Throwable => IO[A]): IO[A] =
IO.Bind(ioa, IOFrame.errorHandler(f))

def raiseError[A](e: Throwable): IO[A] = IO.raiseError(e)

def suspend[A](thunk: => IO[A]): IO[A] = IO.suspend(thunk)

def async[A](k: (Either[Throwable, A] => Unit) => Unit): IO[A] = IO.async(k)

def runAsync[A](ioa: IO[A])(cb: Either[Throwable, A] => IO[Unit]): IO[Unit] = ioa.runAsync(cb)

override def raiseError[A](e: Throwable): IO[A] =
IO.raiseError(e)
override def suspend[A](thunk: => IO[A]): IO[A] =
IO.suspend(thunk)
override def async[A](k: (Either[Throwable, A] => Unit) => Unit): IO[A] =
IO.async(k)
override def runAsync[A](ioa: IO[A])(cb: Either[Throwable, A] => IO[Unit]): IO[Unit] =
ioa.runAsync(cb)
// creates a new call-site, so *very* slightly faster than using the default
override def shift(implicit ec: ExecutionContext) = IO.shift(ec)
override def shift(implicit ec: ExecutionContext): IO[Unit] =
IO.shift(ec)
override def liftIO[A](ioa: IO[A]): IO[A] =
ioa

override def liftIO[A](ioa: IO[A]) = ioa
// this will use stack proportional to the maximum number of joined async suspensions
override def tailRecM[A, B](a: A)(f: A => IO[Either[A, B]]): IO[B] =
f(a) flatMap {
case Left(a) => tailRecM(a)(f)
case Right(b) => pure(b)
}
}

implicit def ioMonoid[A: Monoid]: Monoid[IO[A]] = new IOSemigroup[A] with Monoid[IO[A]] {
Expand Down
32 changes: 31 additions & 1 deletion laws/shared/src/main/scala/cats/effect/laws/SyncLaws.scala
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,26 @@ trait SyncLaws[F[_]] extends MonadErrorLaws[F, Throwable] {
fa <-> F.raiseError(t)
}

def bindSuspendsEvaluation[A](fa: F[A], a1: A, f: (A, A) => A) = {
var state = a1
val evolve = F.flatMap(fa) { a2 =>
state = f(state, a2)
F.pure(state)
}
// Observing `state` before and after `evolve`
F.map2(F.pure(state), evolve)(f) <-> F.map(fa)(a2 => f(a1, f(a1, a2)))
}

def mapSuspendsEvaluation[A](fa: F[A], a1: A, f: (A, A) => A) = {
var state = a1
val evolve = F.map(fa) { a2 =>
state = f(state, a2)
state
}
// Observing `state` before and after `evolve`
F.map2(F.pure(state), evolve)(f) <-> F.map(fa)(a2 => f(a1, f(a1, a2)))
}

lazy val stackSafetyOnRepeatedLeftBinds = {
val result = (0 until 10000).foldLeft(F.delay(())) { (acc, _) =>
acc.flatMap(_ => F.delay(()))
Expand All @@ -75,12 +95,22 @@ trait SyncLaws[F[_]] extends MonadErrorLaws[F, Throwable] {
}

lazy val stackSafetyOnRepeatedAttempts = {
// Note this isn't enough to guarantee stack safety, unless
// coupled with `bindSuspendsEvaluation`
val result = (0 until 10000).foldLeft(F.delay(())) { (acc, _) =>
F.attempt(acc).map(_ => ())
}

result <-> F.pure(())
}

lazy val stackSafetyOnRepeatedMaps = {
// Note this isn't enough to guarantee stack safety, unless
// coupled with `mapSuspendsEvaluation`
val result = (0 until 10000).foldLeft(F.delay(0)) { (acc, _) =>
F.map(acc)(_ + 1)
}
result <-> F.pure(10000)
}
}

object SyncLaws {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,15 @@ trait SyncTests[F[_]] extends MonadErrorTests[F, Throwable] with TestsPlatform {
"throw in suspend is raiseError" -> forAll(laws.suspendThrowIsRaiseError[A] _),
"unsequenced delay is no-op" -> forAll(laws.unsequencedDelayIsNoop[A] _),
"repeated sync evaluation not memoized" -> forAll(laws.repeatedSyncEvaluationNotMemoized[A] _),
"propagate errors through bind (suspend)" -> forAll(laws.propagateErrorsThroughBindSuspend[A] _))
"propagate errors through bind (suspend)" -> forAll(laws.propagateErrorsThroughBindSuspend[A] _),
"bind suspends evaluation" -> forAll(laws.bindSuspendsEvaluation[A] _),
"map suspends evaluation" -> forAll(laws.mapSuspendsEvaluation[A] _))

val jvmProps = Seq(
"stack-safe on left-associated binds" -> Prop.lzy(laws.stackSafetyOnRepeatedLeftBinds),
"stack-safe on right-associated binds" -> Prop.lzy(laws.stackSafetyOnRepeatedRightBinds),
"stack-safe on repeated attempts" -> Prop.lzy(laws.stackSafetyOnRepeatedAttempts))
"stack-safe on repeated attempts" -> Prop.lzy(laws.stackSafetyOnRepeatedAttempts),
"stack-safe on repeated maps" -> Prop.lzy(laws.stackSafetyOnRepeatedMaps))

val jsProps = Seq()

Expand Down