diff --git a/service_provider.go b/service_provider.go index 30b35670..61afa1b1 100644 --- a/service_provider.go +++ b/service_provider.go @@ -5,6 +5,8 @@ import ( "compress/flate" "context" "crypto/rsa" + "crypto/sha256" + "crypto/sha512" "crypto/tls" "crypto/x509" "encoding/base64" @@ -91,6 +93,18 @@ type ServiceProvider struct { // IDPMetadata is the metadata from the identity provider. IDPMetadata *EntityDescriptor + // IDPCertificateFingerprint is fingerprint of the idp public certificate. If this field is specified, + // IDPCertificateFingerprintAlgorithm must also be specified, and IDPCertificate must not be specified. + IDPCertificateFingerprint *string + // IDPCertificateFingerprintAlgorithm is fingerprint algorithm used to obtain fingerprint of the idp public + // certificate. + // If this field is specified, IDPCertificateFingerprint must also be specified, and IDPCertificate must not be specified. + IDPCertificateFingerprintAlgorithm *string + + // IDPCertificate to use as idp public certificate. If this field is specified, IDPCertificateFingerprint and + // IDPCertificateFingerprintAlgorithm must not be specified. + IDPCertificate *string + // AuthnNameIDFormat is the format used in the NameIDPolicy for // authentication requests AuthnNameIDFormat NameIDFormat @@ -378,6 +392,85 @@ func (sp *ServiceProvider) getIDPSigningCerts() ([]*x509.Certificate, error) { return certs, nil } +func (sp *ServiceProvider) getCertBasedOnFingerprint(el *etree.Element) ([]*x509.Certificate, error) { + x509CertEl := el.FindElement("./Signature/KeyInfo/X509Data/X509Certificate") + if x509CertEl == nil { + return nil, fmt.Errorf("cannot validate signature on %s: no certificate present", el.Tag) + } + if len(x509CertEl.Child) != 1 { + return nil, fmt.Errorf("cannot validate signature on %s: x509 cert el child len != 1: %d", el.Tag, len(x509CertEl.Child)) + } + + x509CertElCharData, ok := x509CertEl.Child[0].(*etree.CharData) + if !ok { + return nil, fmt.Errorf("cannot validate signature on %s: x509 cert el first child not char data: %T", el.Tag, x509CertEl.Child[0]) + } + + cert, err := parseCert(x509CertElCharData.Data) + if err != nil { + return nil, fmt.Errorf("cannot validate signature on %s: %w", el.Tag, err) + } + + finP, err := fingerprint(cert, *sp.IDPCertificateFingerprintAlgorithm) + if err != nil { + return nil, fmt.Errorf("cannot validate signature on %s: %w", el.Tag, err) + } + + if *sp.IDPCertificateFingerprint != finP { + return nil, fmt.Errorf("cannot validate signature on %s: fingerprint mismatch", el.Tag) + } + + return []*x509.Certificate{cert}, nil + +} + +func parseCert(x509Data string) (*x509.Certificate, error) { + // cleanup whitespace + regex := regexp.MustCompile(`\s+`) + certStr := regex.ReplaceAllString(x509Data, "") + certBytes, err := base64.StdEncoding.DecodeString(certStr) + if err != nil { + return nil, fmt.Errorf("parse cert, cannot base64 decode cert string: %w", err) + } + + parsedCert, err := x509.ParseCertificate(certBytes) + if err != nil { + return nil, fmt.Errorf("parse cert, cannot parse certificate: %w", err) + } + + return parsedCert, nil +} + +func fingerprint(cert *x509.Certificate, fingerprintAlgorithm string) (string, error) { + switch fingerprintAlgorithm { + case "http://www.w3.org/2001/04/xmlenc#sha256": + fp := sha256.Sum256(cert.Raw) + return fingerprintFormat(fp[:]) + case "http://www.w3.org/2001/04/xmlenc#sha512": + fp := sha512.Sum512(cert.Raw) + return fingerprintFormat(fp[:]) + default: + return "", fmt.Errorf("fingerprint, unknown algorithm: %s", fingerprintAlgorithm) + } +} + +func fingerprintFormat(fp []byte) (string, error) { + var buf bytes.Buffer + for i, f := range fp { + if i > 0 { + _, err := fmt.Fprintf(&buf, ":") + if err != nil { + return "", fmt.Errorf("fingerprint format, print ':': %w", err) + } + } + _, err := fmt.Fprintf(&buf, "%02X", f) + if err != nil { + return "", fmt.Errorf("fingerprint format, print bytes: %w", err) + } + } + return buf.String(), nil +} + // MakeArtifactResolveRequest produces a new ArtifactResolve object to send to the idp's Artifact resolver func (sp *ServiceProvider) MakeArtifactResolveRequest(artifactID string) (*ArtifactResolve, error) { req := ArtifactResolve{ @@ -1101,9 +1194,28 @@ func (sp *ServiceProvider) validateSignature(el *etree.Element) error { return errSignatureElementNotPresent } - certs, err := sp.getIDPSigningCerts() - if err != nil { - return fmt.Errorf("cannot validate signature on %s: %v", el.Tag, err) + var certs []*x509.Certificate + if sp.IDPMetadata != nil && sp.IDPCertificateFingerprint == nil && sp.IDPCertificateFingerprintAlgorithm == nil && sp.IDPCertificate == nil { + certs, err = sp.getIDPSigningCerts() + if err != nil { + return fmt.Errorf("cannot validate signature on %s: %v", el.Tag, err) + } + } + if sp.IDPMetadata != nil && sp.IDPCertificateFingerprint != nil && sp.IDPCertificateFingerprintAlgorithm != nil && sp.IDPCertificate == nil { + certs, err = sp.getCertBasedOnFingerprint(el) + if err != nil { + return fmt.Errorf("cannot validate signature on %s: %v", el.Tag, err) + } + } + if sp.IDPMetadata != nil && sp.IDPCertificateFingerprint == nil && sp.IDPCertificateFingerprintAlgorithm == nil && sp.IDPCertificate != nil { + cert, err := parseCert(*sp.IDPCertificate) + if err != nil { + return fmt.Errorf("cannot validate signature on %s: %w", el.Tag, err) + } + certs = append(certs, cert) + } + if len(certs) == 0 { + return fmt.Errorf("cannot validate signature on %s: saml config not set up properly, specify either idp metadata url, fingerprints or actual certificate", el.Tag) } certificateStore := dsig.MemoryX509CertificateStore{