Skip to content

Commit

Permalink
Merge pull request #9 from 9seconds/rework-ca
Browse files Browse the repository at this point in the history
Rework ca
  • Loading branch information
9seconds authored Dec 22, 2019
2 parents 64554e6 + 817fd0a commit aee88b1
Show file tree
Hide file tree
Showing 11 changed files with 217 additions and 388 deletions.
229 changes: 52 additions & 177 deletions ca/ca.go
Original file line number Diff line number Diff line change
@@ -1,48 +1,28 @@
package ca

import (
"crypto/hmac"
"crypto/rsa"
"crypto/sha1" // nolint: gosec
"context"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/binary"
"hash"
"math/big"
"math/rand"
"net"
"hash/fnv"
"runtime"
"sync"
"time"

"github.com/karlseguin/ccache"
lru "github.com/hashicorp/golang-lru"
"golang.org/x/xerrors"
)

const (
// TTLForCertificate is minimal time until certificate is
// considered as expired. This is a time from the last usage. If the
// certificate is expired, it does not necessary wiped out from LRU
// cache.
TTLForCertificate = 10 * time.Minute

// RSAKeyLength is the bit size of RSA private key for the certificate.
RSAKeyLength = 2048
)

var (
certWorkerCount = uint32(runtime.NumCPU())
bigBangTime = time.Unix(0, 0)
)

// CertificateMetrics is a subset of the main Metrics interface which
// provides callbacks for certificates.
type CertificateMetrics interface {
NewCertificate()
DropCertificate()
}

// DefaultMaxSize defines a default value for TLS certificates to store
// in LRU cache.
const DefaultMaxSize = 1024

