Skip to content

Commit

Permalink
Add CertProvider to hot reload TLS certs for gRPC service (#587)
Browse files Browse the repository at this point in the history
Signed-off-by: Janis Meybohm <[email protected]>
  • Loading branch information
jayme-github authored Aug 1, 2024
1 parent f4af2db commit b3b7c4b
Show file tree
Hide file tree
Showing 5 changed files with 241 additions and 28 deletions.
7 changes: 6 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ BUILDX_PLATFORMS := linux/amd64,linux/arm64/v8
# Root dir returns absolute path of current directory. It has a trailing "/".
PROJECT_DIR := $(dir $(abspath $(lastword $(MAKEFILE_LIST))))
export PROJECT_DIR
ifneq ($(shell docker compose version 2>/dev/null),)
DOCKER_COMPOSE=docker compose
else
DOCKER_COMPOSE=docker-compose
endif

.PHONY: bootstrap
bootstrap: ;
Expand Down Expand Up @@ -142,7 +147,7 @@ docker_multiarch_push: docker_multiarch_image

.PHONY: integration_tests
integration_tests:
docker-compose --project-directory $(PWD) -f integration-test/docker-compose-integration-test.yml up --build --exit-code-from tester
$(DOCKER_COMPOSE) --project-directory $(PWD) -f integration-test/docker-compose-integration-test.yml up --build --exit-code-from tester

.PHONY: precommit_install
precommit_install:
Expand Down
17 changes: 13 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1109,17 +1109,26 @@ otelcol-contrib --config examples/otlp-collector/config.yaml
docker run -d --name jaeger -p 16686:16686 -p 14250:14250 jaegertracing/all-in-one:1.33
```
# TLS
Ratelimit supports TLS for it's gRPC endpoint.

The following environment variables control the TLS feature:

1. `GRPC_SERVER_USE_TLS` - Enables gRPC connections to server over TLS
1. `GRPC_SERVER_TLS_CERT` - Path to the file containing the server cert chain
1. `GRPC_SERVER_TLS_KEY` - Path to the file containing the server private key

Ratelimit uses [goruntime](https://github.com/lyft/goruntime) to watch the TLS certificate and key and will hot reload them on changes.

# mTLS

Ratelimit supports mTLS when Envoy sends requests to the service.

The following environment variables control the mTLS feature:
TLS must be enabled on the gRPC endpoint in order for mTLS to work see [TLS](#TLS).

The following variables can be set to enable mTLS on the Ratelimit service.

1. `GRPC_SERVER_USE_TLS` - Enables gprc connections to server over TLS
1. `GRPC_SERVER_TLS_CERT` - Path to the file containing the server cert chain
1. `GRPC_SERVER_TLS_KEY` - Path to the file containing the server private key
1. `GRPC_CLIENT_TLS_CACERT` - Path to the file containing the client CA certificate.
1. `GRPC_CLIENT_TLS_SAN` - (Optional) DNS Name to validate from the client cert during mTLS auth

Expand Down
109 changes: 109 additions & 0 deletions src/provider/cert_provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package provider

import (
"crypto/tls"
"path/filepath"
"sync"

"github.com/lyft/goruntime/loader"
gostats "github.com/lyft/gostats"
logger "github.com/sirupsen/logrus"

"github.com/envoyproxy/ratelimit/src/settings"
)

// CertProvider will watch certDirectory for changes via goruntime/loader and reload the cert and key files
type CertProvider struct {
settings settings.Settings
runtime loader.IFace
runtimeUpdateEvent chan int
rootStore gostats.Store
certLock sync.RWMutex
cert *tls.Certificate
certDirectory string
certFile string
keyFile string
}

// GetCertificateFunc returns a function compatible with tls.Config.GetCertificate, fetching the current certificate
func (p *CertProvider) GetCertificateFunc() func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
return func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
p.certLock.RLock()
defer p.certLock.RUnlock()
return p.cert, nil
}
}

func (p *CertProvider) watch() {
p.runtime.AddUpdateCallback(p.runtimeUpdateEvent)

go func() {
for {
logger.Debugf("CertProvider: waiting for runtime update")
<-p.runtimeUpdateEvent
logger.Debugf("CertProvider: got runtime update and reloading config")
p.reloadCert()
}
}()
}

// reloadCert loads the cert and key files and updates the tls.Certificate in memory
func (p *CertProvider) reloadCert() {
tlsKeyPair, err := tls.LoadX509KeyPair(p.certFile, p.keyFile)
if err != nil {
logger.Errorf("CertProvider failed to load TLS key pair (%s, %s): %v", p.certFile, p.keyFile, err)
// panic in case there is no cert already loaded as this would mean starting up without TLS
if p.cert == nil {
logger.Fatalf("CertProvider failed to load any certificate, exiting.")
}
return // keep the old cert if we have one
}
p.certLock.Lock()
defer p.certLock.Unlock()
p.cert = &tlsKeyPair
logger.Infof("CertProvider reloaded cert from (%s, %s)", p.certFile, p.keyFile)
}

// setupRuntime sets up the goruntime loader to watch the certDirectory
// Will panic if it fails to set up the loader
func (p *CertProvider) setupRuntime() {
var err error

// runtimePath is the parent folder of certPath
runtimePath := filepath.Dir(p.certDirectory)
// runtimeSubdirectory is the name of the folder to watch, containing the certs
runtimeSubdirectory := filepath.Base(p.certDirectory)

p.runtime, err = loader.New2(
runtimePath,
runtimeSubdirectory,
p.rootStore.ScopeWithTags("certs", p.settings.ExtraTags),
&loader.DirectoryRefresher{},
loader.IgnoreDotFiles)

if err != nil {
logger.Fatalf("Failed to set up goruntime loader: %v", err)
}
}

// NewCertProvider creates a new CertProvider
// Will panic if it fails to set up gruntime or fails to load the initial certificate
func NewCertProvider(settings settings.Settings, rootStore gostats.Store, certFile, keyFile string) *CertProvider {
certDirectory := filepath.Dir(certFile)
if certDirectory != filepath.Dir(keyFile) {
logger.Fatalf("certFile and keyFile must be in the same directory")
}
p := &CertProvider{
settings: settings,
runtimeUpdateEvent: make(chan int),
rootStore: rootStore,
certDirectory: certDirectory,
certFile: certFile,
keyFile: keyFile,
}
p.setupRuntime()
// Initially load the certificate (or panic)
p.reloadCert()
go p.watch()
return p
}
49 changes: 26 additions & 23 deletions src/server/server_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import (
pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3"
"github.com/gorilla/mux"
reuseport "github.com/kavu/go_reuseport"
"github.com/lyft/goruntime/loader"
gostats "github.com/lyft/gostats"
logger "github.com/sirupsen/logrus"
"google.golang.org/grpc"
Expand Down Expand Up @@ -58,20 +57,20 @@ const (
)

type server struct {
httpAddress string
grpcAddress string
grpcListenType grpcListenType
debugAddress string
router *mux.Router
grpcServer *grpc.Server
store gostats.Store
scope gostats.Scope
provider provider.RateLimitConfigProvider
runtime loader.IFace
debugListener serverDebugListener
httpServer *http.Server
listenerMu sync.Mutex
health *HealthChecker
httpAddress string
grpcAddress string
grpcListenType grpcListenType
debugAddress string
router *mux.Router
grpcServer *grpc.Server
store gostats.Store
scope gostats.Scope
provider provider.RateLimitConfigProvider
debugListener serverDebugListener
httpServer *http.Server
listenerMu sync.Mutex
health *HealthChecker
grpcCertProvider *provider.CertProvider
}

func (server *server) AddDebugHttpEndpoint(path string, help string, handler http.HandlerFunc) {
Expand Down Expand Up @@ -242,6 +241,14 @@ func newServer(s settings.Settings, name string, statsManager stats.Manager, loc

ret := new(server)

// setup stats
ret.store = statsManager.GetStatsStore()
ret.scope = ret.store.ScopeWithTags(name, s.ExtraTags)
ret.store.AddStatGenerator(gostats.NewRuntimeStats(ret.scope.Scope("go")))
if localCache != nil {
ret.store.AddStatGenerator(limiter.NewLocalCacheStats(localCache, ret.scope.Scope("localcache")))
}

keepaliveOpt := grpc.KeepaliveParams(keepalive.ServerParameters{
MaxConnectionAge: s.GrpcMaxConnectionAge,
MaxConnectionAgeGrace: s.GrpcMaxConnectionAgeGrace,
Expand All @@ -256,6 +263,10 @@ func newServer(s settings.Settings, name string, statsManager stats.Manager, loc
}
if s.GrpcServerUseTLS {
grpcServerTlsConfig := s.GrpcServerTlsConfig
ret.grpcCertProvider = provider.NewCertProvider(s, ret.store, s.GrpcServerTlsCert, s.GrpcServerTlsKey)
// Remove the static certificates and use the provider via the GetCertificate function
grpcServerTlsConfig.Certificates = nil
grpcServerTlsConfig.GetCertificate = ret.grpcCertProvider.GetCertificateFunc()
// Verify client SAN if provided
if s.GrpcClientTlsSAN != "" {
grpcServerTlsConfig.VerifyPeerCertificate = verifyClient(grpcServerTlsConfig.ClientCAs, s.GrpcClientTlsSAN)
Expand All @@ -275,14 +286,6 @@ func newServer(s settings.Settings, name string, statsManager stats.Manager, loc
}
ret.debugAddress = net.JoinHostPort(s.DebugHost, strconv.Itoa(s.DebugPort))

// setup stats
ret.store = statsManager.GetStatsStore()
ret.scope = ret.store.ScopeWithTags(name, s.ExtraTags)
ret.store.AddStatGenerator(gostats.NewRuntimeStats(ret.scope.Scope("go")))
if localCache != nil {
ret.store.AddStatGenerator(limiter.NewLocalCacheStats(localCache, ret.scope.Scope("localcache")))
}

// setup config provider
ret.provider = getProviderImpl(s, statsManager, ret.store)

Expand Down
87 changes: 87 additions & 0 deletions test/integration/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package integration_test

import (
"crypto/tls"
"crypto/x509"
"fmt"
"io"
"math/rand"
Expand Down Expand Up @@ -294,6 +295,92 @@ func Test_mTLS(t *testing.T) {
defer conn.Close()
}

func TestReloadGRPCServerCerts(t *testing.T) {
common.WithMultiRedis(t, []common.RedisConfig{
{Port: 6383},
}, func() {
s := makeSimpleRedisSettings(6383, 6380, false, 0)
assert := assert.New(t)
// TLS setup initially used to configure the server
initialServerCAFile, initialServerCertFile, initialServerCertKey, err := mTLSSetup(utils.ServerCA)
assert.NoError(err)
// Second TLS setup that will replace the above during test
newServerCAFile, newServerCertFile, newServerCertKey, err := mTLSSetup(utils.ServerCA)
assert.NoError(err)
// Create CertPools and tls.Configs for both CAs
initialCaCert, err := os.ReadFile(initialServerCAFile)
assert.NoError(err)
initialCertPool := x509.NewCertPool()
initialCertPool.AppendCertsFromPEM(initialCaCert)
initialTlsConfig := &tls.Config{
RootCAs: initialCertPool,
}
newCaCert, err := os.ReadFile(newServerCAFile)
assert.NoError(err)
newCertPool := x509.NewCertPool()
newCertPool.AppendCertsFromPEM(newCaCert)
newTlsConfig := &tls.Config{
RootCAs: newCertPool,
}
connStr := fmt.Sprintf("localhost:%v", s.GrpcPort)

// Set up ratelimit with the initial certificate
s.GrpcServerUseTLS = true
s.GrpcServerTlsCert = initialServerCertFile
s.GrpcServerTlsKey = initialServerCertKey
settings.GrpcServerTlsConfig()(&s)
runner := startTestRunner(t, s)
defer runner.Stop()

// Ensure TLS validation works with the initial CA in cert pool
t.Run("WithInitialCert", func(t *testing.T) {
conn, err := tls.Dial("tcp", connStr, initialTlsConfig)
assert.NoError(err)
conn.Close()
})

// Ensure TLS validation fails with the new CA in cert pool
t.Run("WithNewCertFail", func(t *testing.T) {
conn, err := tls.Dial("tcp", connStr, newTlsConfig)
assert.Error(err)
if err == nil {
conn.Close()
}
})

// Replace the initial certificate with the new one
err = os.Rename(newServerCertFile, initialServerCertFile)
assert.NoError(err)
err = os.Rename(newServerCertKey, initialServerCertKey)
assert.NoError(err)

// Ensure TLS validation works with the new CA in cert pool
t.Run("WithNewCertOK", func(t *testing.T) {
// If this takes longer than 10s, something is probably wrong
wait := 10
for i := 0; i < wait; i++ {
// Ensure the new certificate is being used
conn, err := tls.Dial("tcp", connStr, newTlsConfig)
if err == nil {
conn.Close()
break
}
time.Sleep(1 * time.Second)
}
assert.NoError(err)
})

// Ensure TLS validation fails with the initial CA in cert pool
t.Run("WithInitialCertFail", func(t *testing.T) {
conn, err := tls.Dial("tcp", connStr, initialTlsConfig)
assert.Error(err)
if err == nil {
conn.Close()
}
})
})
}

func testBasicConfigAuthTLS(perSecond bool, local_cache_size int) func(*testing.T) {
s := makeSimpleRedisSettings(16381, 16382, perSecond, local_cache_size)
s.RedisTlsConfig = &tls.Config{}
Expand Down

0 comments on commit b3b7c4b

Please sign in to comment.