Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions pkg/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/sosedoff/pgweb/pkg/bookmarks"
"github.com/sosedoff/pgweb/pkg/client"
"github.com/sosedoff/pgweb/pkg/command"
"github.com/sosedoff/pgweb/pkg/connect"
"github.com/sosedoff/pgweb/pkg/connection"
"github.com/sosedoff/pgweb/pkg/metrics"
"github.com/sosedoff/pgweb/pkg/queries"
Expand Down Expand Up @@ -92,18 +93,18 @@ func GetSessions(c *gin.Context) {

// ConnectWithBackend creates a new connection based on backend resource
func ConnectWithBackend(c *gin.Context) {
// Setup a new backend client
backend := Backend{
Endpoint: command.Opts.ConnectBackend,
Token: command.Opts.ConnectToken,
PassHeaders: strings.Split(command.Opts.ConnectHeaders, ","),
backend := connect.NewBackend(command.Opts.ConnectBackend, command.Opts.ConnectToken)
backend.SetLogger(logger)

if command.Opts.ConnectHeaders != "" {
backend.SetPassHeaders(strings.Split(command.Opts.ConnectHeaders, ","))
}

ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()

// Fetch connection credentials
cred, err := backend.FetchCredential(ctx, c.Param("resource"), c)
cred, err := backend.FetchCredential(ctx, c.Param("resource"), c.Request.Header)
if err != nil {
badRequest(c, err)
return
Expand Down
87 changes: 0 additions & 87 deletions pkg/api/backend.go

This file was deleted.

2 changes: 0 additions & 2 deletions pkg/api/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,10 @@ import (
var (
errNotConnected = errors.New("Not connected")
errNotPermitted = errors.New("Not permitted")
errConnStringRequired = errors.New("Connection string is required")
errInvalidConnString = errors.New("Invalid connection string")
errSessionRequired = errors.New("Session ID is required")
errSessionLocked = errors.New("Session is locked")
errURLRequired = errors.New("URL parameter is required")
errQueryRequired = errors.New("Query parameter is required")
errDatabaseNameRequired = errors.New("Database name is required")
errBackendConnectError = errors.New("Unable to connect to the auth backend")
)
92 changes: 92 additions & 0 deletions pkg/connect/backend.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package connect

import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"strings"

"github.com/sirupsen/logrus"
)

type Backend struct {
Endpoint string
Token string
PassHeaders []string

logger *logrus.Logger
}

func NewBackend(endpoint string, token string) Backend {
return Backend{
Endpoint: endpoint,
Token: token,
logger: logrus.StandardLogger(),
}
}

func (be *Backend) SetLogger(logger *logrus.Logger) {
be.logger = logger
}

func (be *Backend) SetPassHeaders(headers []string) {
be.PassHeaders = headers
}

func (be *Backend) FetchCredential(ctx context.Context, resource string, headers http.Header) (*Credential, error) {
be.logger.WithField("resource", resource).Debug("fetching database credential")

request := Request{
Resource: resource,
Token: be.Token,
Headers: map[string]string{},
}

// Pass allow-listed client headers to the backend request
for _, name := range be.PassHeaders {
request.Headers[strings.ToLower(name)] = headers.Get(name)
}

body, err := json.Marshal(request)
if err != nil {
be.logger.WithField("resource", resource).Error("backend request serialization error:", err)
return nil, err
}

req, err := http.NewRequestWithContext(ctx, http.MethodPost, be.Endpoint, bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Set("content-type", "application/json")

resp, err := http.DefaultClient.Do(req)
if err != nil {
be.logger.WithField("resource", resource).Error("backend credential fetch failed:", err)
return nil, errBackendConnectError
}
defer resp.Body.Close()

if resp.StatusCode != 200 {
err = fmt.Errorf("backend credential fetch received HTTP status code %v", resp.StatusCode)

be.logger.
WithField("resource", request.Resource).
WithField("status", resp.StatusCode).
Error(err)

return nil, err
}

cred := &Credential{}
if err := json.NewDecoder(resp.Body).Decode(cred); err != nil {
return nil, err
}

if cred.DatabaseURL == "" {
return nil, errConnStringRequired
}

return cred, nil
}
36 changes: 14 additions & 22 deletions pkg/api/backend_test.go → pkg/connect/backend_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package api
package connect

import (
"context"
Expand All @@ -9,6 +9,7 @@ import (
"time"

"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
)

Expand All @@ -17,8 +18,8 @@ func TestBackendFetchCredential(t *testing.T) {
name string
backend Backend
resourceName string
cred *BackendCredential
reqCtx *gin.Context
cred *Credential
headers http.Header
ctx func() (context.Context, context.CancelFunc)
err error
}{
Expand All @@ -33,12 +34,12 @@ func TestBackendFetchCredential(t *testing.T) {
ctx: func() (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), time.Millisecond*100)
},
err: errors.New("Unable to connect to the auth backend"),
err: errors.New("unable to connect to the auth backend"),
},
{
name: "Empty response",
backend: Backend{Endpoint: "http://localhost:5555/empty-response"},
err: errors.New("Connection string is required"),
err: errors.New("connection string is required"),
},
{
name: "Missing header",
Expand All @@ -51,19 +52,15 @@ func TestBackendFetchCredential(t *testing.T) {
Endpoint: "http://localhost:5555/pass-header",
PassHeaders: []string{"x-foo"},
},
reqCtx: &gin.Context{
Request: &http.Request{
Header: http.Header{
"X-Foo": []string{"bar"},
},
},
headers: http.Header{
"X-Foo": []string{"bar"},
},
cred: &BackendCredential{DatabaseURL: "postgres://hostname/bar"},
cred: &Credential{DatabaseURL: "postgres://hostname/bar"},
},
{
name: "Success",
backend: Backend{Endpoint: "http://localhost:5555/success"},
cred: &BackendCredential{DatabaseURL: "postgres://hostname/dbname"},
cred: &Credential{DatabaseURL: "postgres://hostname/dbname"},
},
}

Expand All @@ -73,21 +70,16 @@ func TestBackendFetchCredential(t *testing.T) {
startTestBackend(srvCtx, "localhost:5555")

for _, ex := range examples {
ex.backend.logger = logrus.StandardLogger()

t.Run(ex.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
if ex.ctx != nil {
ctx, cancel = ex.ctx()
}
defer cancel()

reqCtx := ex.reqCtx
if reqCtx == nil {
reqCtx = &gin.Context{
Request: &http.Request{},
}
}

cred, err := ex.backend.FetchCredential(ctx, ex.resourceName, reqCtx)
cred, err := ex.backend.FetchCredential(ctx, ex.resourceName, ex.headers)
assert.Equal(t, ex.err, err)
assert.Equal(t, ex.cred, cred)
})
Expand Down Expand Up @@ -117,7 +109,7 @@ func startTestBackend(ctx context.Context, listenAddr string) {
})

router.POST("/pass-header", func(c *gin.Context) {
req := BackendRequest{}
req := Request{}
if err := c.BindJSON(&req); err != nil {
panic(err)
}
Expand Down
20 changes: 20 additions & 0 deletions pkg/connect/types.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package connect

import "errors"

var (
errBackendConnectError = errors.New("unable to connect to the auth backend")
errConnStringRequired = errors.New("connection string is required")
)

// Request holds the resource request details
type Request struct {
Resource string `json:"resource"`
Token string `json:"token"`
Headers map[string]string `json:"headers,omitempty"`
}

// Credential holds the database connection string
type Credential struct {
DatabaseURL string `json:"database_url"`
}
Loading