Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AWS S3: Add getObjectByRanges to S3 API #2982

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 128 additions & 0 deletions s3/src/main/scala/akka/stream/alpakka/s3/impl/MergeOrderedN.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
package akka.stream.alpakka.s3.impl

import akka.annotation.InternalApi
import akka.stream.stage.{GraphStage, GraphStageLogic, InHandler, OutHandler}
import akka.stream.{Attributes, Inlet, Outlet, UniformFanInShape}

import scala.collection.{immutable, mutable}

@InternalApi private[impl] object MergeOrderedN {
/** @see [[MergeOrderedN]] */
def apply[T](inputPorts: Int, breadth: Int) =
new MergeOrderedN[T](inputPorts, breadth)
}

/**
* Takes multiple streams (in ascending order of input ports) whose elements will be pushed only if all elements from the
* previous stream(s) are already pushed downstream.
*
* The `breadth` controls how many upstream are pulled in parallel.
* That means elements might be received in any order, but will be buffered (if necessary) until their time comes.
*
* '''Emits when''' the next element from upstream (in ascending order of input ports) is available
*
* '''Backpressures when''' downstream backpressures
*
* '''Completes when''' all upstreams complete and there are no more buffered elements
*
* '''Cancels when''' downstream cancels
*/
@InternalApi private[impl] final class MergeOrderedN[T](val inputPorts: Int, val breadth: Int) extends GraphStage[UniformFanInShape[T, T]] {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ConcatPreFetch or something like that would be more correct , merge in Akka streams generally mean emit in any order.

require(inputPorts > 1, "input ports must be > 1")
require(breadth > 0, "breadth must be > 0")

val in: immutable.IndexedSeq[Inlet[T]] = Vector.tabulate(inputPorts)(i => Inlet[T]("MergeOrderedN.in" + i))
val out: Outlet[T] = Outlet[T]("MergeOrderedN.out")
override val shape: UniformFanInShape[T, T] = UniformFanInShape(out, in: _*)

override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with OutHandler {
private val bufferByInPort = mutable.Map.empty[Int, mutable.Queue[T]] // Queue must not be empty, if so entry should be removed
private var currentHeadInPortIdx = 0
private var currentLastInPortIdx = 0
private val overallLastInPortIdx = inputPorts - 1

setHandler(out, this)

in.zipWithIndex.foreach { case (inPort, idx) =>
setHandler(inPort, new InHandler {
override def onPush(): Unit = {
val elem = grab(inPort)
if (currentHeadInPortIdx != idx || !isAvailable(out)) {
bufferByInPort.updateWith(idx) {
case Some(inPortBuffer) =>
Some(inPortBuffer.enqueue(elem))
case None =>
val inPortBuffer = mutable.Queue.empty[T]
inPortBuffer.enqueue(elem)
Some(inPortBuffer)
}
} else {
pushUsingQueue(Some(elem))
}
tryPull(inPort)
}

override def onUpstreamFinish(): Unit = {
if (canCompleteStage)
completeStage()
else if (canSlideFrame)
slideFrame()
}
})
}

override def onPull(): Unit = pushUsingQueue()

private def pushUsingQueue(next: Option[T] = None): Unit = {
val maybeBuffer = bufferByInPort.get(currentHeadInPortIdx)
if (maybeBuffer.forall(_.isEmpty) && next.nonEmpty) {
push(out, next.get)
} else if (maybeBuffer.exists(_.nonEmpty) && next.nonEmpty) {
maybeBuffer.get.enqueue(next.get)
push(out, maybeBuffer.get.dequeue())
} else if (maybeBuffer.exists(_.nonEmpty) && next.isEmpty) {
push(out, maybeBuffer.get.dequeue())
} else {
// Both empty
}

if (maybeBuffer.exists(_.isEmpty))
bufferByInPort.remove(currentHeadInPortIdx)

if (canCompleteStage)
completeStage()
else if (canSlideFrame)
slideFrame()
}

override def preStart(): Unit = {
if (breadth >= inputPorts) {
in.foreach(pull)
currentLastInPortIdx = overallLastInPortIdx
} else {
in.slice(0, breadth).foreach(pull)
currentLastInPortIdx = breadth - 1
}
}

private def canSlideFrame: Boolean =
(!bufferByInPort.contains(currentHeadInPortIdx) || bufferByInPort(currentHeadInPortIdx).isEmpty) &&
isClosed(in(currentHeadInPortIdx))

private def canCompleteStage: Boolean =
canSlideFrame && currentHeadInPortIdx == overallLastInPortIdx

private def slideFrame(): Unit = {
currentHeadInPortIdx += 1

if (isAvailable(out))
pushUsingQueue()

if (currentLastInPortIdx != overallLastInPortIdx)
currentLastInPortIdx += 1

if (!hasBeenPulled(in(currentLastInPortIdx)))
tryPull(in(currentLastInPortIdx))
}
}
}
72 changes: 72 additions & 0 deletions s3/src/main/scala/akka/stream/alpakka/s3/impl/S3Stream.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import akka.annotation.InternalApi
import akka.dispatch.ExecutionContexts
import akka.http.scaladsl.Http.OutgoingConnection
import akka.http.scaladsl.model.StatusCodes.{NoContent, NotFound, OK}
import akka.http.scaladsl.model.headers.ByteRange.FromOffset
import akka.http.scaladsl.model.headers._
import akka.http.scaladsl.model.{headers => http, _}
import akka.http.scaladsl.settings.{ClientConnectionSettings, ConnectionPoolSettings}
Expand All @@ -20,13 +21,15 @@ import akka.http.scaladsl.{ClientTransport, Http}
import akka.stream.alpakka.s3.BucketAccess.{AccessDenied, AccessGranted, NotExists}
import akka.stream.alpakka.s3._
import akka.stream.alpakka.s3.impl.auth.{CredentialScope, Signer, SigningKey}
import akka.stream.alpakka.s3.scaladsl.S3
import akka.stream.scaladsl.{Flow, Keep, RetryFlow, RunnableGraph, Sink, Source, Tcp}
import akka.stream.{Attributes, Materializer}
import akka.util.ByteString
import akka.{Done, NotUsed}
import software.amazon.awssdk.regions.Region

