Skip to content
Open
2 changes: 2 additions & 0 deletions api/adscert.proto
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ message RequestInfo {
bytes url_hash = 2;
bytes body_hash = 3;
repeated SignatureInfo signature_info = 4;
// useful if 1 signatory is managing multiple origin domains such as in resellers case.
string origin_domain = 5;
}

// SignatureInfo captures the signature generated for the signing request. It
Expand Down
4 changes: 0 additions & 4 deletions cmd/server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,6 @@ func main() {
logger.SetLevel(parsedLogLevel)
logger.Infof("Log Level: %s, parsed as iota %v", *logLevel, parsedLogLevel)

if *origin == "" {
logger.Fatalf("Origin ads.cert Call Sign domain name is required")
}

if *privateKey == "" {
logger.Fatalf("Private key is required")
}
Expand Down
4 changes: 4 additions & 0 deletions examples/signer-client/signer-client.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (

var (
serverAddress = flag.String("server_address", "localhost:3000", "address of grpc server")
originDomain = flag.String("origin_domain", "", "Origin domain")
destinationURL = flag.String("url", "https://google.com/gen_204", "URL to invoke")
body = flag.String("body", "", "POST request body")
signingTimeout = flag.Duration("signing_timeout", 5*time.Millisecond, "Specifies how long this client will wait for signing to finish before abandoning.")
Expand Down Expand Up @@ -49,6 +50,9 @@ func main() {
// destination URL and body, setting these value on the RequestInfo message.
reqInfo := &api.RequestInfo{}
signatory.SetRequestInfo(reqInfo, *destinationURL, []byte(*body))
if originDomain != nil {
reqInfo.OriginDomain = *originDomain
}

// Request the signature.
logger.Infof("signing request for url: %v", *destinationURL)
Expand Down
16 changes: 8 additions & 8 deletions internal/formats/adscert_connection_signature.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,21 +133,21 @@ func EncodeSignatureSuffix(

func NewAuthenticatedConnectionSignature(status AuthenticatedConnectionProtocolStatus, from string, invoking string) (*AuthenticatedConnectionSignature, error) {

s := &AuthenticatedConnectionSignature{}
s.status = status
s.from = from
s.invoking = invoking

if status == StatusUnspecified {
return nil, ErrParamMissingStatus
return s, ErrParamMissingStatus
}
if from == "" {
return nil, ErrParamMissingFrom
return s, ErrParamMissingFrom
}
if invoking == "" {
return nil, ErrParamMissingInvoking
return s, ErrParamMissingInvoking
}

s := &AuthenticatedConnectionSignature{}
s.status = status
s.from = from
s.invoking = invoking

return s, nil
}

Expand Down
12 changes: 4 additions & 8 deletions internal/formats/adscert_connection_signature_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ func TestNewAuthenticatedConnectionSignature(t *testing.T) {
nonce string

wantNewACSErr error
wantNilACS bool
wantAddParamsForSignatureErr error
wantUnsignedBaseMessage string
wantUnsignedExtendedMessage string
Expand Down Expand Up @@ -58,7 +57,6 @@ func TestNewAuthenticatedConnectionSignature(t *testing.T) {
invoking: "invoking.com",

wantNewACSErr: formats.ErrParamMissingStatus,
wantNilACS: true,
},
{
desc: "check ErrParamMissingFrom",
Expand All @@ -67,7 +65,6 @@ func TestNewAuthenticatedConnectionSignature(t *testing.T) {
invoking: "invoking.com",

wantNewACSErr: formats.ErrParamMissingFrom,
wantNilACS: true,
},
{
desc: "check ErrParamMissingInvoking",
Expand All @@ -76,7 +73,6 @@ func TestNewAuthenticatedConnectionSignature(t *testing.T) {
invoking: "",

wantNewACSErr: formats.ErrParamMissingInvoking,
wantNilACS: true,
},

{
Expand Down Expand Up @@ -167,12 +163,12 @@ func TestNewAuthenticatedConnectionSignature(t *testing.T) {
t.Errorf("NewAuthenticatedConnectionSignature() %s error check: got %v, want %v", tC.desc, gotErr, tC.wantNewACSErr)
}

gotNilACS := (acs == nil)
if tC.wantNilACS != gotNilACS {
t.Fatalf("NewAuthenticatedConnectionSignature() %s nil check: got (acs == nil) %v, want %v", tC.desc, gotNilACS, tC.wantNilACS)
if acs == nil {
t.Fatalf("NewAuthenticatedConnectionSignature() %s nil check: got (acs == nil), should not be nil", tC.desc)
}

if gotNilACS {
// skip rest of tests if an error was returned
if gotErr != nil {
return
}

Expand Down
305 changes: 158 additions & 147 deletions pkg/adscert/api/adscert.pb.go

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions pkg/adscert/api/adscert_grpc.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

47 changes: 28 additions & 19 deletions pkg/adscert/discovery/domain_indexer_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ func NewDefaultDomainIndexer(dnsResolver DNSResolver, domainStore DomainStore, d
domainRenewalInterval: domainRenewalInterval,
dnsResolver: dnsResolver,
domainStore: domainStore,
currentPrivateKey: make(map[string]keyAlias),
}

myPrivateKeys, err := privateKeysToKeyMap(base64PrivateKeys)
Expand All @@ -39,12 +40,14 @@ func NewDefaultDomainIndexer(dnsResolver DNSResolver, domainStore DomainStore, d
}
di.myPrivateKeys = myPrivateKeys

for _, privateKey := range di.myPrivateKeys {
// since iterating over a map is non-deterministic, we can make sure to set the key
// either if it is not already set or it is alphabetically less than current key at the index when
// iterating over the private keys map.
if di.currentPrivateKey == "" || di.currentPrivateKey < privateKey.alias {
di.currentPrivateKey = privateKey.alias
for originCallsign := range di.myPrivateKeys {
for _, privateKey := range di.myPrivateKeys[originCallsign] {
// since iterating over a map is non-deterministic, we can make sure to set the key
// either if it is not already set or it is alphabetically less than current key at the index when
// iterating over the private keys map.
if di.currentPrivateKey[originCallsign] == "" || di.currentPrivateKey[originCallsign] < privateKey.alias {
di.currentPrivateKey[originCallsign] = privateKey.alias
}
}
}

Expand All @@ -62,8 +65,8 @@ type defaultDomainIndexer struct {
lastRun time.Time
lastRunLock sync.RWMutex

myPrivateKeys keyMap
currentPrivateKey keyAlias
myPrivateKeys map[string]keyMap
currentPrivateKey map[string]keyAlias

dnsResolver DNSResolver
domainStore DomainStore
Expand Down Expand Up @@ -227,21 +230,27 @@ func (di *defaultDomainIndexer) checkDomainForKeyRecords(ctx context.Context, cu
}

// create shared secrets for each private key + public key combination
for _, myKey := range di.myPrivateKeys {
for _, theirKey := range currentDomainInfo.allPublicKeys {
keyPairAlias := newKeyPairAlias(myKey.alias, theirKey.alias)
if currentDomainInfo.allSharedSecrets[keyPairAlias] == nil {
currentDomainInfo.allSharedSecrets[keyPairAlias], err = calculateSharedSecret(myKey, theirKey)
if err != nil {
logger.Warningf("error calculating shared secret for record %s: %v", currentDomainInfo.Domain, err)
currentDomainInfo.domainStatus = DomainStatusErrorOnSharedSecretCalculation
for originCallsign := range di.myPrivateKeys {
if originCallsign != currentDomainInfo.Domain {
continue
}

for _, myKey := range di.myPrivateKeys[originCallsign] {
for _, theirKey := range currentDomainInfo.allPublicKeys {
keyPairAlias := newKeyPairAlias(myKey.alias, theirKey.alias)
if currentDomainInfo.allSharedSecrets[keyPairAlias] == nil {
currentDomainInfo.allSharedSecrets[keyPairAlias], err = calculateSharedSecret(myKey, theirKey)
if err != nil {
logger.Warningf("error calculating shared secret for record %s: %v", currentDomainInfo.Domain, err)
currentDomainInfo.domainStatus = DomainStatusErrorOnSharedSecretCalculation
}
}
}
}
}

currentDomainInfo.currentSharedSecretId = newKeyPairAlias(di.currentPrivateKey, currentDomainInfo.currentPublicKeyId)
currentDomainInfo.lastUpdateTime = time.Now()
currentDomainInfo.currentSharedSecretId = newKeyPairAlias(di.currentPrivateKey[originCallsign], currentDomainInfo.currentPublicKeyId)
currentDomainInfo.lastUpdateTime = time.Now()
}
}

func parsePolicyRecords(baseSubdomain string, baseSubdomainRecords []string) (foundDomains []string, parseError bool) {
Expand Down
23 changes: 17 additions & 6 deletions pkg/adscert/discovery/internal_base_key.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package discovery

import (
"errors"
"fmt"
"strings"

"github.com/IABTechLab/adscert/internal/formats"
"github.com/IABTechLab/adscert/pkg/adscert/logger"
Expand Down Expand Up @@ -76,11 +78,14 @@ func calculateSharedSecret(originPrivateKey *x25519Key, remotePublicKey *x25519K
return result, err
}

func privateKeysToKeyMap(privateKeys []string) (keyMap, error) {
result := keyMap{}

func privateKeysToKeyMap(privateKeys []string) (map[string]keyMap, error) {
results := map[string]keyMap{}
for _, privateKeyBase64 := range privateKeys {
privateKey, err := parseKeyFromString(privateKeyBase64)
sp := strings.SplitN(privateKeyBase64, "|", 2)
if len(sp) < 2 {
return nil, errors.New("missing origin callsign")
}
privateKey, err := parseKeyFromString(sp[1])
if err != nil {
return nil, err
}
Expand All @@ -90,10 +95,16 @@ func privateKeysToKeyMap(privateKeys []string) (keyMap, error) {

keyAlias := keyAlias(formats.ExtractKeyAliasFromPublicKeyBase64(formats.EncodeKeyBase64(publicBytes[:])))
privateKey.alias = keyAlias
result[keyAlias] = privateKey

km := results[sp[0]]
if km == nil {
km = keyMap{}
}
km[keyAlias] = privateKey
results[sp[0]] = km
}

return result, nil
return results, nil
}

func parseKeyFromString(base64EncodedKey string) (*x25519Key, error) {
Expand Down
15 changes: 13 additions & 2 deletions pkg/adscert/signatory/signatory_local_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ func NewLocalAuthenticatedConnectionsSignatory(
domainCheckInterval time.Duration,
domainRenewalInterval time.Duration,
base64PrivateKeys []string) *LocalAuthenticatedConnectionsSignatory {
if originCallsign != "" {
for i := range base64PrivateKeys {
base64PrivateKeys[i] = originCallsign + "|" + base64PrivateKeys[i]
}
}
return &LocalAuthenticatedConnectionsSignatory{
originCallsign: originCallsign,
secureRandom: secureRandom,
Expand Down Expand Up @@ -91,9 +96,15 @@ func (s *LocalAuthenticatedConnectionsSignatory) SignAuthenticatedConnection(req
}

func (s *LocalAuthenticatedConnectionsSignatory) signSingleMessage(request *api.AuthenticatedConnectionSignatureRequest, domainInfo discovery.DomainInfo) (*api.SignatureInfo, error) {

sigInfo := &api.SignatureInfo{}
acs, err := formats.NewAuthenticatedConnectionSignature(formats.StatusOK, s.originCallsign, request.RequestInfo.InvokingDomain)

var originCallsign string
if request.RequestInfo.OriginDomain != "" {
originCallsign = request.RequestInfo.OriginDomain
} else {
originCallsign = s.originCallsign
}
acs, err := formats.NewAuthenticatedConnectionSignature(formats.StatusOK, originCallsign, request.RequestInfo.InvokingDomain)
if err != nil {
acs.SetStatus(formats.StatusErrorOnSignature)
setSignatureInfoFromAuthenticatedConnection(sigInfo, acs)
Expand Down