Skip to content
Open
2 changes: 2 additions & 0 deletions api/adscert.proto
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ message RequestInfo {
bytes url_hash = 2;
bytes body_hash = 3;
repeated SignatureInfo signature_info = 4;
// useful if 1 signatory is managing multiple origin domains such as in resellers case.
string origin_domain = 5;
}

// SignatureInfo captures the signature generated for the signing request. It
Expand Down
4 changes: 0 additions & 4 deletions cmd/server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,6 @@ func main() {
logger.SetLevel(parsedLogLevel)
logger.Infof("Log Level: %s, parsed as iota %v", *logLevel, parsedLogLevel)

if *origin == "" {
logger.Fatalf("Origin ads.cert Call Sign domain name is required")
}

if *privateKey == "" {
logger.Fatalf("Private key is required")
}
Expand Down
4 changes: 4 additions & 0 deletions examples/signer-client/signer-client.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (

var (
serverAddress = flag.String("server_address", "localhost:3000", "address of grpc server")
originDomain = flag.String("origin_domain", "", "Origin domain")
destinationURL = flag.String("url", "https://google.com/gen_204", "URL to invoke")
body = flag.String("body", "", "POST request body")
signingTimeout = flag.Duration("signing_timeout", 5*time.Millisecond, "Specifies how long this client will wait for signing to finish before abandoning.")
Expand Down Expand Up @@ -49,6 +50,9 @@ func main() {
// destination URL and body, setting these value on the RequestInfo message.
reqInfo := &api.RequestInfo{}
signatory.SetRequestInfo(reqInfo, *destinationURL, []byte(*body))
if originDomain != nil {
reqInfo.OriginDomain = *originDomain
}

// Request the signature.
logger.Infof("signing request for url: %v", *destinationURL)
Expand Down
1 change: 1 addition & 0 deletions examples/signer-server/signer-server.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ func main() {
base64PrivateKeys := signatory.GenerateFakePrivateKeysForTesting(*origin)

signatoryApi := signatory.NewLocalAuthenticatedConnectionsSignatory(
"info",
*origin,
crypto_rand.Reader,
clock.New(),
Expand Down
1 change: 1 addition & 0 deletions examples/verifier-parser/verifier-parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ func main() {
base64PrivateKeys := signatory.GenerateFakePrivateKeysForTesting(*origin)

signatoryApi := signatory.NewLocalAuthenticatedConnectionsSignatory(
"info",
*origin,
crypto_rand.Reader,
clock.New(),
Expand Down
1 change: 1 addition & 0 deletions examples/verifier-server/verifier-server.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ func main() {
}

signatoryApi := signatory.NewLocalAuthenticatedConnectionsSignatory(
"info",
*origin,
crypto_rand.Reader,
clock.New(),
Expand Down
1 change: 1 addition & 0 deletions internal/server/server_reference_implementation.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (

func SetUpAdsCertSignatoryServer(grpcServer *grpc.Server, adscertCallSign string, domainCheckInterval time.Duration, domainRenewalInterval time.Duration, privateKeys []string) {
signatoryApi := signatory.NewLocalAuthenticatedConnectionsSignatory(
"info",
adscertCallSign,
crypto_rand.Reader,
clock.New(),
Expand Down
305 changes: 158 additions & 147 deletions pkg/adscert/api/adscert.pb.go

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions pkg/adscert/api/adscert_grpc.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

47 changes: 28 additions & 19 deletions pkg/adscert/discovery/domain_indexer_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ func NewDefaultDomainIndexer(dnsResolver DNSResolver, domainStore DomainStore, d
domainRenewalInterval: domainRenewalInterval,
dnsResolver: dnsResolver,
domainStore: domainStore,
currentPrivateKey: make(map[string]keyAlias),
}

myPrivateKeys, err := privateKeysToKeyMap(base64PrivateKeys)
Expand All @@ -39,12 +40,14 @@ func NewDefaultDomainIndexer(dnsResolver DNSResolver, domainStore DomainStore, d
}
di.myPrivateKeys = myPrivateKeys

for _, privateKey := range di.myPrivateKeys {
// since iterating over a map is non-deterministic, we can make sure to set the key
// either if it is not already set or it is alphabetically less than current key at the index when
// iterating over the private keys map.
if di.currentPrivateKey == "" || di.currentPrivateKey < privateKey.alias {
di.currentPrivateKey = privateKey.alias
for originCallsign := range di.myPrivateKeys {
for _, privateKey := range di.myPrivateKeys[originCallsign] {
// since iterating over a map is non-deterministic, we can make sure to set the key
// either if it is not already set or it is alphabetically less than current key at the index when
// iterating over the private keys map.
if di.currentPrivateKey[originCallsign] == "" || di.currentPrivateKey[originCallsign] < privateKey.alias {
di.currentPrivateKey[originCallsign] = privateKey.alias
}
}
}

Expand All @@ -62,8 +65,8 @@ type defaultDomainIndexer struct {
lastRun time.Time
lastRunLock sync.RWMutex

myPrivateKeys keyMap
currentPrivateKey keyAlias
myPrivateKeys map[string]keyMap
currentPrivateKey map[string]keyAlias

dnsResolver DNSResolver
domainStore DomainStore
Expand Down Expand Up @@ -227,21 +230,27 @@ func (di *defaultDomainIndexer) checkDomainForKeyRecords(ctx context.Context, cu
}

// create shared secrets for each private key + public key combination
for _, myKey := range di.myPrivateKeys {
for _, theirKey := range currentDomainInfo.allPublicKeys {
keyPairAlias := newKeyPairAlias(myKey.alias, theirKey.alias)
if currentDomainInfo.allSharedSecrets[keyPairAlias] == nil {
currentDomainInfo.allSharedSecrets[keyPairAlias], err = calculateSharedSecret(myKey, theirKey)
if err != nil {
logger.Warningf("error calculating shared secret for record %s: %v", currentDomainInfo.Domain, err)
currentDomainInfo.domainStatus = DomainStatusErrorOnSharedSecretCalculation
for originCallsign := range di.myPrivateKeys {
if originCallsign != currentDomainInfo.Domain {
continue
}

for _, myKey := range di.myPrivateKeys[originCallsign] {
for _, theirKey := range currentDomainInfo.allPublicKeys {
keyPairAlias := newKeyPairAlias(myKey.alias, theirKey.alias)
if currentDomainInfo.allSharedSecrets[keyPairAlias] == nil {
currentDomainInfo.allSharedSecrets[keyPairAlias], err = calculateSharedSecret(myKey, theirKey)
if err != nil {
logger.Warningf("error calculating shared secret for record %s: %v", currentDomainInfo.Domain, err)
currentDomainInfo.domainStatus = DomainStatusErrorOnSharedSecretCalculation
}
}
}
}
}

currentDomainInfo.currentSharedSecretId = newKeyPairAlias(di.currentPrivateKey, currentDomainInfo.currentPublicKeyId)
currentDomainInfo.lastUpdateTime = time.Now()
currentDomainInfo.currentSharedSecretId = newKeyPairAlias(di.currentPrivateKey[originCallsign], currentDomainInfo.currentPublicKeyId)
currentDomainInfo.lastUpdateTime = time.Now()
}
}

func parsePolicyRecords(baseSubdomain string, baseSubdomainRecords []string) (foundDomains []string, parseError bool) {
Expand Down
23 changes: 17 additions & 6 deletions pkg/adscert/discovery/internal_base_key.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package discovery

import (
"errors"
"fmt"
"strings"

"github.com/IABTechLab/adscert/internal/formats"
"github.com/IABTechLab/adscert/pkg/adscert/logger"
Expand Down Expand Up @@ -76,11 +78,14 @@ func calculateSharedSecret(originPrivateKey *x25519Key, remotePublicKey *x25519K
return result, err
}

func privateKeysToKeyMap(privateKeys []string) (keyMap, error) {
result := keyMap{}

func privateKeysToKeyMap(privateKeys []string) (map[string]keyMap, error) {
results := map[string]keyMap{}
for _, privateKeyBase64 := range privateKeys {
privateKey, err := parseKeyFromString(privateKeyBase64)
sp := strings.SplitN(privateKeyBase64, "|", 2)
if len(sp) < 2 {
return nil, errors.New("missing origin callsign")
}
privateKey, err := parseKeyFromString(sp[1])
if err != nil {
return nil, err
}
Expand All @@ -90,10 +95,16 @@ func privateKeysToKeyMap(privateKeys []string) (keyMap, error) {

keyAlias := keyAlias(formats.ExtractKeyAliasFromPublicKeyBase64(formats.EncodeKeyBase64(publicBytes[:])))
privateKey.alias = keyAlias
result[keyAlias] = privateKey

km := results[sp[0]]
if km == nil {
km = keyMap{}
}
km[keyAlias] = privateKey
results[sp[0]] = km
}

return result, nil
return results, nil
}

func parseKeyFromString(base64EncodedKey string) (*x25519Key, error) {
Expand Down
19 changes: 15 additions & 4 deletions pkg/adscert/signatory/signatory_local_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
)

func NewLocalAuthenticatedConnectionsSignatory(
logLevel string,
originCallsign string,
secureRandom io.Reader,
clock clock.Clock,
Expand All @@ -26,6 +27,12 @@ func NewLocalAuthenticatedConnectionsSignatory(
domainCheckInterval time.Duration,
domainRenewalInterval time.Duration,
base64PrivateKeys []string) *LocalAuthenticatedConnectionsSignatory {
logger.SetLevel(logger.GetLevelFromString(logLevel))
if originCallsign != "" {
for i := range base64PrivateKeys {
base64PrivateKeys[i] = originCallsign + "|" + base64PrivateKeys[i]
}
}
return &LocalAuthenticatedConnectionsSignatory{
originCallsign: originCallsign,
secureRandom: secureRandom,
Expand Down Expand Up @@ -91,12 +98,16 @@ func (s *LocalAuthenticatedConnectionsSignatory) SignAuthenticatedConnection(req
}

func (s *LocalAuthenticatedConnectionsSignatory) signSingleMessage(request *api.AuthenticatedConnectionSignatureRequest, domainInfo discovery.DomainInfo) (*api.SignatureInfo, error) {

sigInfo := &api.SignatureInfo{}
acs, err := formats.NewAuthenticatedConnectionSignature(formats.StatusOK, s.originCallsign, request.RequestInfo.InvokingDomain)

var originCallsign string
if request.RequestInfo.OriginDomain != "" {
originCallsign = request.RequestInfo.OriginDomain
} else {
originCallsign = s.originCallsign
}
acs, err := formats.NewAuthenticatedConnectionSignature(formats.StatusOK, originCallsign, request.RequestInfo.InvokingDomain)
if err != nil {
acs.SetStatus(formats.StatusErrorOnSignature)
setSignatureInfoFromAuthenticatedConnection(sigInfo, acs)
return sigInfo, fmt.Errorf("error constructing authenticated connection signature format: %v", err)
}

Expand Down