import scala.collection.immutable
import scala.collection.mutable.ListBuffer
import scala.concurrent.{Future, Promise}
import scala.util.{Failure, Success, Try}

Expand All @@ -37,6 +40,9 @@ import scala.util.{Failure, Success, Try}
BucketAndKey.validateObjectKey(key, conf)
this
}

def mkString: String =
s"s3://$bucket/$key"
}

/** Internal Api */
Expand Down Expand Up @@ -165,6 +171,7 @@ import scala.util.{Failure, Success, Try}
import Marshalling._

val MinChunkSize: Int = 5 * 1024 * 1024 //in bytes
val DefaultByteRangeSize: Long = 8 * 1024 * 1024
val atLeastOneByteString: Flow[ByteString, ByteString, NotUsed] =
Flow[ByteString].orElse(Source.single(ByteString.empty))

Expand Down Expand Up @@ -232,6 +239,71 @@ import scala.util.{Failure, Success, Try}
.mapMaterializedValue(_.flatMap(identity)(ExecutionContexts.parasitic))
}

def getObjectByRanges(
s3Location: S3Location,
versionId: Option[String],
s3Headers: S3Headers,
rangeSize: Long = DefaultByteRangeSize,
parallelism: Int = 4
): Source[ByteString, Future[ObjectMetadata]] = {
Source.fromMaterializer { (_, _) =>
val objectMetadataMat = Promise[ObjectMetadata]()
getObjectMetadata(s3Location.bucket, s3Location.key, versionId, s3Headers)
.flatMapConcat {
case Some(s3Meta) if s3Meta.contentLength == 0 =>
objectMetadataMat.success(s3Meta)
Source.empty[ByteString]
case Some(s3Meta) =>
objectMetadataMat.success(s3Meta)
val byteRanges = computeByteRanges(s3Meta.contentLength, rangeSize)
if (byteRanges.size <= 1) {
getObject(s3Location, None, versionId, s3Headers)
} else {
val rangeSources = prepareRangeSources(s3Location, versionId, s3Headers, byteRanges)
Source.combine[ByteString, ByteString](
rangeSources.head,
rangeSources(1),
rangeSources.drop(2): _*
)(p => MergeOrderedN(p, parallelism))
}
case None =>
Source.failed(throw new NoSuchElementException(s"Object does not exist at location [${s3Location.mkString}]"))
}
.mapError {
case e: Throwable =>
objectMetadataMat.tryFailure(e)
e
}
.mapMaterializedValue(_ => objectMetadataMat.future)
}
.mapMaterializedValue(_.flatMap(identity)(ExecutionContexts.parasitic))
}

private def computeByteRanges(contentLength: Long, rangeSize: Long): Seq[ByteRange] = {
require(contentLength >= 0, s"contentLength ($contentLength) must be >= 0")
require(rangeSize > 0, s"rangeSize ($rangeSize) must be > 0")
if (contentLength <= rangeSize)
Nil
else {
val ranges = ListBuffer[ByteRange]()
for (i <- 0L until contentLength by rangeSize) {
if ((i + rangeSize) >= contentLength)
ranges += FromOffset(i)
else
ranges += ByteRange(i, i + rangeSize - 1)
}
ranges.result()
}
}

private def prepareRangeSources(
s3Location: S3Location,
versionId: Option[String],
s3Headers: S3Headers,
byteRanges: Seq[ByteRange]
): Seq[Source[ByteString, Future[ObjectMetadata]]] =
byteRanges.map(br => getObject(s3Location, Some(br), versionId, s3Headers))

/**
* An ADT that represents the current state of pagination
*/
Expand Down
49 changes: 49 additions & 0 deletions s3/src/main/scala/akka/stream/alpakka/s3/scaladsl/S3.scala
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,55 @@ object S3 {
): Source[ByteString, Future[ObjectMetadata]] =
S3Stream.getObject(S3Location(bucket, key), range, versionId, s3Headers)

