Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add/fix Foldable extensions: findM and collectFirstSomeM #2421

Merged
merged 4 commits into from
Aug 28, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions core/src/main/scala/cats/Foldable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -601,11 +601,11 @@ object Foldable {
* It could be made a value class after
* https://github.com/scala/bug/issues/9600 is resolved.
*/
private sealed abstract class Source[+A] {
private[cats] sealed abstract class Source[+A] {
def uncons: Option[(A, Eval[Source[A]])]
}

private object Source {
private[cats] object Source {
val Empty: Source[Nothing] = new Source[Nothing] {
def uncons = None
}
Expand Down
58 changes: 48 additions & 10 deletions core/src/main/scala/cats/syntax/foldable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -95,29 +95,67 @@ final class FoldableOps[F[_], A](val fa: F[A]) extends AnyVal {

/**
* Monadic version of `collectFirstSome`.
*
* If there are no elements, the result is `None`. `collectFirstSomeM` short-circuits,
* i.e. once a Some element is found, no further effects are produced.
*
* For example:
* {{{
* scala> import cats.implicits._
* scala> def parseInt(s: String): Either[String, Int] = Either.catchOnly[NumberFormatException](s.toInt).leftMap(_.getMessage)
* scala> val keys1 = List("1", "2", "4", "5")
* scala> val map1 = Map(4 -> "Four", 5 -> "Five")
* scala> keys1.collectFirstSomeM(parseInt(_) map map1.get)
* res1: scala.util.Either[String,Option[String]] = Right(Some(Four))
* res0: scala.util.Either[String,Option[String]] = Right(Some(Four))
*
* scala> val map2 = Map(6 -> "Six", 7 -> "Seven")
* scala> keys1.collectFirstSomeM(parseInt(_) map map2.get)
* res2: scala.util.Either[String,Option[String]] = Right(None)
* res1: scala.util.Either[String,Option[String]] = Right(None)
*
* scala> val keys2 = List("1", "x", "4", "5")
* scala> keys2.collectFirstSomeM(parseInt(_) map map1.get)
* res3: scala.util.Either[String,Option[String]] = Left(For input string: "x")
* res2: scala.util.Either[String,Option[String]] = Left(For input string: "x")
*
* scala> val keys3 = List("1", "2", "4", "x")
* scala> keys3.collectFirstSomeM(parseInt(_) map map1.get)
* res4: scala.util.Either[String,Option[String]] = Right(Some(Four))
* res3: scala.util.Either[String,Option[String]] = Right(Some(Four))
* }}}
*/
def collectFirstSomeM[G[_], B](f: A => G[Option[B]])(implicit F: Foldable[F], G: Monad[G]): G[Option[B]] =
F.foldRight(fa, Eval.now(G.pure(Option.empty[B])))((a, lb) =>
Eval.now(G.flatMap(f(a)) {
case None => lb.value
case s => G.pure(s)
})
).value
G.tailRecM(Foldable.Source.fromFoldable(fa))(_.uncons match {
case Some((a, src)) => G.map(f(a)) {
case None => Left(src.value)
case s => Right(s)
}
case None => G.pure(Right(None))
})

/**
* Find the first element matching the effectful predicate, if one exists.
*
* If there are no elements, the result is `None`. `findM` short-circuits,
* i.e. once an element is found, no further effects are produced.
*
* For example:
* {{{
* scala> import cats.implicits._
* scala> val list = List(1,2,3,4)
* scala> list.findM(n => (n >= 2).asRight[String])
* res0: Either[String,Option[Int]] = Right(Some(2))
*
* scala> list.findM(n => (n > 4).asRight[String])
* res1: Either[String,Option[Int]] = Right(None)
*
* scala> list.findM(n => Either.cond(n < 3, n >= 2, "error"))
* res2: Either[String,Option[Int]] = Right(Some(2))
*
* scala> list.findM(n => Either.cond(n < 3, false, "error"))
* res3: Either[String,Option[Int]] = Left(error)
* }}}
*/
def findM[G[_]](p: A => G[Boolean])(implicit F: Foldable[F], G: Monad[G]): G[Option[A]] =
G.tailRecM(Foldable.Source.fromFoldable(fa))(_.uncons match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

couldn't this be collectFirstSome { a => p(a).map(if (_) Some(a) else None) }?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure it could but would require an extra map for each iteration. I just followed the recommendation of other contributors to prefer efficiency over conciseness in code.

case Some((a, src)) => G.map(p(a))(if (_) Right(Some(a)) else Left(src.value))
case None => G.pure(Right(None))
})
}
74 changes: 54 additions & 20 deletions tests/src/test/scala/cats/tests/FoldableSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,12 @@ abstract class FoldableSuite[F[_]: Foldable](name: String)(
}
}

test(s"Foldable[$name].find/exists/forall/existsM/forallM/filter_/dropWhile_") {
test(s"Foldable[$name].find/exists/forall/findM/existsM/forallM/filter_/dropWhile_") {
forAll { (fa: F[Int], n: Int) =>
fa.find(_ > n) should === (iterator(fa).find(_ > n))
fa.exists(_ > n) should === (iterator(fa).exists(_ > n))
fa.forall(_ > n) should === (iterator(fa).forall(_ > n))
fa.findM(k => Option(k > n)) should === (Option(iterator(fa).find(_ > n)))
fa.existsM(k => Option(k > n)) should === (Option(iterator(fa).exists(_ > n)))
fa.forallM(k => Option(k > n)) should === (Option(iterator(fa).forall(_ > n)))
fa.filter_(_ > n) should === (iterator(fa).filter(_ > n).toList)
Expand Down Expand Up @@ -199,16 +200,42 @@ class FoldableSuiteAdditional extends CatsSuite {
larger.value should === (large.map(_ + 1))
}

def checkFoldMStackSafety[F[_]](fromRange: Range => F[Int])(implicit F: Foldable[F]): Unit = {
def checkMonadicFoldsStackSafety[F[_]](fromRange: Range => F[Int])(implicit F: Foldable[F]): Unit = {
def nonzero(acc: Long, x: Int): Option[Long] =
if (x == 0) None else Some(acc + x)

def gte(lb: Int, x: Int): Option[Boolean] =
if (x >= lb) Some(true) else Some(false)

def gteSome(lb: Int, x: Int): Option[Option[Int]] =
if (x >= lb) Some(Some(x)) else Some(None)

val n = 100000
val expected = n.toLong*(n.toLong+1)/2
val foldMResult = F.foldM(fromRange(1 to n), 0L)(nonzero)
assert(foldMResult.get == expected)
val src = fromRange(1 to n)

val foldMExpected = n.toLong*(n.toLong+1)/2
val foldMResult = F.foldM(src, 0L)(nonzero)
assert(foldMResult.get == foldMExpected)

val existsMExpected = true
val existsMResult = F.existsM(src)(gte(n, _))
assert(existsMResult.get == existsMExpected)

val forallMExpected = true
val forallMResult = F.forallM(src)(gte(0, _))
assert(forallMResult.get == forallMExpected)

val findMExpected = Some(n)
val findMResult = src.findM(gte(n, _))
assert(findMResult.get == findMExpected)

val collectFirstSomeMExpected = Some(n)
val collectFirstSomeMResult = src.collectFirstSomeM(gteSome(n, _))
assert(collectFirstSomeMResult.get == collectFirstSomeMExpected)

()
}

test(s"Foldable.iterateRight") {
forAll { (fa: List[Int]) =>
val eval = Foldable.iterateRight(fa, Eval.later(0)) { (a, eb) =>
Expand All @@ -222,36 +249,36 @@ class FoldableSuiteAdditional extends CatsSuite {
}
}

test("Foldable[List].foldM stack safety") {
checkFoldMStackSafety[List](_.toList)
test("Foldable[List].foldM/existsM/forallM/findM/collectFirstSomeM stack safety") {
checkMonadicFoldsStackSafety[List](_.toList)
}

test("Foldable[Stream].foldM stack safety") {
checkFoldMStackSafety[Stream](_.toStream)
checkMonadicFoldsStackSafety[Stream](_.toStream)
}

test("Foldable[Vector].foldM stack safety") {
checkFoldMStackSafety[Vector](_.toVector)
test("Foldable[Vector].foldM/existsM/forallM/findM/collectFirstSomeM stack safety") {
checkMonadicFoldsStackSafety[Vector](_.toVector)
}

test("Foldable[SortedSet].foldM stack safety") {
checkFoldMStackSafety[SortedSet](s => SortedSet(s:_*))
test("Foldable[SortedSet].foldM/existsM/forallM/findM/collectFirstSomeM stack safety") {
checkMonadicFoldsStackSafety[SortedSet](s => SortedSet(s:_*))
}

test("Foldable[SortedMap[String, ?]].foldM stack safety") {
checkFoldMStackSafety[SortedMap[String, ?]](xs => SortedMap.empty[String, Int] ++ xs.map(x => x.toString -> x).toMap)
test("Foldable[SortedMap[String, ?]].foldM/existsM/forallM/findM/collectFirstSomeM stack safety") {
checkMonadicFoldsStackSafety[SortedMap[String, ?]](xs => SortedMap.empty[String, Int] ++ xs.map(x => x.toString -> x).toMap)
}

test("Foldable[NonEmptyList].foldM stack safety") {
checkFoldMStackSafety[NonEmptyList](xs => NonEmptyList.fromListUnsafe(xs.toList))
test("Foldable[NonEmptyList].foldM/existsM/forallM/findM/collectFirstSomeM stack safety") {
checkMonadicFoldsStackSafety[NonEmptyList](xs => NonEmptyList.fromListUnsafe(xs.toList))
}

test("Foldable[NonEmptyVector].foldM stack safety") {
checkFoldMStackSafety[NonEmptyVector](xs => NonEmptyVector.fromVectorUnsafe(xs.toVector))
test("Foldable[NonEmptyVector].foldM/existsM/forallM/findM/collectFirstSomeM stack safety") {
checkMonadicFoldsStackSafety[NonEmptyVector](xs => NonEmptyVector.fromVectorUnsafe(xs.toVector))
}

test("Foldable[NonEmptyStream].foldM stack safety") {
checkFoldMStackSafety[NonEmptyStream](xs => NonEmptyStream(xs.head, xs.tail: _*))
test("Foldable[NonEmptyStream].foldM/existsM/forallM/findM/collectFirstSomeM stack safety") {
checkMonadicFoldsStackSafety[NonEmptyStream](xs => NonEmptyStream(xs.head, xs.tail: _*))
}

test("Foldable[Stream]") {
Expand Down Expand Up @@ -324,6 +351,13 @@ class FoldableSuiteAdditional extends CatsSuite {
assert(F.forallM[Id, Boolean](false #:: boom)(identity) == false)
}

test(".findM/.collectFirstSomeM short-circuiting") {
implicit val F = foldableStreamWithDefaultImpl
def boom: Stream[Int] = sys.error("boom")
assert((1 #:: boom).findM[Id](_ > 0) == Some(1))
assert((1 #:: boom).collectFirstSomeM[Id, Int](Option.apply) == Some(1))
}

test("Foldable[List] doesn't break substitution") {
val result = List.range(0,10).foldM(List.empty[Int])((accum, elt) => Eval.always(elt :: accum))

Expand Down