diff --git a/deploy/alpha-stocks-dev/deployments/gateway-deployment.yaml b/deploy/alpha-stocks-dev/deployments/gateway-deployment.yaml index 5a00b76..405b8a8 100644 --- a/deploy/alpha-stocks-dev/deployments/gateway-deployment.yaml +++ b/deploy/alpha-stocks-dev/deployments/gateway-deployment.yaml @@ -4,7 +4,7 @@ metadata: name: gateway namespace: alpha-stocks-dev spec: - replicas: 3 + replicas: 1 selector: matchLabels: app: gateway diff --git a/deploy/alpha-stocks-dev/deployments/order-deployment.yaml b/deploy/alpha-stocks-dev/deployments/order-deployment.yaml index 768505e..d37aeb7 100644 --- a/deploy/alpha-stocks-dev/deployments/order-deployment.yaml +++ b/deploy/alpha-stocks-dev/deployments/order-deployment.yaml @@ -4,7 +4,7 @@ metadata: name: order namespace: alpha-stocks-dev spec: - replicas: 2 + replicas: 1 selector: matchLabels: app: order diff --git a/deploy/alpha-stocks-dev/deployments/stocks-deployment.yaml b/deploy/alpha-stocks-dev/deployments/stocks-deployment.yaml index c4af0e0..d161a08 100644 --- a/deploy/alpha-stocks-dev/deployments/stocks-deployment.yaml +++ b/deploy/alpha-stocks-dev/deployments/stocks-deployment.yaml @@ -4,7 +4,7 @@ metadata: name: stocks namespace: alpha-stocks-dev spec: - replicas: 2 + replicas: 1 selector: matchLabels: app: stocks @@ -12,7 +12,7 @@ spec: metadata: annotations: tratteria/inject-sidecar: "true" - tratteria/service-port: "8070" + tratteria/agent-mode: "delegation" labels: app: stocks spec: @@ -29,6 +29,8 @@ spec: value: spiffe://dev.alphastocks.com/order - name: STOCKS_SERVICE_SPIFFE_ID value: spiffe://dev.alphastocks.com/stocks + - name: TRAT_VERIFY_ENDPOINT + value: http://localhost:9030/verify-trat image: ghcr.io/tratteria/example-application/stocks:latest name: stocks ports: diff --git a/deploy/alpha-stocks-dev/tratteria/kubernetes/tratteria-deployment.yaml b/deploy/alpha-stocks-dev/tratteria/kubernetes/tratteria-deployment.yaml index 59a8162..dfc621d 100644 --- a/deploy/alpha-stocks-dev/tratteria/kubernetes/tratteria-deployment.yaml +++ b/deploy/alpha-stocks-dev/tratteria/kubernetes/tratteria-deployment.yaml @@ -4,7 +4,7 @@ metadata: name: tratteria namespace: alpha-stocks-dev # Replace [your-namespace] with your Kubernetes namespace spec: - replicas: 3 + replicas: 1 selector: matchLabels: app: tratteria diff --git a/deploy/tconfigd/installation/README.md b/deploy/tconfigd/installation/README.md index 892a410..c4ea12a 100644 --- a/deploy/tconfigd/installation/README.md +++ b/deploy/tconfigd/installation/README.md @@ -94,9 +94,9 @@ Update the `config.yaml` file to match your specific deployment settings: - **Description**: Host directory where the SPIRE agent's socket resides. Update this value if it is different in your SPIRE installation. - `tratteriaSpiffeId`: `"spiffe://[your-trust-domain]/tratteria"` - **Description**: SPIFFE ID used to register [tratteria service](https://github.com/tratteria/tratteria), an open source Transaction Tokens (TraTs) Service. - - `agentHttpApiPort`: "`9030`" - - **Description**: Port number for the tratteria agent HTTP APIs. Do not change this unless you have some specific need. - - `agentInterceptorPort`: "`9050`" + - `agentApiPort`: "`9030`" + - **Description**: Port number for the tratteria agent APIs. Do not change this unless you have some specific need. + - `agentInterceptorPort`: "`9050`" - **Description**: The port number for the tratteria agent's incoming requests interceptor. Do not change this unless you have some specific need. diff --git a/deploy/tconfigd/installation/config.yaml b/deploy/tconfigd/installation/config.yaml index f7230ab..d4bdfad 100644 --- a/deploy/tconfigd/installation/config.yaml +++ b/deploy/tconfigd/installation/config.yaml @@ -1,5 +1,5 @@ enableTratInterception: "true" # Enable or disable incoming requests interception for TraT verification spireAgentHostDir: "/run/spire/sockets" # Host directory where the SPIRE agent's socket resides; update this if different in your environment tratteriaSpiffeId: "spiffe://dev.alphastocks.com/tratteria" # Replace "trust-domain" with your trust domain -agentHttpApiPort: "9030" # Port number for the tratteria agent HTTP APIs +agentApiPort: "9030" # Port number for the tratteria agent APIs agentInterceptorPort: "9050" # Port number for the tratteria agent incoming requests interceptor diff --git a/stocks/cmd/main.go b/stocks/cmd/main.go index 4eb6edf..c29b628 100644 --- a/stocks/cmd/main.go +++ b/stocks/cmd/main.go @@ -22,6 +22,7 @@ type App struct { DB *sql.DB Config *config.StocksConfig SpireJwtSource *workloadapi.JWTSource + HTTPClient *http.Client Logger *zap.Logger } @@ -60,10 +61,11 @@ func main() { DB: db, Config: appConfig, SpireJwtSource: spireJwtSource, + HTTPClient: &http.Client{}, Logger: logger, } - middleware := middleware.GetMiddleware(appConfig, app.SpireJwtSource, app.Logger) + middleware := middleware.GetMiddleware(appConfig, app.SpireJwtSource, app.HTTPClient, app.Logger) app.Router.Use(middleware) diff --git a/stocks/pkg/config/config.go b/stocks/pkg/config/config.go index 184aebb..72f9c2a 100644 --- a/stocks/pkg/config/config.go +++ b/stocks/pkg/config/config.go @@ -3,6 +3,7 @@ package config import ( "context" "fmt" + "net/url" "os" "time" @@ -19,7 +20,8 @@ type spiffeIDs struct { } type StocksConfig struct { - SpiffeIDs *spiffeIDs + SpiffeIDs *spiffeIDs + TratVerifyEndpoint *url.URL } func GetAppConfig() *StocksConfig { @@ -29,6 +31,7 @@ func GetAppConfig() *StocksConfig { Order: spiffeid.RequireFromString(getEnv("ORDER_SERVICE_SPIFFE_ID")), Stocks: spiffeid.RequireFromString(getEnv("STOCKS_SERVICE_SPIFFE_ID")), }, + TratVerifyEndpoint: parseURL(getEnv("TRAT_VERIFY_ENDPOINT")), } } @@ -52,3 +55,12 @@ func getEnv(key string) string { return value } + +func parseURL(rawurl string) *url.URL { + parsedURL, err := url.Parse(rawurl) + if err != nil { + panic(fmt.Sprintf("Error parsing URL %s: %v", rawurl, err)) + } + + return parsedURL +} diff --git a/stocks/pkg/middleware/middleware.go b/stocks/pkg/middleware/middleware.go index 8553ab8..0b6e919 100644 --- a/stocks/pkg/middleware/middleware.go +++ b/stocks/pkg/middleware/middleware.go @@ -19,10 +19,12 @@ func CombineMiddleware(middleware ...func(http.Handler) http.Handler) func(http. } } -func GetMiddleware(stocksConfig *config.StocksConfig, spireJwtSource *workloadapi.JWTSource, logger *zap.Logger) func(http.Handler) http.Handler { +func GetMiddleware(stocksConfig *config.StocksConfig, spireJwtSource *workloadapi.JWTSource, httpClient *http.Client, logger *zap.Logger) func(http.Handler) http.Handler { middlewareList := []func(http.Handler) http.Handler{} middlewareList = append(middlewareList, getSpiffeMiddleware(stocksConfig, spireJwtSource, logger)) + middlewareList = append(middlewareList, getTraTVerifierMiddleware(stocksConfig.TratVerifyEndpoint, httpClient, logger)) + return CombineMiddleware(middlewareList...) } diff --git a/stocks/pkg/middleware/tratverifiermiddleware.go b/stocks/pkg/middleware/tratverifiermiddleware.go new file mode 100644 index 0000000..9644083 --- /dev/null +++ b/stocks/pkg/middleware/tratverifiermiddleware.go @@ -0,0 +1,178 @@ +package middleware + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + + "go.uber.org/zap" +) + +type VerifyTraTRequest struct { + Path string `json:"endpoint"` + Method string `json:"method"` + QueryParameters json.RawMessage `json:"queryParameters"` + Headers json.RawMessage `json:"headers"` + Body json.RawMessage `json:"body"` +} + +type VerifyTraTResponse struct { + Valid bool `json:"valid"` + Reason string `json:"reason"` +} + +func readAndReplaceBody(r *http.Request) (json.RawMessage, error) { + if r.Body == nil { + return []byte("{}"), nil + } + + data, err := io.ReadAll(r.Body) + if err != nil { + return nil, err + } + + r.Body.Close() + + r.Body = io.NopCloser(bytes.NewBuffer(data)) + + if len(data) == 0 { + return []byte("{}"), nil + } + + return data, nil +} + +func convertMapToJson(data map[string]string) (json.RawMessage, error) { + bytes, err := json.Marshal(data) + if err != nil { + return nil, err + } + + if len(bytes) == 0 { + bytes = []byte("{}") + } + + return json.RawMessage(bytes), nil +} + +// TODO: handle keys with multiple values. +func convertHeaderToJson(headers http.Header) (json.RawMessage, error) { + headerMap := make(map[string]string) + for key, values := range headers { + headerMap[key] = values[0] + } + + return convertMapToJson(headerMap) +} + +func getVerifyTraTRequest(r *http.Request) (*VerifyTraTRequest, error) { + body, err := readAndReplaceBody(r) + if err != nil { + return nil, fmt.Errorf("error reading request body: %w", err) + } + + headersJson, err := convertHeaderToJson(r.Header) + if err != nil { + return nil, fmt.Errorf("error reading request header: %w", err) + } + + //TODO: handle keys with multiple values. + queryParams := make(map[string]string) + for key, values := range r.URL.Query() { + queryParams[key] = values[0] + } + + queryParamsJson, err := convertMapToJson(queryParams) + if err != nil { + return nil, fmt.Errorf("error reading query parameters: %w", err) + } + + details := &VerifyTraTRequest{ + Path: r.URL.Path, + Method: r.Method, + QueryParameters: queryParamsJson, + Headers: headersJson, + Body: body, + } + + return details, nil +} + +func getTraTVerifierMiddleware(traTVerifierEndpoint *url.URL, httpClient *http.Client, logger *zap.Logger) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + verifyTraTRequest, err := getVerifyTraTRequest(r) + if err != nil { + logger.Error("Error creating verify trat request.", zap.Error(err)) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + + return + } + + requestBody, err := json.Marshal(verifyTraTRequest) + if err != nil { + logger.Error("Error marshalling verify trat request to JSON.", zap.Error(err)) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + + return + } + + req, err := http.NewRequest(http.MethodPost, traTVerifierEndpoint.String(), bytes.NewBuffer(requestBody)) + if err != nil { + logger.Error("Error creating trat verification request.", zap.Error(err)) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + + return + } + + req.Header.Set("Content-Type", "application/json") + + resp, err := httpClient.Do(req) + if err != nil { + logger.Error("Error sending trat verification request.", zap.Error(err)) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + + return + } + + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + logger.Error("Received non-OK response on trat verification.", zap.Int("status_code", resp.StatusCode)) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + + return + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + logger.Error("Error reading trat verification response body.", zap.Error(err)) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + + return + } + + var verifyResponse VerifyTraTResponse + + err = json.Unmarshal(body, &verifyResponse) + if err != nil { + logger.Error("Error unmarshalling trat verification response body.", zap.Error(err)) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + + return + } + + if !verifyResponse.Valid { + logger.Info("Invalid trat.", zap.String("reason", verifyResponse.Reason)) + http.Error(w, "Invalid trat", http.StatusForbidden) + + return + } + + next.ServeHTTP(w, r) + }) + } +} diff --git a/stocks/pkg/service/service.go b/stocks/pkg/service/service.go index 4353121..bbe31d5 100644 --- a/stocks/pkg/service/service.go +++ b/stocks/pkg/service/service.go @@ -230,6 +230,7 @@ func (s *Service) GetUserHoldings(username string) (Holdings, error) { for rows.Next() { var holding Holding + var currentPrice float64 if err := rows.Scan(&holding.StockID, &holding.StockSymbol, &holding.Quantity, ¤tPrice); err != nil {