Skip to content

Commit

Permalink
Use delegation method for TraT verification in stocks service
Browse files Browse the repository at this point in the history
  • Loading branch information
kchiranjewee63 committed Aug 5, 2024
1 parent 4544ab3 commit 3263038
Show file tree
Hide file tree
Showing 11 changed files with 209 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ metadata:
name: gateway
namespace: alpha-stocks-dev
spec:
replicas: 3
replicas: 1
selector:
matchLabels:
app: gateway
Expand Down
2 changes: 1 addition & 1 deletion deploy/alpha-stocks-dev/deployments/order-deployment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ metadata:
name: order
namespace: alpha-stocks-dev
spec:
replicas: 2
replicas: 1
selector:
matchLabels:
app: order
Expand Down
6 changes: 4 additions & 2 deletions deploy/alpha-stocks-dev/deployments/stocks-deployment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@ metadata:
name: stocks
namespace: alpha-stocks-dev
spec:
replicas: 2
replicas: 1
selector:
matchLabels:
app: stocks
template:
metadata:
annotations:
tratteria/inject-sidecar: "true"
tratteria/service-port: "8070"
tratteria/agent-mode: "delegation"
labels:
app: stocks
spec:
Expand All @@ -29,6 +29,8 @@ spec:
value: spiffe://dev.alphastocks.com/order
- name: STOCKS_SERVICE_SPIFFE_ID
value: spiffe://dev.alphastocks.com/stocks
- name: TRAT_VERIFY_ENDPOINT
value: http://localhost:9030/verify-trat
image: ghcr.io/tratteria/example-application/stocks:latest
name: stocks
ports:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ metadata:
name: tratteria
namespace: alpha-stocks-dev # Replace [your-namespace] with your Kubernetes namespace
spec:
replicas: 3
replicas: 1
selector:
matchLabels:
app: tratteria
Expand Down
6 changes: 3 additions & 3 deletions deploy/tconfigd/installation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ Update the `config.yaml` file to match your specific deployment settings:
- **Description**: Host directory where the SPIRE agent's socket resides. Update this value if it is different in your SPIRE installation.
- `tratteriaSpiffeId`: `"spiffe://[your-trust-domain]/tratteria"`
- **Description**: SPIFFE ID used to register [tratteria service](https://github.com/tratteria/tratteria), an open source Transaction Tokens (TraTs) Service.
- `agentHttpApiPort`: "`9030`"
- **Description**: Port number for the tratteria agent HTTP APIs. Do not change this unless you have some specific need.
- `agentInterceptorPort`: "`9050`"
- `agentApiPort`: "`9030`"
- **Description**: Port number for the tratteria agent APIs. Do not change this unless you have some specific need.
- `agentInterceptorPort`: "`9050`"
- **Description**: The port number for the tratteria agent's incoming requests interceptor. Do not change this unless you have some specific need.


Expand Down
2 changes: 1 addition & 1 deletion deploy/tconfigd/installation/config.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
enableTratInterception: "true" # Enable or disable incoming requests interception for TraT verification
spireAgentHostDir: "/run/spire/sockets" # Host directory where the SPIRE agent's socket resides; update this if different in your environment
tratteriaSpiffeId: "spiffe://dev.alphastocks.com/tratteria" # Replace "trust-domain" with your trust domain
agentHttpApiPort: "9030" # Port number for the tratteria agent HTTP APIs
agentApiPort: "9030" # Port number for the tratteria agent APIs
agentInterceptorPort: "9050" # Port number for the tratteria agent incoming requests interceptor
4 changes: 3 additions & 1 deletion stocks/cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ type App struct {
DB *sql.DB
Config *config.StocksConfig
SpireJwtSource *workloadapi.JWTSource
HTTPClient *http.Client
Logger *zap.Logger
}

Expand Down Expand Up @@ -60,10 +61,11 @@ func main() {
DB: db,
Config: appConfig,
SpireJwtSource: spireJwtSource,
HTTPClient: &http.Client{},
Logger: logger,
}

middleware := middleware.GetMiddleware(appConfig, app.SpireJwtSource, app.Logger)
middleware := middleware.GetMiddleware(appConfig, app.SpireJwtSource, app.HTTPClient, app.Logger)

app.Router.Use(middleware)

Expand Down
14 changes: 13 additions & 1 deletion stocks/pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package config
import (
"context"
"fmt"
"net/url"
"os"
"time"

Expand All @@ -19,7 +20,8 @@ type spiffeIDs struct {
}

type StocksConfig struct {
SpiffeIDs *spiffeIDs
SpiffeIDs *spiffeIDs
TratVerifyEndpoint *url.URL
}

func GetAppConfig() *StocksConfig {
Expand All @@ -29,6 +31,7 @@ func GetAppConfig() *StocksConfig {
Order: spiffeid.RequireFromString(getEnv("ORDER_SERVICE_SPIFFE_ID")),
Stocks: spiffeid.RequireFromString(getEnv("STOCKS_SERVICE_SPIFFE_ID")),
},
TratVerifyEndpoint: parseURL(getEnv("TRAT_VERIFY_ENDPOINT")),
}
}

Expand All @@ -52,3 +55,12 @@ func getEnv(key string) string {

return value
}

func parseURL(rawurl string) *url.URL {
parsedURL, err := url.Parse(rawurl)
if err != nil {
panic(fmt.Sprintf("Error parsing URL %s: %v", rawurl, err))
}

return parsedURL
}
4 changes: 3 additions & 1 deletion stocks/pkg/middleware/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ func CombineMiddleware(middleware ...func(http.Handler) http.Handler) func(http.
}
}