// CA is a datastructure which presents TLS CA (certificate authority).
// The main purpose of this type is to generate TLS certificates
// on-the-fly, using given CA certificate and private key.
Expand All @@ -51,180 +31,75 @@ type CertificateMetrics interface {
// number of concurrently generated certificates is equal to the number
// of CPUs.
type CA struct {
ca tls.Certificate
orgNames []string
secret []byte
requestChans []chan *signRequest
cache *ccache.Cache
wg *sync.WaitGroup
metrics CertificateMetrics
cache *lru.Cache
cancel context.CancelFunc
workers []worker
wg sync.WaitGroup
}

// Get returns generated TLSConfig instance for the given hostname.
func (c *CA) Get(host string) (TLSConfig, error) {
item := c.cache.TrackingGet(host)

if item == ccache.NilTracked {
newRequest := signRequestPool.Get().(*signRequest)
defer signRequestPool.Put(newRequest)

newRequest.host = host
c.getWorkerChan(host) <- newRequest
response := <-newRequest.response

defer signResponsePool.Put(response)
func (c *CA) Get(host string) (*tls.Config, error) {
if item, ok := c.cache.Get(host); ok {
return item.(*tls.Config), nil
}

if response.err != nil {
return TLSConfig{}, xerrors.Errorf("cannot create TLS certificate for host %s: %w",
host, response.err)
}
hashFunc := fnv.New32a()
hashFunc.Write([]byte(host)) // nolint: errcheck

item = response.item
}
num := int(hashFunc.Sum32() % uint32(len(c.workers)))

return TLSConfig{item}, nil
return c.workers[num].get(host)
}

// Close stops CA instance. This includes all signing workers and LRU
// cache.
func (c *CA) Close() error {
for _, ch := range c.requestChans {
close(ch)
}

func (c *CA) Close() {
c.cancel()
c.wg.Wait()
c.cache.Stop()

return nil
}

func (c *CA) worker(requests chan *signRequest, wg *sync.WaitGroup) {
defer wg.Done()

for req := range requests {
resp := signResponsePool.Get().(*signResponse)
resp.err = nil

if item := c.cache.TrackingGet(req.host); item != ccache.NilTracked {
resp.item = item
req.response <- resp

continue
}

cert, err := c.sign(req.host)
if err != nil {
resp.err = err
req.response <- resp

continue
}

c.metrics.NewCertificate()

conf := &tls.Config{InsecureSkipVerify: true} // nolint: gosec
conf.Certificates = append(conf.Certificates, cert)
c.cache.Set(req.host, conf, TTLForCertificate)
resp.item = c.cache.TrackingGet(req.host)
req.response <- resp
}
c.cache.Purge()
}

func (c *CA) sign(host string) (tls.Certificate, error) {
template := x509.Certificate{
SerialNumber: &big.Int{},
Issuer: c.ca.Leaf.Subject,
Subject: pkix.Name{Organization: c.orgNames},
NotBefore: bigBangTime,
NotAfter: timeNotAfter(),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}

hash := hmac.New(sha1.New, c.secret)
hash.Write([]byte(host)) // nolint: errcheck

if ip := net.ParseIP(host); ip != nil {
template.IPAddresses = append(template.IPAddresses, ip)
} else {
template.DNSNames = append(template.DNSNames, host)
template.Subject.CommonName = host
}

hashed := hash.Sum(nil)
template.SerialNumber.SetBytes(hashed)
hash.Write(c.secret) // nolint: errcheck

randSeed := int64(binary.LittleEndian.Uint64(hash.Sum(nil)[:8]))
randGen := rand.New(rand.NewSource(randSeed))

certpriv, err := rsa.GenerateKey(randGen, RSAKeyLength)
func NewCA(certCA, certKey []byte, metrics CertificateMetrics, maxSize int, orgNames []string) (*CA, error) {
ca, err := tls.X509KeyPair(certCA, certKey)
if err != nil {
panic(err)
return nil, xerrors.Errorf("invalid certificates: %w", err)
}

derBytes, err := x509.CreateCertificate(randGen, &template, c.ca.Leaf,
&certpriv.PublicKey, c.ca.PrivateKey)
if err != nil {
return tls.Certificate{}, xerrors.Errorf("cannot generate TLS certificate: %w", err)
if ca.Leaf, err = x509.ParseCertificate(ca.Certificate[0]); err != nil {
return nil, xerrors.Errorf("invalid certificates: %w", err)
}

return tls.Certificate{
Certificate: [][]byte{derBytes, c.ca.Certificate[0]},
PrivateKey: certpriv,
}, nil
}

func (c *CA) getWorkerChan(host string) chan<- *signRequest {
newHash := hashPool.Get().(hash.Hash32)
newHash.Reset()
newHash.Write([]byte(host)) // nolint: errcheck
chanNumber := newHash.Sum32() % certWorkerCount
hashPool.Put(newHash)

return c.requestChans[chanNumber]
}

// NewCA creates new instance of TLS CA.
func NewCA(certCA, certKey []byte, metrics CertificateMetrics,
cacheMaxSize int64, cacheItemsToPrune uint32, orgNames ...string) (CA, error) {
ca, err := tls.X509KeyPair(certCA, certKey)
if err != nil {
return CA{}, xerrors.Errorf("invalid certificates: %w", err)
if maxSize <= 0 {
maxSize = DefaultMaxSize
}

if ca.Leaf, err = x509.ParseCertificate(ca.Certificate[0]); err != nil {
return CA{}, xerrors.Errorf("invalid certificates: %w", err)
cache, err := lru.NewWithEvict(maxSize, func(_, _ interface{}) {
metrics.DropCertificate()
})
if err != nil {
return nil, xerrors.Errorf("cannot make a new cache: %w", err)
}

ccacheConf := ccache.Configure()
ccacheConf = ccacheConf.MaxSize(cacheMaxSize)
ccacheConf = ccacheConf.ItemsToPrune(cacheItemsToPrune)
ccacheConf = ccacheConf.OnDelete(func(_ *ccache.Item) { metrics.DropCertificate() })

obj := CA{
ca: ca,
metrics: metrics,
secret: certKey,
orgNames: orgNames,
cache: ccache.New(ccacheConf),
requestChans: make([]chan *signRequest, 0, certWorkerCount),
wg: &sync.WaitGroup{},
ctx, cancel := context.WithCancel(context.Background())
obj := &CA{
cache: cache,
workers: make([]worker, runtime.NumCPU()),
cancel: cancel,
}

for i := 0; i < int(certWorkerCount); i++ {
newChan := make(chan *signRequest)
obj.requestChans = append(obj.requestChans, newChan)
obj.wg.Add(1)

go obj.worker(newChan, obj.wg)
obj.wg.Add(len(obj.workers))

for i := range obj.workers {
obj.workers[i] = worker{
ca: ca,
cache: cache,
orgNames: orgNames,
secret: certKey,
ctx: ctx,
metrics: metrics,
channelRequests: make(chan workerRequest),
}
go obj.workers[i].run(&obj.wg)
}

return obj, nil
}

func timeNotAfter() time.Time {
now := time.Now()
return time.Date(now.Year()+10, now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
}
47 changes: 10 additions & 37 deletions ca/ca_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package ca

import (
"crypto/x509"
"testing"

"github.com/stretchr/testify/mock"
Expand Down Expand Up @@ -63,57 +62,31 @@ type CATestSuite struct {
func (suite *CATestSuite) SetupTest() {
suite.mock = &MockCertificateMetrics{}

ca, err := NewCA(testCaCACert, testCaPrivateKey, suite.mock, 1000, 100)
ca, err := NewCA(testCaCACert, testCaPrivateKey, suite.mock, -1, []string{"name"})
if err != nil {
panic(err)
}

suite.ca = &ca
}

func (suite *CATestSuite) TearDownTest() {
suite.ca.Close()
suite.ca = ca
}

func (suite *CATestSuite) TestDoubleGet() {
suite.mock.On("NewCertificate")
suite.mock.On("NewCertificate").Once()
suite.mock.On("DropCertificate").Maybe()

conf1, err := suite.ca.Get("hostname")
suite.Nil(err)
suite.NoError(err)

conf2, err := suite.ca.Get("hostname")
suite.Nil(err)
suite.NoError(err)

suite.Equal(conf1.Get().Certificates[0].PrivateKey, conf2.Get().Certificates[0].PrivateKey)
suite.Equal(conf1.Get().Certificates[0].Certificate[0], conf2.Get().Certificates[0].Certificate[0])

suite.mock.AssertExpectations(suite.T())
suite.Equal(conf1.Certificates[0].PrivateKey, conf2.Certificates[0].PrivateKey)
suite.Equal(conf1.Certificates[0].Certificate[0], conf2.Certificates[0].Certificate[0])
}

func (suite *CATestSuite) TestSigner() {
suite.mock.On("NewCertificate")

conf, err := suite.ca.Get("hostname")
suite.Nil(err)

cert := conf.Get().Certificates[0]
cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0])
suite.Nil(err)

pool := x509.NewCertPool()
pool.AddCert(suite.ca.ca.Leaf)

suite.Nil(cert.Leaf.VerifyHostname("hostname"))
suite.NotNil(cert.Leaf.VerifyHostname("hostname2"))
suite.Nil(cert.Leaf.CheckSignatureFrom(suite.ca.ca.Leaf))

_, err = cert.Leaf.Verify(x509.VerifyOptions{
DNSName: "hostname",
Roots: pool,
})
suite.Nil(err)

func (suite *CATestSuite) TearDownTest() {
suite.mock.AssertExpectations(suite.T())
suite.ca.Close()
}

func TestCA(t *testing.T) {
Expand Down
40 changes: 0 additions & 40 deletions ca/pools.go

This file was deleted.

Loading

0 comments on commit aee88b1

Please sign in to comment.