Skip to content

Commit

Permalink
Support service_account authentication for both staging/prod environm…
Browse files Browse the repository at this point in the history
…ent + refactoring (#102)

Co-authored-by: James Kwon <[email protected]>
  • Loading branch information
james03160927 and james03160927 authored Dec 31, 2024
1 parent 67b2b18 commit 3bd0ea4
Show file tree
Hide file tree
Showing 6 changed files with 182 additions and 96 deletions.
3 changes: 3 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,7 @@ type Config struct {
JWTSecret string
DiscordSecurityChannelWebhook string
SecretScannerURL string
IDTokenAudience string
AlgoliaAppID string
AlgoliaAPIKey string
}
107 changes: 89 additions & 18 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,53 +4,124 @@ import (
"context"
"fmt"
"os"
"os/signal"
"registry-backend/config"
"registry-backend/ent"
"registry-backend/ent/migrate"
drip_logging "registry-backend/logging"
"registry-backend/server"
"syscall"

"github.com/rs/zerolog/log"

_ "github.com/lib/pq"
)

// validateEnvVars ensures all required environment variables are set based on the environment.
func validateEnvVars(env string) {
// Variables mandatory for all environments
mandatoryVars := []string{
"DB_CONNECTION_STRING",
"PROJECT_ID",
"DRIP_ENV",
"JWT_SECRET",
}

// Additional variables mandatory for production and staging
prodStagingVars := []string{
"SLACK_REGISTRY_CHANNEL_WEBHOOK",
"SECRET_SCANNER_URL",
"SECURITY_COUNCIL_DISCORD_WEBHOOK",
"ALGOLIA_APP_ID",
"ALGOLIA_API_KEY",
"ID_TOKEN_AUDIENCE",
}

// Add production and staging-specific variables
if env == "prod" || env == "staging" {
mandatoryVars = append(mandatoryVars, prodStagingVars...)
}

// Validate that all mandatory environment variables are set
missingVars := []string{}
for _, key := range mandatoryVars {
if os.Getenv(key) == "" {
missingVars = append(missingVars, key)
}
}

// Log and terminate if mandatory variables are missing
if len(missingVars) > 0 {
log.Fatal().Msgf("Missing mandatory environment variables for '%s': %v", env, missingVars)
}
}

func main() {
// Retrieve the current environment
dripEnv := os.Getenv("DRIP_ENV")
if dripEnv == "" {
log.Fatal().Msg("Environment variable DRIP_ENV is not set.")
}

// Validate environment variables based on the current environment
validateEnvVars(dripEnv)

// Set global log level based on the LOG_LEVEL environment variable
drip_logging.SetGlobalLogLevel(os.Getenv("LOG_LEVEL"))

connection_string := os.Getenv("DB_CONNECTION_STRING")
// Retrieve the database connection string
connectionString := os.Getenv("DB_CONNECTION_STRING")

config := config.Config{
// Build the application configuration
appConfig := config.Config{
ProjectID: os.Getenv("PROJECT_ID"),
DripEnv: os.Getenv("DRIP_ENV"),
DripEnv: dripEnv,
SlackRegistryChannelWebhook: os.Getenv("SLACK_REGISTRY_CHANNEL_WEBHOOK"),
JWTSecret: os.Getenv("JWT_SECRET"),
SecretScannerURL: os.Getenv("SECRET_SCANNER_URL"),
DiscordSecurityChannelWebhook: os.Getenv("SECURITY_COUNCIL_DISCORD_WEBHOOK"),
AlgoliaAppID: os.Getenv("ALGOLIA_APP_ID"),
AlgoliaAPIKey: os.Getenv("ALGOLIA_API_KEY"),
IDTokenAudience: os.Getenv("ID_TOKEN_AUDIENCE"),
}

// Construct the database connection string
var dsn string
if os.Getenv("DRIP_ENV") == "localdev" {
dsn = fmt.Sprintf("%s sslmode=disable", connection_string)
if dripEnv == "localdev" {
// For local development, disable SSL for easier setup
dsn = fmt.Sprintf("%s sslmode=disable", connectionString)
} else {
dsn = connection_string
// Use the connection string as-is for non-development environments
dsn = connectionString
}

// Initialize the database client
client, err := ent.Open("postgres", dsn)

if err != nil {
log.Fatal().Err(err).Msg("failed opening connection to postgres.")
}
defer client.Close()
// Run the auto migration tool for localdev.
if os.Getenv("DRIP_ENV") == "localdev" {
log.Info().Msg("Running migrations")
if err := client.Schema.Create(context.Background(), migrate.WithDropIndex(true),
migrate.WithDropColumn(true)); err != nil {
log.Fatal().Err(err).Msg("failed creating schema resources.")
log.Fatal().Err(err).Msg("Failed to establish a connection to the PostgreSQL database.")
}
defer client.Close() // Ensure the database client is closed when the application exits

// Run database migrations in local development to keep the schema up to date
if dripEnv == "localdev" {
log.Info().Msg("Running migrations for local development.")
if err := client.Schema.Create(context.Background(), migrate.WithDropIndex(true), migrate.WithDropColumn(true)); err != nil {
log.Fatal().Err(err).Msg("Failed to create schema resources during migration.")
}
}

server := server.NewServer(client, &config)
log.Fatal().Err(server.Start()).Msg("Server stopped")
// Handle graceful shutdown
go func() {
sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
<-sigs
log.Info().Msg("Shutting down server gracefully.")
client.Close()
os.Exit(0)
}()

// Initialize and start the server
server := server.NewServer(client, &appConfig)
log.Info().Msg("Starting the server.")
log.Fatal().Err(server.Start()).Msg("Server has stopped unexpectedly.")
}
2 changes: 1 addition & 1 deletion node-pack-extract/cloudbuild.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ steps:
entrypoint: "bash"
args:
- -c
- gcloud auth print-identity-token --audiences="https://api.comfy.org" | tee /workspace/token
- gcloud auth print-identity-token --audiences="$_REGISTRY_BACKEND_URL" | tee /workspace/token

- name: "gcr.io/cloud-builders/curl"
entrypoint: "bash"
Expand Down
86 changes: 44 additions & 42 deletions run-service-prod.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,51 +14,53 @@ spec:
run.googleapis.com/minScale: '1'
spec:
containers:
- image: registry-backend-image-substitute
env:
- name: DRIP_ENV
value: prod
- name: DB_CONNECTION_STRING
valueFrom:
secretKeyRef:
key: 1
name: PROD_SUPABASE_CONNECTION_STRING
- name: JWT_SECRET
valueFrom:
secretKeyRef:
- image: registry-backend-image-substitute
env:
- name: DRIP_ENV
value: prod
- name: DB_CONNECTION_STRING
valueFrom:
secretKeyRef:
key: 1
name: PROD_SUPABASE_CONNECTION_STRING
- name: JWT_SECRET
valueFrom:
secretKeyRef:
key: 1
name: PROD_JWT_SECRET
- name: SLACK_REGISTRY_CHANNEL_WEBHOOK
valueFrom:
secretKeyRef:
key: 1
name: PROD_SLACK_REGISTRY_CHANNEL_WEBHOOK
- name: PROJECT_ID
value: dreamboothy
# TODO(robinhuang): Switch to a list of strings
- name: CORS_ORIGIN
value: https://comfyregistry.org
- name: SECRET_SCANNER_URL
valueFrom:
secretKeyRef:
- name: SLACK_REGISTRY_CHANNEL_WEBHOOK
valueFrom:
secretKeyRef:
key: 1
name: PROD_SLACK_REGISTRY_CHANNEL_WEBHOOK
- name: PROJECT_ID
value: dreamboothy
# TODO(robinhuang): Switch to a list of strings
- name: CORS_ORIGIN
value: https://comfyregistry.org
- name: SECRET_SCANNER_URL
valueFrom:
secretKeyRef:
key: 1
name: SECURITY_SCANNER_CLOUD_FUNCTION_URL
- name: SECURITY_COUNCIL_DISCORD_WEBHOOK
valueFrom:
secretKeyRef:
- name: SECURITY_COUNCIL_DISCORD_WEBHOOK
valueFrom:
secretKeyRef:
key: 1
name: SECURITY_COUNCIL_DISCORD_WEBHOOK
- name: ALGOLIA_APP_ID
valueFrom:
secretKeyRef:
key: 2
name: PROD_ALGOLIA_APP_ID
- name: ALGOLIA_API_KEY
valueFrom:
secretKeyRef:
key: 2
name: PROD_ALGOLIA_API_KEY
resources:
limits:
cpu: 4000m
memory: 2Gi
- name: ALGOLIA_APP_ID
valueFrom:
secretKeyRef:
key: 2
name: PROD_ALGOLIA_APP_ID
- name: ALGOLIA_API_KEY
valueFrom:
secretKeyRef:
key: 2
name: PROD_ALGOLIA_API_KEY
- name: ID_TOKEN_AUDIENCE
value: https://api.comfy.org
resources:
limits:
cpu: 4000m
memory: 2Gi
66 changes: 34 additions & 32 deletions run-service-staging.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,42 +15,44 @@ spec:
run.googleapis.com/startup-cpu-boost: 'false'
spec:
containers:
- image: registry-backend-image-substitute
env:
- name: DRIP_ENV
value: staging
- name: DB_CONNECTION_STRING
valueFrom:
secretKeyRef:
key: 1
name: STAGING_SUPABASE_CONNECTION_STRING
- name: JWT_SECRET
valueFrom:
secretKeyRef:
key: 1
name: STAGING_JWT_SECRET
- name: PROJECT_ID
value: dreamboothy
# TODO(robinhuang): Switch to a list of strings
- name: CORS_ORIGIN
value: https://staging.comfyregistry.org
- name: SECRET_SCANNER_URL
valueFrom:
secretKeyRef:
- image: registry-backend-image-substitute
env:
- name: DRIP_ENV
value: staging
- name: DB_CONNECTION_STRING
valueFrom:
secretKeyRef:
key: 1
name: STAGING_SUPABASE_CONNECTION_STRING
- name: JWT_SECRET
valueFrom:
secretKeyRef:
key: 1
name: STAGING_JWT_SECRET
- name: PROJECT_ID
value: dreamboothy
# TODO(robinhuang): Switch to a list of strings
- name: CORS_ORIGIN
value: https://staging.comfyregistry.org
- name: SECRET_SCANNER_URL
valueFrom:
secretKeyRef:
key: 1
name: SECURITY_SCANNER_CLOUD_FUNCTION_URL
- name: SECURITY_COUNCIL_DISCORD_WEBHOOK
valueFrom:
secretKeyRef:
- name: SECURITY_COUNCIL_DISCORD_WEBHOOK
valueFrom:
secretKeyRef:
key: 1
name: SECURITY_COUNCIL_DISCORD_WEBHOOK
- name: ALGOLIA_APP_ID
valueFrom:
secretKeyRef:
- name: ALGOLIA_APP_ID
valueFrom:
secretKeyRef:
key: 2
name: STAGING_ALGOLIA_APP_ID
- name: ALGOLIA_API_KEY
valueFrom:
secretKeyRef:
- name: ALGOLIA_API_KEY
valueFrom:
secretKeyRef:
key: 2
name: STAGING_ALGOLIA_API_KEY
name: STAGING_ALGOLIA_API_KEY
- name: ID_TOKEN_AUDIENCE
value: https://stagingapi.comfy.org
14 changes: 11 additions & 3 deletions server/middleware/authentication/service_account_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package drip_authentication

import (
"net/http"
"os"
"regexp"
"strings"

Expand Down Expand Up @@ -42,7 +43,7 @@ func ServiceAccountAuthMiddleware() echo.MiddlewareFunc {
return next(ctx)
}

// validate token
// Validate token
authHeader := ctx.Request().Header.Get("Authorization")
token := ""
if strings.HasPrefix(authHeader, "Bearer ") {
Expand All @@ -53,9 +54,16 @@ func ServiceAccountAuthMiddleware() echo.MiddlewareFunc {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing token")
}

log.Ctx(ctx.Request().Context()).Info().Msgf("Validating google id token %s for path %s and method %s", token, reqPath, reqMethod)
log.Ctx(ctx.Request().Context()).Info().Msgf("Validating Google ID token %s for path %s and method %s", token, reqPath, reqMethod)

payload, err := idtoken.Validate(ctx.Request().Context(), token, "https://api.comfy.org")
// Get the audience from the environment variable
audience := os.Getenv("ID_TOKEN_AUDIENCE")
if audience == "" {
log.Ctx(ctx.Request().Context()).Error().Msg("ID_TOKEN_AUDIENCE environment variable is not set")
return echo.NewHTTPError(http.StatusInternalServerError, "Server misconfiguration")
}

payload, err := idtoken.Validate(ctx.Request().Context(), token, audience)
if err != nil {
log.Ctx(ctx.Request().Context()).Error().Err(err).Msg("Invalid token")
return ctx.JSON(http.StatusUnauthorized, "Invalid token")
Expand Down

0 comments on commit 3bd0ea4

Please sign in to comment.