Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,19 @@ class MessageRepository(queueName: String, db: DB) extends Logging {

sql"""
create table if not exists $tableName (
message_id varchar unique,
message_id varchar(255) not null unique,
delivery_receipts blob,
next_delivery bigint,
content blob,
attributes blob,
created bigint,
received bigint,
receive_count int,
group_id varchar,
deduplication_id varchar,
tracing_id varchar,
sequence_number varchar,
dead_letter_source_queue_name varchar
group_id varchar(255),
deduplication_id varchar(255),
tracing_id varchar(255),
sequence_number varchar(255),
dead_letter_source_queue_name varchar(255)
)""".execute.apply()

def drop(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class QueueRepository(db: DB) extends Logging {

sql"""
create table if not exists $tableName (
name varchar unique,
name varchar(255) not null unique,
data blob
)""".execute.apply()

Expand Down
2 changes: 2 additions & 0 deletions rest/rest-sqs/src/main/resources/reference.conf
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ rest-sqs {
bind-hostname = "0.0.0.0"
# Possible values: relaxed, strict
sqs-limits = strict
aws-access-key-id = ""
aws-secret-access-key = ""
}

rest-stats {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,15 @@ object SQSException {
errorMessage = Some("BatchEntryIdsNotDistinct")
)

def invalidClientTokenId(message: String): SQSException = {
new SQSException(
"InvalidClientTokenId",
403,
"AuthFailure",
Some(message)
)
}

def tooManyEntriesInBatchRequest: SQSException = new SQSException(
"AWS.SimpleQueueService.TooManyEntriesInBatchRequest",
errorType = "com.amazonaws.sqs#TooManyEntriesInBatchRequest",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import org.elasticmq.metrics.QueuesMetrics
import org.elasticmq.rest.sqs.Constants._
import org.elasticmq.rest.sqs.XmlNsVersion.extractXmlNs
import org.elasticmq.rest.sqs.directives.{
AWSCredentialDirectives,
AWSProtocolDirectives,
AnyParamDirectives,
ElasticMQDirectives,
Expand Down Expand Up @@ -127,6 +128,10 @@ case class TheSQSRestServerBuilder(
this.copy(queueEventListener = Some(_queueEventListener))

def start(): SQSRestServer = {
val rootConfig = ConfigFactory.load()
val restSqsConfig = rootConfig.getConfig("rest-sqs")
val credentials = AWSCredentials.fromConfig(restSqsConfig)

val (theActorSystem, stopActorSystem) = getOrCreateActorSystem
val theQueueManagerActor = getOrCreateQueueManagerActor(theActorSystem)
val theServerAddress =
Expand Down Expand Up @@ -173,7 +178,9 @@ case class TheSQSRestServerBuilder(
with ListDeadLetterSourceQueuesDirectives
with StartMessageMoveTaskDirectives
with CancelMessageMoveTaskDirectives
with ListMessageMoveTasksDirectives {
with ListMessageMoveTasksDirectives
with AWSCredentialsModule
with AWSCredentialDirectives {

def serverAddress = currentServerAddress.get()

Expand All @@ -183,6 +190,7 @@ case class TheSQSRestServerBuilder(
lazy val sqsLimits = theLimits
lazy val timeout = Timeout(21, TimeUnit.SECONDS) // see application.conf
lazy val contextPath = serverAddress.contextPathStripped
lazy val awsCredentials: AWSCredentials = credentials

lazy val awsRegion: String = _awsRegion
lazy val awsAccountId: String = _awsAccountId
Expand Down Expand Up @@ -235,13 +243,15 @@ case class TheSQSRestServerBuilder(
implicit val protocol: AWSProtocol = _protocol
handleServerExceptions(protocol) {
handleRejectionsWithSQSError(protocol) {
anyParamsMap(protocol) { p =>
val marshallerDependencies = MarshallerDependencies(protocol, version)
if (config.debug) {
logRequestResult("") {
rawRoutes(p)(marshallerDependencies)
}
} else rawRoutes(p)(marshallerDependencies)
verifyAWSAccessKeyId(protocol) {
anyParamsMap(protocol) { p =>
val marshallerDependencies = MarshallerDependencies(protocol, version)
if (config.debug) {
logRequestResult("") {
rawRoutes(p)(marshallerDependencies)
}
} else rawRoutes(p)(marshallerDependencies)
}
}
}
}
Expand Down Expand Up @@ -506,6 +516,21 @@ trait SQSLimitsModule {
def sqsLimits: Limits
}

trait AWSCredentialsModule {
def awsCredentials: AWSCredentials
}

case class AWSCredentials(accessKey: String, secretKey: String)

object AWSCredentials {
def fromConfig(config: com.typesafe.config.Config): AWSCredentials = {
AWSCredentials(
config.getString("aws-access-key-id"),
config.getString("aws-secret-access-key")
)
}
}

class ElasticMQConfig {
private lazy val rootConfig = ConfigFactory.load()
private lazy val elasticMQConfig = rootConfig.getConfig("elasticmq")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package org.elasticmq.rest.sqs.directives

import org.apache.pekko.http.scaladsl.server.{Directive0, Directives}
import org.elasticmq.rest.sqs.{AWSCredentialsModule, AWSProtocol, SQSException}

trait AWSCredentialDirectives extends Directives {
this: AWSCredentialsModule with ElasticMQDirectives =>

private val accessKeyRegex = "Credential=([^/]+)/".r

def verifyAWSAccessKeyId(protocol: AWSProtocol): Directive0 = {
if (awsCredentials.accessKey.nonEmpty) {
// Optional header in case it's missing
optionalHeaderValueByName("Authorization").flatMap {
case Some(authHeader) =>
accessKeyRegex.findFirstMatchIn(authHeader) match {
case Some(m) if m.group(1) == awsCredentials.accessKey =>
pass
case _ =>
// Must return a Directive0 here
complete(
SQSException.invalidClientTokenId(
"The security token included in the request is invalid."
)
)
}
case None =>
complete(
SQSException.invalidClientTokenId(
"The security token included in the request is invalid."
)
)
}
} else {
pass
}
}
}