/**
* Gets a S3 Object using `Byte-Range Fetches`
*
* @param bucket the s3 bucket name
* @param key the s3 object key
* @param sse [optional] the server side encryption used on upload
* @param rangeSize size of each range to request
* @param parallelism number of range to request in parallel
*
* @return A [[akka.stream.scaladsl.Source]] containing the objects data as a [[akka.util.ByteString]] along with a materialized value containing the
* [[akka.stream.alpakka.s3.ObjectMetadata]]
*/
def getObjectByRanges(
bucket: String,
key: String,
versionId: Option[String] = None,
sse: Option[ServerSideEncryption] = None,
rangeSize: Long = MinChunkSize,
parallelism: Int = 4
): Source[ByteString, Future[ObjectMetadata]] =
S3Stream.getObjectByRanges(S3Location(bucket, key),
versionId,
S3Headers.empty.withOptionalServerSideEncryption(sse),
rangeSize,
parallelism
)

/**
* Gets a S3 Object using `Byte-Range Fetches`
*
* @param bucket the s3 bucket name
* @param key the s3 object key
* @param s3Headers any headers you want to add
* @param rangeSize size of each range to request
* @param parallelism number of range to request in parallel
*
* @return A [[akka.stream.scaladsl.Source]] containing the objects data as a [[akka.util.ByteString]] along with a materialized value containing the
* [[akka.stream.alpakka.s3.ObjectMetadata]]
*/
def getObjectByRanges(
bucket: String,
key: String,
versionId: Option[String],
s3Headers: S3Headers,
rangeSize: Long,
parallelism: Int
): Source[ByteString, Future[ObjectMetadata]] =
S3Stream.getObjectByRanges(S3Location(bucket, key), versionId, s3Headers, rangeSize, parallelism)

/**
* Will return a list containing all of the buckets for the current AWS account
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package akka.stream.alpakka.s3.scaladsl

import akka.actor.ActorSystem
import akka.http.scaladsl.model.headers.ByteRange
import akka.stream.alpakka.s3.S3Settings
import akka.stream.alpakka.s3.headers.ServerSideEncryption
import akka.stream.alpakka.s3.impl.S3Stream
Expand Down Expand Up @@ -177,6 +178,19 @@ abstract class S3WireMockBase(_system: ActorSystem, val _wireMockServer: WireMoc
)
)

def mockRangedDownload(byteRange: ByteRange, range: String): Unit =
mock
.register(
get(urlEqualTo(s"/$bucketKey"))
.withHeader("Range", new EqualToPattern(s"bytes=$byteRange"))
.willReturn(
aResponse()
.withStatus(200)
.withHeader("ETag", """"fba9dede5f27731c9771645a39863328"""")
.withBody(range)
)
)

def mockRangedDownload(): Unit =
mock
.register(
Expand Down
38 changes: 38 additions & 0 deletions s3/src/test/scala/docs/scaladsl/S3SourceSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package docs.scaladsl

import akka.http.scaladsl.model.headers.ByteRange
import akka.http.scaladsl.model.headers.ByteRange.FromOffset
import akka.http.scaladsl.model.{ContentType, ContentTypes, HttpEntity, HttpResponse, IllegalUriException}
import akka.stream.Attributes
import akka.stream.alpakka.s3.BucketAccess.{AccessDenied, AccessGranted, NotExists}
Expand All @@ -26,6 +27,43 @@ class S3SourceSpec extends S3WireMockBase with S3ClientIntegrationSpec {
override protected def afterEach(): Unit =
mock.removeMappings()

"S3Source" should "download a stream of bytes by ranges from S3" in {

val bodyBytes = ByteString(body)
val bodyRanges = bodyBytes.grouped(10).toList
val rangeHeaders = bodyRanges.zipWithIndex.map {
case (_, idx) if idx != bodyRanges.size - 1 =>
ByteRange(idx * 10, (idx * 10) + 10 - 1)
case (_, idx) =>
FromOffset(idx * 10)
}
val rangesWithHeaders = bodyRanges.zip(rangeHeaders)

mockHead(bodyBytes.size)
rangesWithHeaders.foreach { case (bs, br) => mockRangedDownload(br, bs.utf8String)}

val s3Source: Source[ByteString, Future[ObjectMetadata]] =
S3.getObjectByRanges(bucket, bucketKey, rangeSize = 10L)

val (metadataFuture, dataFuture) =
s3Source.toMat(Sink.reduce[ByteString](_ ++ _))(Keep.both).run()

val data = dataFuture.futureValue
val metadata = metadataFuture.futureValue

data.utf8String shouldBe body

HttpResponse(
entity = HttpEntity(
metadata.contentType
.flatMap(ContentType.parse(_).toOption)
.getOrElse(ContentTypes.`application/octet-stream`),
metadata.contentLength,
s3Source
)
)
}

"S3Source" should "download a stream of bytes from S3" in {

mockDownload()
Expand Down