diff --git a/domains-cats/build.sbt b/domains-cats/build.sbt index f89a996f..3cfade73 100644 --- a/domains-cats/build.sbt +++ b/domains-cats/build.sbt @@ -6,19 +6,19 @@ libraryDependencies += "org.typelevel" %%% "cats-core" % "2.0.0-M4" libraryDependencies += "org.scalatest" %%% "scalatest" % "3.0.8" % Test -libraryDependencies += "com.thoughtworks.dsl" %%% "keywords-catch" % "1.3.2" +libraryDependencies += "com.thoughtworks.dsl" %%% "keywords-catch" % "1.4.0" -libraryDependencies += "com.thoughtworks.dsl" %%% "keywords-monadic" % "1.3.2" +libraryDependencies += "com.thoughtworks.dsl" %%% "keywords-monadic" % "1.4.0" -libraryDependencies += "com.thoughtworks.dsl" %%% "keywords-return" % "1.3.2" +libraryDependencies += "com.thoughtworks.dsl" %%% "keywords-return" % "1.4.0" -libraryDependencies += "com.thoughtworks.dsl" %%% "keywords-shift" % "1.3.2" % Optional +libraryDependencies += "com.thoughtworks.dsl" %%% "keywords-shift" % "1.4.0" % Optional -libraryDependencies += "com.thoughtworks.dsl" %%% "keywords-yield" % "1.3.2" % Optional +libraryDependencies += "com.thoughtworks.dsl" %%% "keywords-yield" % "1.4.0" % Optional -addCompilerPlugin("com.thoughtworks.dsl" %% "compilerplugins-bangnotation" % "1.3.2") +addCompilerPlugin("com.thoughtworks.dsl" %% "compilerplugins-bangnotation" % "1.4.0") -addCompilerPlugin("com.thoughtworks.dsl" %% "compilerplugins-reseteverywhere" % "1.3.2") +addCompilerPlugin("com.thoughtworks.dsl" %% "compilerplugins-reseteverywhere" % "1.4.0") scalacOptions ++= { import Ordering.Implicits._ diff --git a/domains-cats/src/main/scala/com/thoughtworks/dsl/domains/cats.scala b/domains-cats/src/main/scala/com/thoughtworks/dsl/domains/cats.scala index 2577591e..7c3c6761 100644 --- a/domains-cats/src/main/scala/com/thoughtworks/dsl/domains/cats.scala +++ b/domains-cats/src/main/scala/com/thoughtworks/dsl/domains/cats.scala @@ -3,7 +3,7 @@ package domains import _root_.cats.{Applicative, FlatMap} import com.thoughtworks.dsl.Dsl -import com.thoughtworks.dsl.Dsl.!! +import com.thoughtworks.dsl.Dsl.{!!, TryCatch, TryFinally} import _root_.cats.MonadError import com.thoughtworks.Extractor._ import com.thoughtworks.dsl.keywords.Catch.CatchDsl @@ -12,6 +12,7 @@ import com.thoughtworks.dsl.keywords.{Monadic, Return} import scala.language.higherKinds import scala.language.implicitConversions import scala.util.control.Exception.Catcher +import scala.util.control.NonFatal /** Contains interpreters to enable [[Dsl.Keyword#unary_$bang !-notation]] * for [[keywords.Monadic]] and other keywords @@ -112,8 +113,54 @@ object cats { (keyword: Return[A], handler: Nothing => F[B]) => applicative.pure(restReturnDsl.cpsApply(keyword, identity)) } + @inline private def catchNativeException[F[_], A](continuation: F[A] !! A)( + implicit monadThrowable: MonadThrowable[F]): F[A] = { + try { + continuation(monadThrowable.pure(_)) + } catch { + case NonFatal(e) => + monadThrowable.raiseError(e) + } + } + + implicit def catsTryFinally[F[_], A, B]( + implicit monadError: MonadThrowable[F]): TryFinally[A, F[B], F[A], F[Unit]] = { + (block: F[A] !! A, finalizer: F[Unit] !! Unit, outerSuccessHandler: A => F[B]) => + @inline + def injectFinalizer[A](f: Unit => F[A]): F[A] = { + monadError.flatMap(catchNativeException(finalizer))(f) + } + + monadError.flatMap(monadError.handleErrorWith(catchNativeException(block)) { e: Throwable => + injectFinalizer { _: Unit => + monadError.raiseError(e) + } + }) { a => + injectFinalizer { _: Unit => + outerSuccessHandler(a) + } + } + } + + implicit def catsTryCatch[F[_], A, B](implicit monadError: MonadThrowable[F]): TryCatch[A, F[B], F[A]] = { + (block: F[A] !! A, catcher: Catcher[F[A] !! A], outerSuccessHandler: A => F[B]) => + def errorHandler(e: Throwable): F[A] = { + (try { + catcher.lift(e) + } catch { + case NonFatal(extractorException) => + return monadError.raiseError(extractorException) + }) match { + case None => + monadError.raiseError(e) + case Some(recovered) => + catchNativeException(recovered) + } + } + monadError.flatMap(monadError.handleErrorWith(catchNativeException(block))(errorHandler))(outerSuccessHandler) + } - implicit def catsCatchDsl[F[_], A, B](implicit monadError: MonadThrowable[F]): CatchDsl[F[A], F[B], A] = { + private[dsl] def catsCatchDsl[F[_], A, B](implicit monadError: MonadThrowable[F]): CatchDsl[F[A], F[B], A] = { (block: F[A] !! A, catcher: Catcher[F[A] !! A], handler: A => F[B]) => val fa = monadError.flatMap(monadError.pure(block))(_(monadError.pure)) val protectedFa = monadError.handleErrorWith(fa) {