From b3b7c4b415a5b4c1cbaf2d8bc7062ffd6da3d054 Mon Sep 17 00:00:00 2001 From: jayme-github Date: Thu, 1 Aug 2024 05:35:59 +0200 Subject: [PATCH] Add CertProvider to hot reload TLS certs for gRPC service (#587) Signed-off-by: Janis Meybohm --- Makefile | 7 +- README.md | 17 ++++- src/provider/cert_provider.go | 109 +++++++++++++++++++++++++++ src/server/server_impl.go | 49 ++++++------ test/integration/integration_test.go | 87 +++++++++++++++++++++ 5 files changed, 241 insertions(+), 28 deletions(-) create mode 100644 src/provider/cert_provider.go diff --git a/Makefile b/Makefile index 2f98413d..342e2beb 100644 --- a/Makefile +++ b/Makefile @@ -11,6 +11,11 @@ BUILDX_PLATFORMS := linux/amd64,linux/arm64/v8 # Root dir returns absolute path of current directory. It has a trailing "/". PROJECT_DIR := $(dir $(abspath $(lastword $(MAKEFILE_LIST)))) export PROJECT_DIR +ifneq ($(shell docker compose version 2>/dev/null),) + DOCKER_COMPOSE=docker compose +else + DOCKER_COMPOSE=docker-compose +endif .PHONY: bootstrap bootstrap: ; @@ -142,7 +147,7 @@ docker_multiarch_push: docker_multiarch_image .PHONY: integration_tests integration_tests: - docker-compose --project-directory $(PWD) -f integration-test/docker-compose-integration-test.yml up --build --exit-code-from tester + $(DOCKER_COMPOSE) --project-directory $(PWD) -f integration-test/docker-compose-integration-test.yml up --build --exit-code-from tester .PHONY: precommit_install precommit_install: diff --git a/README.md b/README.md index 25a5daee..e4bc1558 100644 --- a/README.md +++ b/README.md @@ -1109,17 +1109,26 @@ otelcol-contrib --config examples/otlp-collector/config.yaml docker run -d --name jaeger -p 16686:16686 -p 14250:14250 jaegertracing/all-in-one:1.33 ``` +# TLS + +Ratelimit supports TLS for it's gRPC endpoint. + +The following environment variables control the TLS feature: + +1. `GRPC_SERVER_USE_TLS` - Enables gRPC connections to server over TLS +1. `GRPC_SERVER_TLS_CERT` - Path to the file containing the server cert chain +1. `GRPC_SERVER_TLS_KEY` - Path to the file containing the server private key + +Ratelimit uses [goruntime](https://github.com/lyft/goruntime) to watch the TLS certificate and key and will hot reload them on changes. + # mTLS Ratelimit supports mTLS when Envoy sends requests to the service. -The following environment variables control the mTLS feature: +TLS must be enabled on the gRPC endpoint in order for mTLS to work see [TLS](#TLS). The following variables can be set to enable mTLS on the Ratelimit service. -1. `GRPC_SERVER_USE_TLS` - Enables gprc connections to server over TLS -1. `GRPC_SERVER_TLS_CERT` - Path to the file containing the server cert chain -1. `GRPC_SERVER_TLS_KEY` - Path to the file containing the server private key 1. `GRPC_CLIENT_TLS_CACERT` - Path to the file containing the client CA certificate. 1. `GRPC_CLIENT_TLS_SAN` - (Optional) DNS Name to validate from the client cert during mTLS auth diff --git a/src/provider/cert_provider.go b/src/provider/cert_provider.go new file mode 100644 index 00000000..321a14fa --- /dev/null +++ b/src/provider/cert_provider.go @@ -0,0 +1,109 @@ +package provider + +import ( + "crypto/tls" + "path/filepath" + "sync" + + "github.com/lyft/goruntime/loader" + gostats "github.com/lyft/gostats" + logger "github.com/sirupsen/logrus" + + "github.com/envoyproxy/ratelimit/src/settings" +) + +// CertProvider will watch certDirectory for changes via goruntime/loader and reload the cert and key files +type CertProvider struct { + settings settings.Settings + runtime loader.IFace + runtimeUpdateEvent chan int + rootStore gostats.Store + certLock sync.RWMutex + cert *tls.Certificate + certDirectory string + certFile string + keyFile string +} + +// GetCertificateFunc returns a function compatible with tls.Config.GetCertificate, fetching the current certificate +func (p *CertProvider) GetCertificateFunc() func(*tls.ClientHelloInfo) (*tls.Certificate, error) { + return func(*tls.ClientHelloInfo) (*tls.Certificate, error) { + p.certLock.RLock() + defer p.certLock.RUnlock() + return p.cert, nil + } +} + +func (p *CertProvider) watch() { + p.runtime.AddUpdateCallback(p.runtimeUpdateEvent) + + go func() { + for { + logger.Debugf("CertProvider: waiting for runtime update") + <-p.runtimeUpdateEvent + logger.Debugf("CertProvider: got runtime update and reloading config") + p.reloadCert() + } + }() +} + +// reloadCert loads the cert and key files and updates the tls.Certificate in memory +func (p *CertProvider) reloadCert() { + tlsKeyPair, err := tls.LoadX509KeyPair(p.certFile, p.keyFile) + if err != nil { + logger.Errorf("CertProvider failed to load TLS key pair (%s, %s): %v", p.certFile, p.keyFile, err) + // panic in case there is no cert already loaded as this would mean starting up without TLS + if p.cert == nil { + logger.Fatalf("CertProvider failed to load any certificate, exiting.") + } + return // keep the old cert if we have one + } + p.certLock.Lock() + defer p.certLock.Unlock() + p.cert = &tlsKeyPair + logger.Infof("CertProvider reloaded cert from (%s, %s)", p.certFile, p.keyFile) +} + +// setupRuntime sets up the goruntime loader to watch the certDirectory +// Will panic if it fails to set up the loader +func (p *CertProvider) setupRuntime() { + var err error + + // runtimePath is the parent folder of certPath + runtimePath := filepath.Dir(p.certDirectory) + // runtimeSubdirectory is the name of the folder to watch, containing the certs + runtimeSubdirectory := filepath.Base(p.certDirectory) + + p.runtime, err = loader.New2( + runtimePath, + runtimeSubdirectory, + p.rootStore.ScopeWithTags("certs", p.settings.ExtraTags), + &loader.DirectoryRefresher{}, + loader.IgnoreDotFiles) + + if err != nil { + logger.Fatalf("Failed to set up goruntime loader: %v", err) + } +} + +// NewCertProvider creates a new CertProvider +// Will panic if it fails to set up gruntime or fails to load the initial certificate +func NewCertProvider(settings settings.Settings, rootStore gostats.Store, certFile, keyFile string) *CertProvider { + certDirectory := filepath.Dir(certFile) + if certDirectory != filepath.Dir(keyFile) { + logger.Fatalf("certFile and keyFile must be in the same directory") + } + p := &CertProvider{ + settings: settings, + runtimeUpdateEvent: make(chan int), + rootStore: rootStore, + certDirectory: certDirectory, + certFile: certFile, + keyFile: keyFile, + } + p.setupRuntime() + // Initially load the certificate (or panic) + p.reloadCert() + go p.watch() + return p +} diff --git a/src/server/server_impl.go b/src/server/server_impl.go index 0b42b40f..f2341917 100644 --- a/src/server/server_impl.go +++ b/src/server/server_impl.go @@ -26,7 +26,6 @@ import ( pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" "github.com/gorilla/mux" reuseport "github.com/kavu/go_reuseport" - "github.com/lyft/goruntime/loader" gostats "github.com/lyft/gostats" logger "github.com/sirupsen/logrus" "google.golang.org/grpc" @@ -58,20 +57,20 @@ const ( ) type server struct { - httpAddress string - grpcAddress string - grpcListenType grpcListenType - debugAddress string - router *mux.Router - grpcServer *grpc.Server - store gostats.Store - scope gostats.Scope - provider provider.RateLimitConfigProvider - runtime loader.IFace - debugListener serverDebugListener - httpServer *http.Server - listenerMu sync.Mutex - health *HealthChecker + httpAddress string + grpcAddress string + grpcListenType grpcListenType + debugAddress string + router *mux.Router + grpcServer *grpc.Server + store gostats.Store + scope gostats.Scope + provider provider.RateLimitConfigProvider + debugListener serverDebugListener + httpServer *http.Server + listenerMu sync.Mutex + health *HealthChecker + grpcCertProvider *provider.CertProvider } func (server *server) AddDebugHttpEndpoint(path string, help string, handler http.HandlerFunc) { @@ -242,6 +241,14 @@ func newServer(s settings.Settings, name string, statsManager stats.Manager, loc ret := new(server) + // setup stats + ret.store = statsManager.GetStatsStore() + ret.scope = ret.store.ScopeWithTags(name, s.ExtraTags) + ret.store.AddStatGenerator(gostats.NewRuntimeStats(ret.scope.Scope("go"))) + if localCache != nil { + ret.store.AddStatGenerator(limiter.NewLocalCacheStats(localCache, ret.scope.Scope("localcache"))) + } + keepaliveOpt := grpc.KeepaliveParams(keepalive.ServerParameters{ MaxConnectionAge: s.GrpcMaxConnectionAge, MaxConnectionAgeGrace: s.GrpcMaxConnectionAgeGrace, @@ -256,6 +263,10 @@ func newServer(s settings.Settings, name string, statsManager stats.Manager, loc } if s.GrpcServerUseTLS { grpcServerTlsConfig := s.GrpcServerTlsConfig + ret.grpcCertProvider = provider.NewCertProvider(s, ret.store, s.GrpcServerTlsCert, s.GrpcServerTlsKey) + // Remove the static certificates and use the provider via the GetCertificate function + grpcServerTlsConfig.Certificates = nil + grpcServerTlsConfig.GetCertificate = ret.grpcCertProvider.GetCertificateFunc() // Verify client SAN if provided if s.GrpcClientTlsSAN != "" { grpcServerTlsConfig.VerifyPeerCertificate = verifyClient(grpcServerTlsConfig.ClientCAs, s.GrpcClientTlsSAN) @@ -275,14 +286,6 @@ func newServer(s settings.Settings, name string, statsManager stats.Manager, loc } ret.debugAddress = net.JoinHostPort(s.DebugHost, strconv.Itoa(s.DebugPort)) - // setup stats - ret.store = statsManager.GetStatsStore() - ret.scope = ret.store.ScopeWithTags(name, s.ExtraTags) - ret.store.AddStatGenerator(gostats.NewRuntimeStats(ret.scope.Scope("go"))) - if localCache != nil { - ret.store.AddStatGenerator(limiter.NewLocalCacheStats(localCache, ret.scope.Scope("localcache"))) - } - // setup config provider ret.provider = getProviderImpl(s, statsManager, ret.store) diff --git a/test/integration/integration_test.go b/test/integration/integration_test.go index fda35cd9..5a8379a4 100644 --- a/test/integration/integration_test.go +++ b/test/integration/integration_test.go @@ -4,6 +4,7 @@ package integration_test import ( "crypto/tls" + "crypto/x509" "fmt" "io" "math/rand" @@ -294,6 +295,92 @@ func Test_mTLS(t *testing.T) { defer conn.Close() } +func TestReloadGRPCServerCerts(t *testing.T) { + common.WithMultiRedis(t, []common.RedisConfig{ + {Port: 6383}, + }, func() { + s := makeSimpleRedisSettings(6383, 6380, false, 0) + assert := assert.New(t) + // TLS setup initially used to configure the server + initialServerCAFile, initialServerCertFile, initialServerCertKey, err := mTLSSetup(utils.ServerCA) + assert.NoError(err) + // Second TLS setup that will replace the above during test + newServerCAFile, newServerCertFile, newServerCertKey, err := mTLSSetup(utils.ServerCA) + assert.NoError(err) + // Create CertPools and tls.Configs for both CAs + initialCaCert, err := os.ReadFile(initialServerCAFile) + assert.NoError(err) + initialCertPool := x509.NewCertPool() + initialCertPool.AppendCertsFromPEM(initialCaCert) + initialTlsConfig := &tls.Config{ + RootCAs: initialCertPool, + } + newCaCert, err := os.ReadFile(newServerCAFile) + assert.NoError(err) + newCertPool := x509.NewCertPool() + newCertPool.AppendCertsFromPEM(newCaCert) + newTlsConfig := &tls.Config{ + RootCAs: newCertPool, + } + connStr := fmt.Sprintf("localhost:%v", s.GrpcPort) + + // Set up ratelimit with the initial certificate + s.GrpcServerUseTLS = true + s.GrpcServerTlsCert = initialServerCertFile + s.GrpcServerTlsKey = initialServerCertKey + settings.GrpcServerTlsConfig()(&s) + runner := startTestRunner(t, s) + defer runner.Stop() + + // Ensure TLS validation works with the initial CA in cert pool + t.Run("WithInitialCert", func(t *testing.T) { + conn, err := tls.Dial("tcp", connStr, initialTlsConfig) + assert.NoError(err) + conn.Close() + }) + + // Ensure TLS validation fails with the new CA in cert pool + t.Run("WithNewCertFail", func(t *testing.T) { + conn, err := tls.Dial("tcp", connStr, newTlsConfig) + assert.Error(err) + if err == nil { + conn.Close() + } + }) + + // Replace the initial certificate with the new one + err = os.Rename(newServerCertFile, initialServerCertFile) + assert.NoError(err) + err = os.Rename(newServerCertKey, initialServerCertKey) + assert.NoError(err) + + // Ensure TLS validation works with the new CA in cert pool + t.Run("WithNewCertOK", func(t *testing.T) { + // If this takes longer than 10s, something is probably wrong + wait := 10 + for i := 0; i < wait; i++ { + // Ensure the new certificate is being used + conn, err := tls.Dial("tcp", connStr, newTlsConfig) + if err == nil { + conn.Close() + break + } + time.Sleep(1 * time.Second) + } + assert.NoError(err) + }) + + // Ensure TLS validation fails with the initial CA in cert pool + t.Run("WithInitialCertFail", func(t *testing.T) { + conn, err := tls.Dial("tcp", connStr, initialTlsConfig) + assert.Error(err) + if err == nil { + conn.Close() + } + }) + }) +} + func testBasicConfigAuthTLS(perSecond bool, local_cache_size int) func(*testing.T) { s := makeSimpleRedisSettings(16381, 16382, perSecond, local_cache_size) s.RedisTlsConfig = &tls.Config{}