func GetMiddleware(stocksConfig *config.StocksConfig, spireJwtSource *workloadapi.JWTSource, logger *zap.Logger) func(http.Handler) http.Handler {
func GetMiddleware(stocksConfig *config.StocksConfig, spireJwtSource *workloadapi.JWTSource, httpClient *http.Client, logger *zap.Logger) func(http.Handler) http.Handler {
middlewareList := []func(http.Handler) http.Handler{}

middlewareList = append(middlewareList, getSpiffeMiddleware(stocksConfig, spireJwtSource, logger))

middlewareList = append(middlewareList, getTraTVerifierMiddleware(stocksConfig.TratVerifyEndpoint, httpClient, logger))

return CombineMiddleware(middlewareList...)
}
178 changes: 178 additions & 0 deletions stocks/pkg/middleware/tratverifiermiddleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
package middleware

import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"

"go.uber.org/zap"
)

type VerifyTraTRequest struct {
Path string `json:"endpoint"`
Method string `json:"method"`
QueryParameters json.RawMessage `json:"queryParameters"`
Headers json.RawMessage `json:"headers"`
Body json.RawMessage `json:"body"`
}

type VerifyTraTResponse struct {
Valid bool `json:"valid"`
Reason string `json:"reason"`
}

func readAndReplaceBody(r *http.Request) (json.RawMessage, error) {
if r.Body == nil {
return []byte("{}"), nil
}

data, err := io.ReadAll(r.Body)
if err != nil {
return nil, err
}

r.Body.Close()

r.Body = io.NopCloser(bytes.NewBuffer(data))

if len(data) == 0 {
return []byte("{}"), nil
}

return data, nil
}

func convertMapToJson(data map[string]string) (json.RawMessage, error) {
bytes, err := json.Marshal(data)
if err != nil {
return nil, err
}

if len(bytes) == 0 {
bytes = []byte("{}")
}

return json.RawMessage(bytes), nil
}

// TODO: handle keys with multiple values.
func convertHeaderToJson(headers http.Header) (json.RawMessage, error) {
headerMap := make(map[string]string)
for key, values := range headers {
headerMap[key] = values[0]
}

return convertMapToJson(headerMap)
}

func getVerifyTraTRequest(r *http.Request) (*VerifyTraTRequest, error) {
body, err := readAndReplaceBody(r)
if err != nil {
return nil, fmt.Errorf("error reading request body: %w", err)
}

headersJson, err := convertHeaderToJson(r.Header)
if err != nil {
return nil, fmt.Errorf("error reading request header: %w", err)
}

//TODO: handle keys with multiple values.
queryParams := make(map[string]string)
for key, values := range r.URL.Query() {
queryParams[key] = values[0]
}

queryParamsJson, err := convertMapToJson(queryParams)
if err != nil {
return nil, fmt.Errorf("error reading query parameters: %w", err)
}

details := &VerifyTraTRequest{
Path: r.URL.Path,
Method: r.Method,
QueryParameters: queryParamsJson,
Headers: headersJson,
Body: body,
}

return details, nil
}

func getTraTVerifierMiddleware(traTVerifierEndpoint *url.URL, httpClient *http.Client, logger *zap.Logger) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
verifyTraTRequest, err := getVerifyTraTRequest(r)
if err != nil {
logger.Error("Error creating verify trat request.", zap.Error(err))
http.Error(w, "Internal Server Error", http.StatusInternalServerError)

return
}

requestBody, err := json.Marshal(verifyTraTRequest)
if err != nil {
logger.Error("Error marshalling verify trat request to JSON.", zap.Error(err))
http.Error(w, "Internal Server Error", http.StatusInternalServerError)

return
}

req, err := http.NewRequest(http.MethodPost, traTVerifierEndpoint.String(), bytes.NewBuffer(requestBody))
if err != nil {
logger.Error("Error creating trat verification request.", zap.Error(err))
http.Error(w, "Internal Server Error", http.StatusInternalServerError)

return
}

req.Header.Set("Content-Type", "application/json")

resp, err := httpClient.Do(req)
if err != nil {
logger.Error("Error sending trat verification request.", zap.Error(err))
http.Error(w, "Internal Server Error", http.StatusInternalServerError)

return
}

defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
logger.Error("Received non-OK response on trat verification.", zap.Int("status_code", resp.StatusCode))
http.Error(w, "Internal Server Error", http.StatusInternalServerError)

return
}

body, err := io.ReadAll(resp.Body)
if err != nil {
logger.Error("Error reading trat verification response body.", zap.Error(err))
http.Error(w, "Internal Server Error", http.StatusInternalServerError)

return
}

var verifyResponse VerifyTraTResponse

err = json.Unmarshal(body, &verifyResponse)
if err != nil {
logger.Error("Error unmarshalling trat verification response body.", zap.Error(err))
http.Error(w, "Internal Server Error", http.StatusInternalServerError)

return
}

if !verifyResponse.Valid {
logger.Info("Invalid trat.", zap.String("reason", verifyResponse.Reason))
http.Error(w, "Invalid trat", http.StatusForbidden)

return
}

next.ServeHTTP(w, r)
})
}
}
1 change: 1 addition & 0 deletions stocks/pkg/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ func (s *Service) GetUserHoldings(username string) (Holdings, error) {

for rows.Next() {
var holding Holding

var currentPrice float64

if err := rows.Scan(&holding.StockID, &holding.StockSymbol, &holding.Quantity, &currentPrice); err != nil {
Expand Down

0 comments on commit 3263038

Please sign in to comment.