diff --git a/free/src/main/scala/cats/free/FreeT.scala b/free/src/main/scala/cats/free/FreeT.scala index a830eeaa62..caa8b9e69d 100644 --- a/free/src/main/scala/cats/free/FreeT.scala +++ b/free/src/main/scala/cats/free/FreeT.scala @@ -234,11 +234,20 @@ sealed abstract private[free] class FreeTInstances extends FreeTInstances0 { FreeT.liftT(E.raiseError[A](e))(M) } + // not to be confused with defer... which is something different... sigh... + implicit def catsDeferForFreeT[S[_], M[_]: Applicative]: Defer[FreeT[S, M, *]] = + new Defer[FreeT[S, M, *]] { + def defer[A](fa: => FreeT[S, M, A]): FreeT[S, M, A] = + FreeT.pure[S, M, Unit](()).flatMap(_ => fa) + } + implicit def catsFreeMonadErrorForFreeT2[S[_], M[_], E](implicit E: MonadError[M, E], S: Functor[S]): MonadError[FreeT[S, M, *], E] = new MonadError[FreeT[S, M, *], E] with FreeTMonad[S, M] { override def M = E + private[this] val RealDefer = catsDeferForFreeT[S, M] + /* * Quick explanation... The previous version of this function (retained above for * bincompat) was only able to look at the *top* level M[_] suspension in a Free @@ -263,7 +272,9 @@ sealed abstract private[free] class FreeTInstances extends FreeTInstances0 { val ft = FreeT.liftT[S, M, FreeT[S, M, A]] { val resultsM = E.map(fa.resume) { case Left(se) => - FreeT.liftF[S, M, FreeT[S, M, A]](S.map(se)(handleErrorWith(_)(f))).flatMap(identity) + // we defer here in order to ensure stack-safety in the results even when M[_] is not itself stack-safe + // there's some small performance loss as a consequence, but really, if you care that much about performance, why are you using FreeT? + RealDefer.defer(FreeT.liftF[S, M, FreeT[S, M, A]](S.map(se)(handleErrorWith(_)(f))).flatMap(identity)) case Right(a) => pure(a) diff --git a/free/src/test/scala/cats/free/FreeTSuite.scala b/free/src/test/scala/cats/free/FreeTSuite.scala index 5e211133f9..2210a23816 100644 --- a/free/src/test/scala/cats/free/FreeTSuite.scala +++ b/free/src/test/scala/cats/free/FreeTSuite.scala @@ -44,6 +44,8 @@ class FreeTSuite extends CatsSuite { SerializableTests.serializable(MonadError[FreeTOption, Unit])) } + checkAll("FreeT[Option, Option, Int", DeferTests[FreeTOption].defer[Int]) + test("FlatMap stack safety tested with 50k flatMaps") { val expected = Applicative[FreeTOption].pure(()) val result =