Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AWS S3: Add getObjectByRanges to S3 API #2982

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 0 additions & 128 deletions s3/src/main/scala/akka/stream/alpakka/s3/impl/MergeOrderedN.scala

This file was deleted.

106 changes: 82 additions & 24 deletions s3/src/main/scala/akka/stream/alpakka/s3/impl/S3Stream.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@

package akka.stream.alpakka.s3.impl

import java.net.InetSocketAddress
import java.time.{Instant, ZoneOffset, ZonedDateTime}
import scala.annotation.nowarn
import akka.actor.ActorSystem
import akka.annotation.InternalApi
import akka.dispatch.ExecutionContexts
Expand All @@ -21,15 +18,17 @@ import akka.http.scaladsl.{ClientTransport, Http}
import akka.stream.alpakka.s3.BucketAccess.{AccessDenied, AccessGranted, NotExists}
import akka.stream.alpakka.s3._
import akka.stream.alpakka.s3.impl.auth.{CredentialScope, Signer, SigningKey}
import akka.stream.alpakka.s3.scaladsl.S3
import akka.stream.scaladsl.{Flow, Keep, RetryFlow, RunnableGraph, Sink, Source, Tcp}
import akka.stream.{Attributes, Materializer}
import akka.util.ByteString
import akka.{Done, NotUsed}
import software.amazon.awssdk.regions.Region

import scala.collection.immutable
import java.net.InetSocketAddress
import java.time.{Instant, ZoneOffset, ZonedDateTime}
import scala.annotation.{nowarn, tailrec}
import scala.collection.mutable.ListBuffer
import scala.collection.{immutable, mutable}
import scala.concurrent.{Future, Promise}
import scala.util.{Failure, Success, Try}

Expand Down Expand Up @@ -175,6 +174,8 @@ import scala.util.{Failure, Success, Try}
val atLeastOneByteString: Flow[ByteString, ByteString, NotUsed] =
Flow[ByteString].orElse(Source.single(ByteString.empty))

private val RangeEndMarker = "$END$"

// def because tokens can expire
private def signingKey(implicit settings: S3Settings) = {
val requestDate = ZonedDateTime.now(ZoneOffset.UTC)
Expand Down Expand Up @@ -255,19 +256,11 @@ import scala.util.{Failure, Success, Try}
Source.empty[ByteString]
case Some(s3Meta) =>
objectMetadataMat.success(s3Meta)
val byteRanges = computeByteRanges(s3Meta.contentLength, rangeSize)
if (byteRanges.size <= 1) {
getObject(s3Location, None, versionId, s3Headers)
} else {
val rangeSources = prepareRangeSources(s3Location, versionId, s3Headers, byteRanges)
Source.combine[ByteString, ByteString](
rangeSources.head,
rangeSources(1),
rangeSources.drop(2): _*
)(p => MergeOrderedN(p, parallelism))
}
doGetByRanges(s3Location, versionId, s3Headers, s3Meta.contentLength, rangeSize, parallelism)
case None =>
Source.failed(throw new NoSuchElementException(s"Object does not exist at location [${s3Location.mkString}]"))
val exc = new NoSuchElementException(s"Object does not exist at location [${s3Location.mkString}]")
objectMetadataMat.failure(exc)
Source.failed(exc)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will close the downstream ASAP, do you want to defer it?

Copy link
Contributor Author

@gael-ft gael-ft Aug 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It happens after getObjectMetadata result has been pulled so I am not sure of what that implies in this context ?

The idea is that no ObjectMetadata means no S3 object, so I think the source should fail as well as the materialized Future.

}
.mapError {
case e: Throwable =>
Expand All @@ -279,6 +272,29 @@ import scala.util.{Failure, Success, Try}
.mapMaterializedValue(_.flatMap(identity)(ExecutionContexts.parasitic))
}

