Skip to content

Commit

Permalink
fix(rabbitmq): Parse vhost correctly if it's provided in the host url (
Browse files Browse the repository at this point in the history
  • Loading branch information
JorTurFer authored Sep 1, 2022
1 parent 8d06f6f commit f5b7234
Show file tree
Hide file tree
Showing 8 changed files with 286 additions and 42 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ To learn more about our roadmap, we recommend reading [this document](ROADMAP.md

- **General:** Metrics endpoint returns correct HPA values ([#3554](https://github.com/kedacore/keda/issues/3554))
- **Datadog Scaler:** Fix: panic in datadog scaler ([#3448](https://github.com/kedacore/keda/issues/3448))
- **RabbitMQ Scaler:** Parse vhost correctly if it's provided in the host url ([#3602](https://github.com/kedacore/keda/issues/3602))

### Deprecations

Expand Down
27 changes: 7 additions & 20 deletions pkg/scalers/rabbitmq_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ type rabbitMQMetadata struct {
activationValue float64 // activation value
host string // connection string for either HTTP or AMQP protocol
protocol string // either http or amqp protocol
vhostName *string // override the vhost from the connection info
vhostName string // override the vhost from the connection info
useRegex bool // specify if the queueName contains a rexeg
excludeUnacknowledged bool // specify if the QueueLength value should exclude Unacknowledged messages (Ready messages only)
pageSize int64 // specify the page size if useRegex is enabled
Expand Down Expand Up @@ -121,12 +121,12 @@ func NewRabbitMQScaler(config *ScalerConfig) (Scaler, error) {
if meta.protocol == amqpProtocol {
// Override vhost if requested.
host := meta.host
if meta.vhostName != nil {
if meta.vhostName != "" {
hostURI, err := amqp.ParseURI(host)
if err != nil {
return nil, fmt.Errorf("error parsing rabbitmq connection string: %s", err)
}
hostURI.Vhost = *meta.vhostName
hostURI.Vhost = meta.vhostName
host = hostURI.String()
}

Expand Down Expand Up @@ -193,7 +193,7 @@ func parseRabbitMQMetadata(config *ScalerConfig) (*rabbitMQMetadata, error) {

// Resolve vhostName
if val, ok := config.TriggerMetadata["vhostName"]; ok {
meta.vhostName = &val
meta.vhostName = val
}

err := parseRabbitMQHttpProtocolMetadata(config, &meta)
Expand Down Expand Up @@ -457,24 +457,11 @@ func (s *rabbitMQScaler) getQueueInfoViaHTTP() (*queueInfo, error) {
// Extract vhost from URL's path.
vhost := parsedURL.Path

// If the URL's path only contains a slash, it represents the trailing slash and
// must be ignored because it may cause confusion with the '/' vhost.
if vhost == "/" {
vhost = ""
if s.metadata.vhostName != "" {
vhost = "/" + url.QueryEscape(s.metadata.vhostName)
}

// Override vhost if requested.
if s.metadata.vhostName != nil {
// If the desired vhost is "All" vhosts, no path is necessary
if *s.metadata.vhostName == "" {
vhost = ""
} else {
vhost = "/" + url.QueryEscape(*s.metadata.vhostName)
}
}

// Encode the '/' vhost if necessary.
if vhost == "//" {
if vhost == "" || vhost == "/" || vhost == "//" {
vhost = rabbitRootVhostPath
}

Expand Down
33 changes: 19 additions & 14 deletions pkg/scalers/rabbitmq_scaler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,18 +225,8 @@ func TestGetQueueInfo(t *testing.T) {
for _, testData := range allTestData {
testData := testData

var expectedVhostPath string
switch testData.vhostPath {
case "/myhost":
expectedVhostPath = "/myhost"
case rabbitRootVhostPath, "//":
expectedVhostPath = rabbitRootVhostPath
default:
expectedVhostPath = ""
}

var apiStub = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
expectedPath := fmt.Sprintf("/api/queues%s/evaluate_trials", expectedVhostPath)
expectedPath := fmt.Sprintf("/api/queues%s/evaluate_trials", getExpectedVhost(testData.vhostPath))
if r.RequestURI != expectedPath {
t.Error("Expect request path to =", expectedPath, "but it is", r.RequestURI)
}
Expand Down Expand Up @@ -373,7 +363,7 @@ func TestGetQueueInfoWithRegex(t *testing.T) {

for _, testData := range allTestData {
var apiStub = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
expectedPath := fmt.Sprintf("/api/queues%s?page=1&use_regex=true&pagination=false&name=%%5Eevaluate_trials%%24&page_size=100", testData.vhostPath)
expectedPath := fmt.Sprintf("/api/queues%s?page=1&use_regex=true&pagination=false&name=%%5Eevaluate_trials%%24&page_size=100", getExpectedVhost(testData.vhostPath))
if r.RequestURI != expectedPath {
t.Error("Expect request path to =", expectedPath, "but it is", r.RequestURI)
}
Expand Down Expand Up @@ -453,7 +443,7 @@ func TestGetPageSizeWithRegex(t *testing.T) {

for _, testData := range allTestData {
var apiStub = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
expectedPath := fmt.Sprintf("/api/queues%s?page=1&use_regex=true&pagination=false&name=%%5Eevaluate_trials%%24&page_size=%d", testData.queueInfo.vhostPath, testData.pageSize)
expectedPath := fmt.Sprintf("/api/queues%s?page=1&use_regex=true&pagination=false&name=%%5Eevaluate_trials%%24&page_size=%d", getExpectedVhost(testData.queueInfo.vhostPath), testData.pageSize)
if r.RequestURI != expectedPath {
t.Error("Expect request path to =", expectedPath, "but it is", r.RequestURI)
}
Expand Down Expand Up @@ -575,7 +565,7 @@ var testRegexQueueInfoNavigationTestData = []getQueueInfoNavigationTestData{
func TestRegexQueueMissingError(t *testing.T) {
for _, testData := range testRegexQueueInfoNavigationTestData {
var apiStub = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
expectedPath := "/api/queues?page=1&use_regex=true&pagination=false&name=evaluate_trials&page_size=100"
expectedPath := "/api/queues/%2F?page=1&use_regex=true&pagination=false&name=evaluate_trials&page_size=100"
if r.RequestURI != expectedPath {
t.Error("Expect request path to =", expectedPath, "but it is", r.RequestURI)
}
Expand Down Expand Up @@ -619,3 +609,18 @@ func TestRegexQueueMissingError(t *testing.T) {
}
}
}

func getExpectedVhost(vhostPath string) string {
switch vhostPath {
case "":
return rabbitRootVhostPath
case "/":
return rabbitRootVhostPath
case "//":
return rabbitRootVhostPath
case rabbitRootVhostPath:
return rabbitRootVhostPath
default:
return vhostPath
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ var (
queueName = "hello"
user = fmt.Sprintf("%s-user", testName)
password = fmt.Sprintf("%s-password", testName)
vhost = fmt.Sprintf("%s-vhost", testName)
connectionString = fmt.Sprintf("amqp://%s:%s@rabbitmq.%s.svc.cluster.local/%s", user, password, rmqNamespace, vhost)
vhost = "/"
connectionString = fmt.Sprintf("amqp://%s:%s@rabbitmq.%s.svc.cluster.local", user, password, rmqNamespace)
messageCount = 100
)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
//go:build e2e
// +build e2e

package rabbitmq_queue_amqp_vhost_test

import (
"encoding/base64"
"fmt"
"testing"

"github.com/joho/godotenv"
"github.com/stretchr/testify/assert"
"k8s.io/client-go/kubernetes"

. "github.com/kedacore/keda/v2/tests/helper"
. "github.com/kedacore/keda/v2/tests/scalers/rabbitmq"
)

// Load environment variables from .env file
var _ = godotenv.Load("../../.env")

const (
testName = "rmq-queue-amqp-vhost-test"
)

var (
testNamespace = fmt.Sprintf("%s-ns", testName)
rmqNamespace = fmt.Sprintf("%s-rmq", testName)
deploymentName = fmt.Sprintf("%s-deployment", testName)
secretName = fmt.Sprintf("%s-secret", testName)
scaledObjectName = fmt.Sprintf("%s-so", testName)
queueName = "hello"
user = fmt.Sprintf("%s-user", testName)
password = fmt.Sprintf("%s-password", testName)
vhost = fmt.Sprintf("%s-vhost", testName)
connectionString = fmt.Sprintf("amqp://%s:%s@rabbitmq.%s.svc.cluster.local/%s", user, password, rmqNamespace, vhost)
messageCount = 100
)

const (
scaledObjectTemplate = `
apiVersion: keda.sh/v1alpha1
kind: ScaledObject
metadata:
name: {{.ScaledObjectName}}
namespace: {{.TestNamespace}}
spec:
scaleTargetRef:
name: {{.DeploymentName}}
pollingInterval: 5
cooldownPeriod: 10
minReplicaCount: 0
maxReplicaCount: 4
triggers:
- type: rabbitmq
metadata:
queueName: {{.QueueName}}
hostFromEnv: RabbitApiHost
mode: QueueLength
value: '10'
activationValue: '5'
`
)

type templateData struct {
TestNamespace string
DeploymentName string
ScaledObjectName string
SecretName string
QueueName string
Connection, Base64Connection string
}

func TestScaler(t *testing.T) {
// setup
t.Log("--- setting up ---")

// Create kubernetes resources
kc := GetKubernetesClient(t)
data, templates := getTemplateData()

RMQInstall(t, kc, rmqNamespace, user, password, vhost)
CreateKubernetesResources(t, kc, testNamespace, data, templates)

assert.True(t, WaitForDeploymentReplicaReadyCount(t, kc, deploymentName, testNamespace, 0, 60, 1),
"replica count should be 0 after 1 minute")

testScaling(t, kc)

testActivationValue(t, kc)

// cleanup
t.Log("--- cleaning up ---")
DeleteKubernetesResources(t, kc, testNamespace, data, templates)
RMQUninstall(t, kc, rmqNamespace, user, password, vhost)
}

func getTemplateData() (templateData, []Template) {
return templateData{
TestNamespace: testNamespace,
DeploymentName: deploymentName,
ScaledObjectName: scaledObjectName,
SecretName: secretName,
QueueName: queueName,
Connection: connectionString,
Base64Connection: base64.StdEncoding.EncodeToString([]byte(connectionString)),
}, []Template{
{Name: "deploymentTemplate", Config: RMQTargetDeploymentTemplate},
{Name: "scaledObjectTemplate", Config: scaledObjectTemplate},
}
}

func testScaling(t *testing.T, kc *kubernetes.Clientset) {
t.Log("--- testing scale up ---")
RMQPublishMessages(t, rmqNamespace, connectionString, queueName, messageCount)
assert.True(t, WaitForDeploymentReplicaReadyCount(t, kc, deploymentName, testNamespace, 4, 60, 1),
"replica count should be 4 after 1 minute")

t.Log("--- testing scale down ---")
assert.True(t, WaitForDeploymentReplicaReadyCount(t, kc, deploymentName, testNamespace, 0, 60, 1),
"replica count should be 0 after 1 minute")
}

func testActivationValue(t *testing.T, kc *kubernetes.Clientset) {
t.Log("--- testing activation value ---")
messagesToQueue := 3
RMQPublishMessages(t, rmqNamespace, connectionString, queueName, messagesToQueue)

AssertReplicaCountNotChangeDuringTimePeriod(t, kc, deploymentName, testNamespace, 0, 60)
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ var (
queueName = "hello"
user = fmt.Sprintf("%s-user", testName)
password = fmt.Sprintf("%s-password", testName)
vhost = fmt.Sprintf("%s-vhost", testName)
connectionString = fmt.Sprintf("amqp://%s:%s@rabbitmq.%s.svc.cluster.local/%s", user, password, rmqNamespace, vhost)
httpConnectionString = fmt.Sprintf("http://%s:%s@rabbitmq.%s.svc.cluster.local/%s", user, password, rmqNamespace, vhost)
vhost = "/"
connectionString = fmt.Sprintf("amqp://%s:%s@rabbitmq.%s.svc.cluster.local/", user, password, rmqNamespace)
httpConnectionString = fmt.Sprintf("http://%s:%s@rabbitmq.%s.svc.cluster.local/", user, password, rmqNamespace)
messageCount = 100
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ var (
queueRegex = "^hell.{1}$"
user = fmt.Sprintf("%s-user", testName)
password = fmt.Sprintf("%s-password", testName)
vhost = fmt.Sprintf("%s-vhost", testName)
connectionString = fmt.Sprintf("amqp://%s:%s@rabbitmq.%s.svc.cluster.local/%s", user, password, rmqNamespace, vhost)
httpConnectionString = fmt.Sprintf("http://%s:%s@rabbitmq.%s.svc.cluster.local/%s", user, password, rmqNamespace, vhost)
vhost = "/"
connectionString = fmt.Sprintf("amqp://%s:%s@rabbitmq.%s.svc.cluster.local/", user, password, rmqNamespace)
httpConnectionString = fmt.Sprintf("http://%s:%s@rabbitmq.%s.svc.cluster.local/", user, password, rmqNamespace)
messageCount = 100
)

Expand Down
Loading

0 comments on commit f5b7234

Please sign in to comment.