From 5cdf69f345c4acd0c29fee2969231001a6bb87ba Mon Sep 17 00:00:00 2001 From: geirolz Date: Fri, 19 Jul 2024 15:31:18 +0200 Subject: [PATCH] Add ifTrueM, ifFalseM, ensureTrue and ensureFalse --- core/src/main/scala-2/cats/syntax/MonadOps.scala | 5 +++++ core/src/main/scala/cats/Monad.scala | 14 ++++++++++++++ core/src/main/scala/cats/MonadError.scala | 12 ++++++++++++ core/src/main/scala/cats/syntax/monadError.scala | 7 +++++++ .../test/scala/cats/tests/MonadErrorSuite.scala | 10 ++++++++++ .../src/test/scala/cats/tests/MonadSuite.scala | 15 +++++++++++++++ 6 files changed, 63 insertions(+) diff --git a/core/src/main/scala-2/cats/syntax/MonadOps.scala b/core/src/main/scala-2/cats/syntax/MonadOps.scala index cc2b782378..c64b9bf42b 100644 --- a/core/src/main/scala-2/cats/syntax/MonadOps.scala +++ b/core/src/main/scala-2/cats/syntax/MonadOps.scala @@ -21,6 +21,7 @@ package cats.syntax +import cats.kernel.Monoid import cats.{Alternative, Monad} final class MonadOps[F[_], A](private val fa: F[A]) extends AnyVal { @@ -32,4 +33,8 @@ final class MonadOps[F[_], A](private val fa: F[A]) extends AnyVal { def iterateUntil(p: A => Boolean)(implicit M: Monad[F]): F[A] = M.iterateUntil(fa)(p) def flatMapOrKeep[A1 >: A](pfa: PartialFunction[A, F[A1]])(implicit M: Monad[F]): F[A1] = M.flatMapOrKeep[A, A1](fa)(pfa) + def ifTrueM[B: Monoid](ifTrue: => F[B])(implicit env: F[A] <:< F[Boolean], M: Monad[F]): F[B] = + M.ifTrueM(fa)(ifTrue) + def ifFalseM[B: Monoid](ifFalse: => F[B])(implicit env: F[A] <:< F[Boolean],M: Monad[F]): F[B] = + M.ifFalseM(fa)(ifFalse) } diff --git a/core/src/main/scala/cats/Monad.scala b/core/src/main/scala/cats/Monad.scala index c694f854e3..353c35705c 100644 --- a/core/src/main/scala/cats/Monad.scala +++ b/core/src/main/scala/cats/Monad.scala @@ -21,6 +21,8 @@ package cats +import cats.kernel.Monoid + /** * Monad. * @@ -162,6 +164,18 @@ trait Monad[F[_]] extends FlatMap[F] with Applicative[F] { tailRecM(branches.toList)(step) } + /** + * If the `F[Boolean]` is `true` then return `ifTrue` otherwise return `ifFalse` + */ + def ifTrueM[B: Monoid](fa: F[Boolean])(ifTrue: => F[B]): F[B] = + ifM(fa)(ifTrue, pure(Monoid[B].empty)) + + /** + * If the `F[Boolean]` is `false` then return `ifFalse` otherwise return `Monoid[A].empty` + */ + def ifFalseM[B: Monoid](fa: F[Boolean])(ifFalse: => F[B]): F[B] = + ifM(fa)(pure(Monoid[B].empty), ifFalse) + /** * Modifies the `A` value in `F[A]` with the supplied function, if the function is defined for the value. * Example: diff --git a/core/src/main/scala/cats/MonadError.scala b/core/src/main/scala/cats/MonadError.scala index c6e1bfa136..4a19c49278 100644 --- a/core/src/main/scala/cats/MonadError.scala +++ b/core/src/main/scala/cats/MonadError.scala @@ -40,6 +40,18 @@ trait MonadError[F[_], E] extends ApplicativeError[F, E] with Monad[F] { def ensureOr[A](fa: F[A])(error: A => E)(predicate: A => Boolean): F[A] = flatMap(fa)(a => if (predicate(a)) pure(a) else raiseError(error(a))) + /** + * Ensures that a `F[Boolean]` is `true`, otherwise raises an error. + */ + def ensureTrue(fa: F[Boolean])(error: => E): F[Boolean] = + ensure(fa)(error)(identity) + + /** + * Ensures that a `F[Boolean]` is `false`, otherwise raises an error. + */ + def ensureFalse(fa: F[Boolean])(error: => E): F[Boolean] = + ensure(fa)(error)(bool => !bool) + /** * Inverse of `attempt` * diff --git a/core/src/main/scala/cats/syntax/monadError.scala b/core/src/main/scala/cats/syntax/monadError.scala index 5b13b52059..48ce310bd7 100644 --- a/core/src/main/scala/cats/syntax/monadError.scala +++ b/core/src/main/scala/cats/syntax/monadError.scala @@ -33,12 +33,19 @@ trait MonadErrorSyntax { } final class MonadErrorOps[F[_], E, A](private val fa: F[A]) extends AnyVal { + def ensure(error: => E)(predicate: A => Boolean)(implicit F: MonadError[F, E]): F[A] = F.ensure(fa)(error)(predicate) def ensureOr(error: A => E)(predicate: A => Boolean)(implicit F: MonadError[F, E]): F[A] = F.ensureOr(fa)(error)(predicate) + def ensureTrue(error: => E)(implicit env: F[A] <:< F[Boolean], F: MonadError[F, E]): F[Boolean] = + F.ensureTrue(fa)(error) + + def ensureFalse(error: => E)(implicit env: F[A] <:< F[Boolean], F: MonadError[F, E]): F[Boolean] = + F.ensureFalse(fa)(error) + /** * Turns a successful value into the error returned by a given partial function if it is * in the partial function's domain. diff --git a/tests/shared/src/test/scala/cats/tests/MonadErrorSuite.scala b/tests/shared/src/test/scala/cats/tests/MonadErrorSuite.scala index dd355641f1..57d11ea8bf 100644 --- a/tests/shared/src/test/scala/cats/tests/MonadErrorSuite.scala +++ b/tests/shared/src/test/scala/cats/tests/MonadErrorSuite.scala @@ -95,4 +95,14 @@ class MonadErrorSuite extends CatsSuite { test("rethrow returns the successful value, when applied to a Right of a specialized successful value") { assert(successful.attempt.asInstanceOf[Try[Either[IllegalArgumentException, Int]]].rethrow === successful) } + + test("ensureTrue raise an error only when the value is true") { + Try(true).ensureTrue(failedValue) === Failure(failedValue) + Try(false).ensureTrue(failedValue) === Success(false) + } + + test("ensureFalse raise an error only when the value is false") { + Try(true).ensureFalse(failedValue) === Success(true) + Try(false).ensureFalse(failedValue) === Failure(failedValue) + } } diff --git a/tests/shared/src/test/scala/cats/tests/MonadSuite.scala b/tests/shared/src/test/scala/cats/tests/MonadSuite.scala index aa5fa28794..173ea48061 100644 --- a/tests/shared/src/test/scala/cats/tests/MonadSuite.scala +++ b/tests/shared/src/test/scala/cats/tests/MonadSuite.scala @@ -154,4 +154,19 @@ class MonadSuite extends CatsSuite { assert(actual.value === 2) } + test("ifTrueM"){ + val actual1: Eval[Int] = Eval.later(true).ifTrueM(Eval.later(1)) + assert(actual1.value === 1) + + val actual2: Eval[Int] = Eval.later(false).ifTrueM(Eval.later(1)) + assert(actual2.value === 0) + } + + test("ifFalseM"){ + val actual1: Eval[Int] = Eval.later(true).ifFalseM(Eval.later(1)) + assert(actual1.value === 0) + + val actual2: Eval[Int] = Eval.later(false).ifFalseM(Eval.later(1)) + assert(actual2.value === 1) + } }