private def doGetByRanges(
s3Location: S3Location,
versionId: Option[String],
s3Headers: S3Headers,
contentLength: Long,
rangeSize: Long,
parallelism: Int
): Source[ByteString, Any] = {
val byteRanges = computeByteRanges(contentLength, rangeSize)
if (byteRanges.size <= 1) {
getObject(s3Location, None, versionId, s3Headers)
} else {
Source(byteRanges)
.zipWithIndex
.flatMapMerge(parallelism, brToIdx => {
val (br, idx) = brToIdx
val endMarker = Source.single(ByteString("$END$"))
getObject(s3Location, Some(br), versionId, s3Headers).concat(endMarker).map(_ -> idx)
})
.statefulMapConcat(RangeMapConcat)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Composing each range-source with a buffer to allow parallel fetching, and then concatenating the resulting streams to get the resulting bytes out in the right order seems like it would achieve the same but much simpler.

Am I missing something clever that this does?

Copy link
Contributor Author

@gael-ft gael-ft May 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understood correctly, you are thinking about conflate or something similar to buffer range sources.
Something like:

getObject(s3Location, Some(br), versionId, s3Headers).conflate(_ ++ _).concat(endMarker).map(_ -> idx)

As flatMapMerge may emit in any order, I still need the range idx to order (possibly buffered) bytes.
So output item of flatMapMerge will look like (ByteString, Long) and can be in any order (regarding the Long).

How can I order them back, without statefulMapConcat ? Range2 could emit before range1 is complete and range2 could be complete before range1.

Note I am not trying to buffer "next" range, if bytes of the "next" range are pushed, I'll push them directly downstream as buffering those bytes is useless (?).


As well, regarding buffers, was not sure if it was useful to "hard pull" upstreams until parallelism * rangeSize is consumed.
Something like:

//...
  .statefulMapConcat(RangeMapConcat)
  // Might be useful to consume elements of all flatMapMerge materialized upstreams 
  .batchWeighted(parallelism * rangeSize, _.size, identity)(_ ++ _)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking something like

Source(byteRanges)
  .mapAsync(parallelism)(range => 
    getObjectByRanges(...).buffer(size, Backpressure)
  ).flatMapConcat(identity)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But ofc that may not be good enough with buffer sized in chunks instead of bytes, we don't have a buffer with weighted size calculation though, maybe batchWeighted could do, not sure.

Copy link
Contributor Author

@gael-ft gael-ft Jun 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm can't make it work:

Tried with:

Source(byteRanges)
  .mapAsync(parallelism)(br => Future.successful(
    getObject(s3Location, Some(br), versionId, s3Headers).batchWeighted(rangeSize, _.size, identity)(_ ++ _)
  ))
  .flatMapConcat(identity)

and

Source(byteRanges)
  .mapAsync(parallelism)(br => Future.successful(
    Source.fromMaterializer { case (mat, _) =>
      getObject(s3Location, Some(br), versionId, s3Headers)
        .preMaterialize()(mat)
        ._2
        .batchWeighted(rangeSize, _.size, identity)(_ ++ _)
    }
  ))
  .flatMapConcat(identity)

But in both situations, ranges are fetched one by one and download perf looks like getObject.
Just like if .mapAsync(P)(_ => someSource).flatMapConcat(identity) was not enough to materalize P sources at the same time.
Leaving us with the flatMapMerge ...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, ofc, they aren't materialized so they can start consume bytes until flatMapConcat:ed, didn't think of that. Pre-materialization creates a running source but the downstream is not materialized until it is used, so you would need to put the batching before preMaterialize.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the record I created an upstream issue with an idea that could make this kind of thing easier: akka/akka#31958 (continue with the current solution here though)

}
}

private def computeByteRanges(contentLength: Long, rangeSize: Long): Seq[ByteRange] = {
require(contentLength >= 0, s"contentLength ($contentLength) must be >= 0")
require(rangeSize > 0, s"rangeSize ($rangeSize) must be > 0")
Expand All @@ -296,13 +312,55 @@ import scala.util.{Failure, Success, Try}
}
}

private def prepareRangeSources(
s3Location: S3Location,
versionId: Option[String],
s3Headers: S3Headers,
byteRanges: Seq[ByteRange]
): Seq[Source[ByteString, Future[ObjectMetadata]]] =
byteRanges.map(br => getObject(s3Location, Some(br), versionId, s3Headers))
private val RangeMapConcat: () => ((ByteString, Long)) => IterableOnce[ByteString] = () => {
var currentRangeIdx = 0L
var completedRanges = Set.empty[Long]
var bufferByRangeIdx = Map.empty[Long, mutable.Queue[ByteString]]

val isEndMarker: ByteString => Boolean = bs => bs.size == RangeEndMarker.length && bs.utf8String == RangeEndMarker

def foldRangeBuffers(): Option[ByteString] = {
@tailrec
def innerFoldRangeBuffers(acc: Option[ByteString]): Option[ByteString] = {
bufferByRangeIdx.get(currentRangeIdx) match {
case None =>
if (completedRanges.contains(currentRangeIdx))
currentRangeIdx += 1
if (bufferByRangeIdx.contains(currentRangeIdx))
innerFoldRangeBuffers(acc)
else
acc
case Some(queue) =>
val next = queue.dequeueAll(_ => true).foldLeft(acc.getOrElse(ByteString.empty))(_ ++ _)
bufferByRangeIdx -= currentRangeIdx
if (completedRanges.contains(currentRangeIdx))
currentRangeIdx += 1
if (bufferByRangeIdx.contains(currentRangeIdx))
innerFoldRangeBuffers(Some(next))
else
Some(next)
}
}

innerFoldRangeBuffers(None)
}

bsToIdx => {
val (bs, idx) = bsToIdx
if (isEndMarker(bs)) {
completedRanges = completedRanges + idx
foldRangeBuffers().toList
} else if (idx == currentRangeIdx) {
bs :: Nil
} else {
bufferByRangeIdx = bufferByRangeIdx.updatedWith(idx.toInt) {
case Some(queue) => Some(queue.enqueue(bs))
case None => Some(mutable.Queue(bs))
}
Nil
}
}
}

/**
* An ADT that represents the current state of pagination
Expand Down