Skip to content

Commit

Permalink
Improve batcher error handling and remove Akka from the Lambda (#2796)
Browse files Browse the repository at this point in the history
* make error handling easier to reason about

* idx->path

* improve access control

* whoops, missing file

* Remove Akka from the Batcher Lambda (#2797)

* push akka to the periphery of the worker service

* Apply auto-formatting rules

---------

Co-authored-by: Github on behalf of Wellcome Collection <[email protected]>

---------

Co-authored-by: Github on behalf of Wellcome Collection <[email protected]>
  • Loading branch information
paul-butcher and weco-bot authored Jan 6, 2025
1 parent c55ad24 commit a25e871
Show file tree
Hide file tree
Showing 8 changed files with 256 additions and 157 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import scala.concurrent.duration._
import scala.util.Try
import org.apache.pekko.{Done, NotUsed}
import org.apache.pekko.stream.scaladsl._
import org.apache.pekko.stream.Materializer
import software.amazon.awssdk.services.sqs.model.{Message => SQSMessage}
import grizzled.slf4j.Logging
import weco.messaging.MessageSender
Expand All @@ -22,7 +21,7 @@ class BatcherWorkerService[MsgDestination](
flushInterval: FiniteDuration,
maxProcessedPaths: Int,
maxBatchSize: Int
)(implicit ec: ExecutionContext, materializer: Materializer)
)(implicit ec: ExecutionContext)
extends Runnable
with Logging {

Expand All @@ -33,14 +32,13 @@ class BatcherWorkerService[MsgDestination](
source
.map {
case (msg: SQSMessage, notificationMessage: NotificationMessage) =>
(msg, notificationMessage.body)
PathFromSQS(notificationMessage.body, msg)
}
.groupedWithin(maxProcessedPaths, flushInterval)
.map(_.toList.unzip)
.mapAsync(1) {
case (msgs, paths) =>
paths =>
info(s"Processing ${paths.size} input paths")
processPaths(msgs, paths)
processPaths(paths)
}
.flatMapConcat(identity)
}
Expand All @@ -51,17 +49,15 @@ class BatcherWorkerService[MsgDestination](
* corresponding batches have been succesfully sent.
*/
private def processPaths(
msgs: List[SQSMessage],
paths: List[String]
paths: Seq[PathFromSQS]
): Future[Source[SQSMessage, NotUsed]] =
PathsProcessor(maxBatchSize, paths, SNSDownstream)
.map {
failedIndices =>
val failedIdxSet = failedIndices.toSet
Source(msgs).zipWithIndex
.collect {
case (msg, idx) if !failedIdxSet.contains(idx) => msg
}
failedPaths =>
val failedPathSet = failedPaths.toSet
Source(paths.collect {
case path if !failedPathSet.contains(path) => path.referent
}.toList)
}

private object SNSDownstream extends Downstream {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,12 @@ object CLIMain extends App {
val toStringFlow: Flow[ByteString, String, NotUsed] =
Flow[ByteString].map(_.utf8String)

val pathsProcessorFlow: Flow[Seq[String], Future[Seq[Long]], NotUsed] =
Flow[Seq[String]].map {
paths: Seq[String] =>
val toPathFlow: Flow[String, Path, NotUsed] =
Flow[String].map(PathFromString)

val pathsProcessorFlow: Flow[Seq[Path], Future[Seq[Path]], NotUsed] =
Flow[Seq[Path]].map {
paths: Seq[Path] =>
PathsProcessor(
40, // TODO: 40 is the number in the config used by Main, do this properly later
paths.toList,
Expand All @@ -44,6 +47,7 @@ object CLIMain extends App {
stdinSource
.via(lineDelimiter)
.via(toStringFlow)
.via(toPathFlow)
// this number is pretty arbitrary, but grouping of some kind is needed in order to
// provide a list to the next step, rather than individual paths
.grouped(10000)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package weco.pipeline.batcher
import com.amazonaws.services.lambda.runtime.{Context, RequestHandler}
import grizzled.slf4j.Logging
import com.amazonaws.services.lambda.runtime.events.SQSEvent
import org.apache.pekko.actor.ActorSystem
import weco.messaging.typesafe.SNSBuilder
import weco.json.JsonUtil._
import com.typesafe.config.ConfigFactory
Expand All @@ -12,6 +11,7 @@ import scala.concurrent.Await
import scala.concurrent.duration.DurationInt
import scala.concurrent.ExecutionContext
import scala.util.Try
import ExecutionContext.Implicits.global

object LambdaMain extends RequestHandler[SQSEvent, String] with Logging {
import weco.pipeline.batcher.lib.SQSEventOps._
Expand Down Expand Up @@ -39,14 +39,9 @@ object LambdaMain extends RequestHandler[SQSEvent, String] with Logging {
): String = {
debug(s"Running batcher lambda, got event: $event")

implicit val actorSystem: ActorSystem =
ActorSystem("main-actor-system")
implicit val ec: ExecutionContext =
actorSystem.dispatcher

val f = PathsProcessor(
config.requireInt("batcher.max_batch_size"),
event.extractPaths,
event.extractPaths.map(PathFromString),
downstream
)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package weco.pipeline.batcher
import software.amazon.awssdk.services.sqs.model.{Message => SQSMessage}

sealed trait Path extends Ordered[Path] {
val path: String
override def toString: String = path
override def compare(that: Path): Int = this.path compare that.path
}

sealed trait PathWithReferent[T] extends Path {
val referent: T
}

case class PathFromSQS(val path: String, val referent: SQSMessage)
extends PathWithReferent[SQSMessage]

case class PathFromString(val path: String) extends Path
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
package weco.pipeline.batcher
import grizzled.slf4j.Logging
import org.apache.pekko.NotUsed
import org.apache.pekko.stream.Materializer
import org.apache.pekko.stream.scaladsl.{Sink, Source}

import scala.concurrent.{ExecutionContext, Future}
import scala.util.{Failure, Success}
Expand All @@ -22,27 +19,43 @@ object PathsProcessor extends Logging {
* SQS/SNS-driven. Should just be the actual failed paths, and the caller
* should build a map to work it out if it wants to)
*/
def apply(maxBatchSize: Int, paths: List[String], downstream: Downstream)(
implicit ec: ExecutionContext,
materializer: Materializer
): Future[Seq[Long]] = {
def apply(maxBatchSize: Int, paths: Seq[Path], downstream: Downstream)(
implicit ec: ExecutionContext
): Future[Seq[Path]] = {
info(s"Processing ${paths.size} paths with max batch size $maxBatchSize")

generateBatches(maxBatchSize, paths)
.mapAsyncUnordered(10) {
case (batch, msgIndices) =>
Future
.sequence {
generateBatches(maxBatchSize, paths).map {
case (batch, msgPaths) =>
notifyDownstream(downstream, batch, msgPaths)
}
}
.flatMap {
results =>
Future {
downstream.notify(batch) match {
case Success(_) => None
case Failure(err) =>
error(s"Failed processing batch $batch with error: $err")
Some(msgIndices)
}
results.collect {
case Some(failedPaths) => failedPaths
}.flatten
}
}
.collect { case Some(failedIndices) => failedIndices }
.mapConcat(identity)
.runWith(Sink.seq)
}

private def notifyDownstream(
downstream: Downstream,
batch: Batch,
msgPaths: List[Path]
)(
implicit ec: ExecutionContext
): Future[Option[List[Path]]] = {
Future {
downstream.notify(batch) match {
case Success(_) => None
case Failure(err) =>
error(s"Failed processing batch $batch with error: $err")
Some(msgPaths)
}
}
}

/** Given a list of input paths, generate the minimal set of selectors
Expand All @@ -51,10 +64,35 @@ object PathsProcessor extends Logging {
*/
private def generateBatches(
maxBatchSize: Int,
paths: List[String]
): Source[(Batch, List[Long]), NotUsed] = {
paths: Seq[Path]
): Seq[(Batch, List[Path])] = {
val selectors = Selector.forPaths(paths)
val groupedSelectors = selectors.groupBy(_._1.rootPath)

logSelectors(paths, selectors, groupedSelectors)

groupedSelectors.map {
case (rootPath, selectorsAndPaths) =>
// For batches consisting of a really large number of selectors, we
// should just send the whole tree: this avoids really long queries
// in the relation embedder, or duplicate work of creating the archives
// cache multiple times, and it is likely pretty much all the nodes will
// be denormalised anyway.
val (selectors, inputPaths) = selectorsAndPaths.unzip(identity)
val batch =
if (selectors.size > maxBatchSize)
Batch(rootPath, List(Selector.Tree(rootPath)))
else
Batch(rootPath, selectors)
batch -> inputPaths
}.toSeq
}

private def logSelectors(
paths: Seq[Path],
selectors: List[(Selector, Path)],
groupedSelectors: Map[String, List[(Selector, Path)]]
): Unit = {
info(
s"Generated ${selectors.size} selectors spanning ${groupedSelectors.size} trees from ${paths.size} paths."
)
Expand All @@ -71,20 +109,5 @@ object PathsProcessor extends Logging {
s"Selectors for root path $rootPath: ${selectors.map(_._1).mkString(", ")}"
)
}
Source(groupedSelectors.toList).map {
case (rootPath, selectorsAndIndices) =>
// For batches consisting of a really large number of selectors, we
// should just send the whole tree: this avoids really long queries
// in the relation embedder, or duplicate work of creating the archives
// cache multiple times, and it is likely pretty much all the nodes will
// be denormalised anyway.
val (selectors, msgIndices) = selectorsAndIndices.unzip(identity)
val batch =
if (selectors.size > maxBatchSize)
Batch(rootPath, List(Selector.Tree(rootPath)))
else
Batch(rootPath, selectors)
batch -> msgIndices
}
}
}
Loading

0 comments on commit a25e871

Please sign in to comment.