diff --git a/core/src/main/scala/cats/data/Validated.scala b/core/src/main/scala/cats/data/Validated.scala index 22f97f78a6..69002d74c0 100644 --- a/core/src/main/scala/cats/data/Validated.scala +++ b/core/src/main/scala/cats/data/Validated.scala @@ -203,6 +203,20 @@ sealed abstract class Validated[+E, +A] extends Product with Serializable { case Valid(a) => Invalid(a) case Invalid(e) => Valid(e) } + + /** + * Ensure that a successful result passes the given predicate, + * falling back to an Invalid of `onFailure` if the predicate + * returns false. + * + * For example: + * {{{ + * scala> Validated.valid("").ensure(new IllegalArgumentException("Must not be empty"))(_.nonEmpty) + * res0: Validated[IllegalArgumentException,String] = Invalid(java.lang.IllegalArgumentException: Must not be empty) + * }}} + */ + def ensure[EE >: E](onFailure: => EE)(f: A => Boolean): Validated[EE, A] = + fold(_ => this, a => if (f(a)) this else Validated.invalid(onFailure)) } object Validated extends ValidatedInstances with ValidatedFunctions{ diff --git a/tests/src/test/scala/cats/tests/ValidatedTests.scala b/tests/src/test/scala/cats/tests/ValidatedTests.scala index e755b1719d..947a8b899a 100644 --- a/tests/src/test/scala/cats/tests/ValidatedTests.scala +++ b/tests/src/test/scala/cats/tests/ValidatedTests.scala @@ -188,4 +188,20 @@ class ValidatedTests extends CatsSuite { val z = x.map2(y)((i, b) => if (b) i + 1 else i) z should === (NonEmptyList("error 1", "error 2").invalid[Int]) } + + test("ensure on Invalid is identity") { + forAll { (x: Validated[Int,String], i: Int, p: String => Boolean) => + if (x.isInvalid) { + x.ensure(i)(p) should === (x) + } + } + } + + test("ensure should fail if predicate not satisfied") { + forAll { (x: Validated[String, Int], s: String, p: Int => Boolean) => + if (x.exists(!p(_))) { + x.ensure(s)(p) should === (Validated.invalid(s)) + } + } + } }