From b667a3d56d6fdb73f769991c5a7234b11feb6133 Mon Sep 17 00:00:00 2001 From: Stojan Dimitrovski Date: Tue, 24 Dec 2024 18:48:29 +0100 Subject: [PATCH] feat: cover 100% of `crypto` with tests --- internal/api/hooks.go | 33 ++++++-- internal/api/magic_link.go | 5 +- internal/api/mail.go | 62 +++++--------- internal/api/mfa.go | 7 +- internal/api/mfa_test.go | 2 +- internal/api/phone.go | 7 +- internal/crypto/crypto.go | 94 +++++---------------- internal/crypto/crypto_test.go | 83 ++++++++++++++++++- internal/crypto/password.go | 89 +++++++------------- internal/crypto/password_test.go | 135 +++++++++++++++++++++---------- internal/crypto/utils.go | 9 +++ internal/crypto/utils_test.go | 14 ++++ 12 files changed, 303 insertions(+), 237 deletions(-) create mode 100644 internal/crypto/utils.go create mode 100644 internal/crypto/utils_test.go diff --git a/internal/api/hooks.go b/internal/api/hooks.go index 14e2b3cd4..2cf99cd2a 100644 --- a/internal/api/hooks.go +++ b/internal/api/hooks.go @@ -14,14 +14,12 @@ import ( "time" "github.com/gofrs/uuid" - "github.com/supabase/auth/internal/observability" + "github.com/sirupsen/logrus" + standardwebhooks "github.com/standard-webhooks/standard-webhooks/libraries/go" "github.com/supabase/auth/internal/conf" - "github.com/supabase/auth/internal/crypto" - - "github.com/sirupsen/logrus" "github.com/supabase/auth/internal/hooks" - + "github.com/supabase/auth/internal/observability" "github.com/supabase/auth/internal/storage" ) @@ -103,7 +101,7 @@ func (a *API) runHTTPHook(r *http.Request, hookConfig conf.ExtensibilityPointCon } msgID := uuid.Must(uuid.NewV4()) currentTime := time.Now() - signatureList, err := crypto.GenerateSignatures(hookConfig.HTTPHookSecrets, msgID, currentTime, inputPayload) + signatureList, err := generateSignatures(hookConfig.HTTPHookSecrets, msgID, currentTime, inputPayload) if err != nil { return nil, err } @@ -382,3 +380,26 @@ func (a *API) runHook(r *http.Request, conn *storage.Connection, hookConfig conf return response, nil } + +func generateSignatures(secrets []string, msgID uuid.UUID, currentTime time.Time, inputPayload []byte) ([]string, error) { + SymmetricSignaturePrefix := "v1," + // TODO(joel): Handle asymmetric case once library has been upgraded + var signatureList []string + for _, secret := range secrets { + if strings.HasPrefix(secret, SymmetricSignaturePrefix) { + trimmedSecret := strings.TrimPrefix(secret, SymmetricSignaturePrefix) + wh, err := standardwebhooks.NewWebhook(trimmedSecret) + if err != nil { + return nil, err + } + signature, err := wh.Sign(msgID.String(), currentTime, inputPayload) + if err != nil { + return nil, err + } + signatureList = append(signatureList, signature) + } else { + return nil, errors.New("invalid signature format") + } + } + return signatureList, nil +} diff --git a/internal/api/magic_link.go b/internal/api/magic_link.go index 15030bde2..57b0a7d8b 100644 --- a/internal/api/magic_link.go +++ b/internal/api/magic_link.go @@ -83,10 +83,7 @@ func (a *API) MagicLink(w http.ResponseWriter, r *http.Request) error { if isNewUser { // User either doesn't exist or hasn't completed the signup process. // Sign them up with temporary password. - password, err := crypto.GeneratePassword(config.Password.RequiredCharacters, 33) - if err != nil { - return internalServerError("error creating user").WithInternalError(err) - } + password := crypto.GeneratePassword(config.Password.RequiredCharacters, 33) signUpParams := &SignupParams{ Email: params.Email, diff --git a/internal/api/mail.go b/internal/api/mail.go index b1492dffb..f2ea69b80 100644 --- a/internal/api/mail.go +++ b/internal/api/mail.go @@ -87,11 +87,7 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error { var url string now := time.Now() - otp, err := crypto.GenerateOtp(config.Mailer.OtpLength) - if err != nil { - // OTP generation must always succeed - panic(err) - } + otp := crypto.GenerateOtp(config.Mailer.OtpLength) hashedToken := crypto.GenerateTokenHash(params.Email, otp) @@ -300,19 +296,18 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error { } func (a *API) sendConfirmation(r *http.Request, tx *storage.Connection, u *models.User, flowType models.FlowType) error { + var err error + config := a.config maxFrequency := config.SMTP.MaxFrequency otpLength := config.Mailer.OtpLength - if err := validateSentWithinFrequencyLimit(u.ConfirmationSentAt, maxFrequency); err != nil { + if err = validateSentWithinFrequencyLimit(u.ConfirmationSentAt, maxFrequency); err != nil { return err } oldToken := u.ConfirmationToken - otp, err := crypto.GenerateOtp(otpLength) - if err != nil { - // OTP generation must succeeed - panic(err) - } + otp := crypto.GenerateOtp(otpLength) + token := crypto.GenerateTokenHash(u.GetEmail(), otp) u.ConfirmationToken = addFlowPrefixToToken(token, flowType) now := time.Now() @@ -342,11 +337,8 @@ func (a *API) sendInvite(r *http.Request, tx *storage.Connection, u *models.User otpLength := config.Mailer.OtpLength var err error oldToken := u.ConfirmationToken - otp, err := crypto.GenerateOtp(otpLength) - if err != nil { - // OTP generation must succeed - panic(err) - } + otp := crypto.GenerateOtp(otpLength) + u.ConfirmationToken = crypto.GenerateTokenHash(u.GetEmail(), otp) now := time.Now() if err = a.sendEmail(r, tx, u, mail.InviteVerification, otp, "", u.ConfirmationToken); err != nil { @@ -382,15 +374,12 @@ func (a *API) sendPasswordRecovery(r *http.Request, tx *storage.Connection, u *m } oldToken := u.RecoveryToken - otp, err := crypto.GenerateOtp(otpLength) - if err != nil { - // OTP generation must succeed - panic(err) - } + otp := crypto.GenerateOtp(otpLength) + token := crypto.GenerateTokenHash(u.GetEmail(), otp) u.RecoveryToken = addFlowPrefixToToken(token, flowType) now := time.Now() - if err = a.sendEmail(r, tx, u, mail.RecoveryVerification, otp, "", u.RecoveryToken); err != nil { + if err := a.sendEmail(r, tx, u, mail.RecoveryVerification, otp, "", u.RecoveryToken); err != nil { u.RecoveryToken = oldToken if errors.Is(err, EmailRateLimitExceeded) { return tooManyRequestsError(ErrorCodeOverEmailSendRateLimit, EmailRateLimitExceeded.Error()) @@ -422,11 +411,8 @@ func (a *API) sendReauthenticationOtp(r *http.Request, tx *storage.Connection, u } oldToken := u.ReauthenticationToken - otp, err := crypto.GenerateOtp(otpLength) - if err != nil { - // OTP generation must succeed - panic(err) - } + otp := crypto.GenerateOtp(otpLength) + u.ReauthenticationToken = crypto.GenerateTokenHash(u.GetEmail(), otp) now := time.Now() @@ -452,6 +438,7 @@ func (a *API) sendReauthenticationOtp(r *http.Request, tx *storage.Connection, u } func (a *API) sendMagicLink(r *http.Request, tx *storage.Connection, u *models.User, flowType models.FlowType) error { + var err error config := a.config otpLength := config.Mailer.OtpLength @@ -462,11 +449,8 @@ func (a *API) sendMagicLink(r *http.Request, tx *storage.Connection, u *models.U } oldToken := u.RecoveryToken - otp, err := crypto.GenerateOtp(otpLength) - if err != nil { - // OTP generation must succeed - panic(err) - } + otp := crypto.GenerateOtp(otpLength) + token := crypto.GenerateTokenHash(u.GetEmail(), otp) u.RecoveryToken = addFlowPrefixToToken(token, flowType) @@ -501,22 +485,16 @@ func (a *API) sendEmailChange(r *http.Request, tx *storage.Connection, u *models return err } - otpNew, err := crypto.GenerateOtp(otpLength) - if err != nil { - // OTP generation must succeed - panic(err) - } + otpNew := crypto.GenerateOtp(otpLength) + u.EmailChange = email token := crypto.GenerateTokenHash(u.EmailChange, otpNew) u.EmailChangeTokenNew = addFlowPrefixToToken(token, flowType) otpCurrent := "" if config.Mailer.SecureEmailChangeEnabled && u.GetEmail() != "" { - otpCurrent, err = crypto.GenerateOtp(otpLength) - if err != nil { - // OTP generation must succeed - panic(err) - } + otpCurrent = crypto.GenerateOtp(otpLength) + currentToken := crypto.GenerateTokenHash(u.GetEmail(), otpCurrent) u.EmailChangeTokenCurrent = addFlowPrefixToToken(currentToken, flowType) } diff --git a/internal/api/mfa.go b/internal/api/mfa.go index c1b83f9c8..4ac2b9b49 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -393,10 +393,9 @@ func (a *API) challengePhoneFactor(w http.ResponseWriter, r *http.Request) error return tooManyRequestsError(ErrorCodeOverSMSSendRateLimit, generateFrequencyLimitErrorMessage(factor.LastChallengedAt, config.MFA.Phone.MaxFrequency)) } } - otp, err := crypto.GenerateOtp(config.MFA.Phone.OtpLength) - if err != nil { - panic(err) - } + + otp := crypto.GenerateOtp(config.MFA.Phone.OtpLength) + challenge, err := factor.CreatePhoneChallenge(ipAddress, otp, config.Security.DBEncryption.Encrypt, config.Security.DBEncryption.EncryptionKeyID, config.Security.DBEncryption.EncryptionKey) if err != nil { return internalServerError("error creating SMS Challenge") diff --git a/internal/api/mfa_test.go b/internal/api/mfa_test.go index 4e2a79758..653f38f68 100644 --- a/internal/api/mfa_test.go +++ b/internal/api/mfa_test.go @@ -525,7 +525,7 @@ func (ts *MFATestSuite) TestMFAVerifyFactor() { } else if v.factorType == models.Phone { friendlyName := uuid.Must(uuid.NewV4()).String() numDigits := 10 - otp, err := crypto.GenerateOtp(numDigits) + otp := crypto.GenerateOtp(numDigits) require.NoError(ts.T(), err) phone := fmt.Sprintf("+%s", otp) f = models.NewPhoneFactor(ts.TestUser, phone, friendlyName) diff --git a/internal/api/phone.go b/internal/api/phone.go index f7a8fbda8..503388809 100644 --- a/internal/api/phone.go +++ b/internal/api/phone.go @@ -77,7 +77,6 @@ func (a *API) sendPhoneConfirmation(r *http.Request, tx *storage.Connection, use now := time.Now() var otp, messageID string - var err error if testOTP, ok := config.Sms.GetTestOTP(phone, now); ok { otp = testOTP @@ -93,10 +92,8 @@ func (a *API) sendPhoneConfirmation(r *http.Request, tx *storage.Connection, use return "", tooManyRequestsError(ErrorCodeOverSMSSendRateLimit, "SMS rate limit exceeded") } } - otp, err = crypto.GenerateOtp(config.Sms.OtpLength) - if err != nil { - return "", internalServerError("error generating otp").WithInternalError(err) - } + otp = crypto.GenerateOtp(config.Sms.OtpLength) + if config.Hook.SendSMS.Enabled { input := hooks.SendSMSInput{ User: user, diff --git a/internal/crypto/crypto.go b/internal/crypto/crypto.go index 52ff2f5fb..bd7764a70 100644 --- a/internal/crypto/crypto.go +++ b/internal/crypto/crypto.go @@ -13,74 +13,44 @@ import ( "math/big" "strconv" "strings" - "time" - "github.com/gofrs/uuid" - standardwebhooks "github.com/standard-webhooks/standard-webhooks/libraries/go" "golang.org/x/crypto/hkdf" - - "github.com/pkg/errors" ) // SecureToken creates a new random token func SecureToken(options ...int) string { length := 16 - if len(options) > 0 { + if len(options) == 1 { length = options[0] + } else if len(options) > 1 { + panic("crypto: only zero or one arguments allowed on SecureToken") } + b := make([]byte, length) - if _, err := io.ReadFull(rand.Reader, b); err != nil { - panic(err.Error()) // rand should never fail - } + must(io.ReadFull(rand.Reader, b)) + return base64.RawURLEncoding.EncodeToString(b) } // GenerateOtp generates a random n digit otp -func GenerateOtp(digits int) (string, error) { +func GenerateOtp(digits int) string { upper := math.Pow10(digits) - val, err := rand.Int(rand.Reader, big.NewInt(int64(upper))) - if err != nil { - return "", errors.WithMessage(err, "Error generating otp") - } + val := must(rand.Int(rand.Reader, big.NewInt(int64(upper)))) + // adds a variable zero-padding to the left to ensure otp is uniformly random expr := "%0" + strconv.Itoa(digits) + "v" otp := fmt.Sprintf(expr, val.String()) - return otp, nil + + return otp } func GenerateTokenHash(emailOrPhone, otp string) string { return fmt.Sprintf("%x", sha256.Sum224([]byte(emailOrPhone+otp))) } // Generated a random secure integer from [0, max[ -func secureRandomInt(max int) (int, error) { - randomInt, err := rand.Int(rand.Reader, big.NewInt(int64(max))) - if err != nil { - return 0, errors.WithMessage(err, "Error generating random integer") - } - return int(randomInt.Int64()), nil -} - -func GenerateSignatures(secrets []string, msgID uuid.UUID, currentTime time.Time, inputPayload []byte) ([]string, error) { - SymmetricSignaturePrefix := "v1," - // TODO(joel): Handle asymmetric case once library has been upgraded - var signatureList []string - for _, secret := range secrets { - if strings.HasPrefix(secret, SymmetricSignaturePrefix) { - trimmedSecret := strings.TrimPrefix(secret, SymmetricSignaturePrefix) - wh, err := standardwebhooks.NewWebhook(trimmedSecret) - if err != nil { - return nil, err - } - signature, err := wh.Sign(msgID.String(), currentTime, inputPayload) - if err != nil { - return nil, err - } - signatureList = append(signatureList, signature) - } else { - return nil, errors.New("invalid signature format") - } - } - return signatureList, nil +func secureRandomInt(max int) int { + randomInt := must(rand.Int(rand.Reader, big.NewInt(int64(max)))) + return int(randomInt.Int64()) } type EncryptedString struct { @@ -111,15 +81,8 @@ func (es *EncryptedString) Decrypt(id string, decryptionKeys map[string]string) return nil, err } - aes, err := aes.NewCipher(key) - if err != nil { - return nil, err - } - - cipher, err := cipher.NewGCM(aes) - if err != nil { - return nil, err - } + block := must(aes.NewCipher(key)) + cipher := must(cipher.NewGCM(block)) decrypted, err := cipher.Open(nil, es.Nonce, es.Data, nil) // #nosec G407 if err != nil { @@ -148,10 +111,7 @@ func ParseEncryptedString(str string) *EncryptedString { } func (es *EncryptedString) String() string { - out, err := json.Marshal(es) - if err != nil { - panic(err) - } + out := must(json.Marshal(es)) return string(out) } @@ -179,9 +139,7 @@ func deriveSymmetricKey(id, keyID, keyBase64URL string) ([]byte, error) { keyReader := hkdf.New(sha256.New, hkdfKey, nil, []byte(id)) key := make([]byte, 256/8) - if _, err := io.ReadFull(keyReader, key); err != nil { - panic(err) - } + must(io.ReadFull(keyReader, key)) return key, nil } @@ -192,15 +150,8 @@ func NewEncryptedString(id string, data []byte, keyID string, keyBase64URL strin return nil, err } - aes, err := aes.NewCipher(key) - if err != nil { - return nil, err - } - - cipher, err := cipher.NewGCM(aes) - if err != nil { - panic(err) - } + block := must(aes.NewCipher(block)) + cipher := must(cipher.NewGCM(aes)) es := EncryptedString{ KeyID: keyID, @@ -208,10 +159,7 @@ func NewEncryptedString(id string, data []byte, keyID string, keyBase64URL strin Nonce: make([]byte, 12), } - if _, err := io.ReadFull(rand.Reader, es.Nonce); err != nil { - panic(err) - } - + must(io.ReadFull(rand.Reader, es.Nonce)) es.Data = cipher.Seal(nil, es.Nonce, data, nil) // #nosec G407 return &es, nil diff --git a/internal/crypto/crypto_test.go b/internal/crypto/crypto_test.go index b677b918d..0d343e9d0 100644 --- a/internal/crypto/crypto_test.go +++ b/internal/crypto/crypto_test.go @@ -5,9 +5,10 @@ import ( "github.com/gofrs/uuid" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -func TestEncryptedString(t *testing.T) { +func TestEncryptedStringPositive(t *testing.T) { id := uuid.Must(uuid.NewV4()).String() es, err := NewEncryptedString(id, []byte("data"), "key-id", "pwFoiPyybQMqNmYVN0gUnpbfpGQV2sDv9vp0ZAxi_Y4") @@ -32,3 +33,83 @@ func TestEncryptedString(t *testing.T) { assert.NoError(t, err) assert.Equal(t, []byte("data"), decrypted) } + +func TestParseEncryptedStringNegative(t *testing.T) { + negativeExamples := []string{ + "not-an-encrypted-string", + // not json + "{{", + // not parsable json + `{"key_id":1}`, + `{"alg":1}`, + `{"data":"!!!"}`, + `{"nonce":"!!!"}`, + // not valid + `{}`, + `{"key_id":"key_id"}`, + `{"key_id":"key_id","alg":"different","data":"AQAB=","nonce":"AQAB="}`, + } + + for _, example := range negativeExamples { + assert.Nil(t, ParseEncryptedString(example)) + } +} + +func TestEncryptedStringDecryptNegative(t *testing.T) { + id := uuid.Must(uuid.NewV4()).String() + + // short key + _, err := NewEncryptedString(id, []byte("data"), "key-id", "short_key") + assert.Error(t, err) + + // not base64 + _, err = NewEncryptedString(id, []byte("data"), "key-id", "!!!") + assert.Error(t, err) + + es, err := NewEncryptedString(id, []byte("data"), "key-id", "pwFoiPyybQMqNmYVN0gUnpbfpGQV2sDv9vp0ZAxi_Y4") + assert.NoError(t, err) + + dec := ParseEncryptedString(es.String()) + assert.NotNil(t, dec) + + _, err = dec.Decrypt(id, map[string]string{ + // empty map + }) + assert.Error(t, err) + + // short key + _, err = dec.Decrypt(id, map[string]string{ + "key-id": "AQAB", + }) + assert.Error(t, err) + + // key not base64 + _, err = dec.Decrypt(id, map[string]string{ + "key-id": "!!!", + }) + assert.Error(t, err) + + // bad key + _, err = dec.Decrypt(id, map[string]string{ + "key-id": "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", + }) + assert.Error(t, err) + + // bad tag for AEAD failure + dec.Data[len(dec.Data)-1] += 1 + + _, err = dec.Decrypt(id, map[string]string{ + "key-id": "pwFoiPyybQMqNmYVN0gUnpbfpGQV2sDv9vp0ZAxi_Y4", + }) + assert.Error(t, err) +} + +func TestSecureToken(t *testing.T) { + require.Panics(t, func() { + SecureToken(1, 2) + }) + + // no panics + SecureToken() + SecureToken(1) +} diff --git a/internal/crypto/password.go b/internal/crypto/password.go index db146e520..26599ccbb 100644 --- a/internal/crypto/password.go +++ b/internal/crypto/password.go @@ -57,8 +57,8 @@ var ErrArgon2MismatchedHashAndPassword = errors.New("crypto: argon2 hash and pas var ErrScryptMismatchedHashAndPassword = errors.New("crypto: fbscrypt hash and password mismatch") // argon2HashRegexp https://github.com/P-H-C/phc-string-format/blob/master/phc-sf-spec.md#argon2-encoding -var argon2HashRegexp = regexp.MustCompile("^[$](?Pargon2(d|i|id))[$]v=(?P(16|19))[$]m=(?P[0-9]+),t=(?P[0-9]+),p=(?P

[0-9]+)(,keyid=(?P[^,$]+))?(,data=(?P[^$]+))?[$](?P[^$]+)[$](?P.+)$") -var scryptHashRegexp = regexp.MustCompile(`^\$(?Pfbscrypt)\$v=(?P[0-9]+),n=(?P[0-9]+),r=(?P[0-9]+),p=(?P

[0-9]+)(?:,ss=(?P[^,]+))?(?:,sk=(?P[^$]+))?\$(?P[^$]+)\$(?P.+)$`) +var argon2HashRegexp = regexp.MustCompile("^[$](?Pargon2(d|i|id))[$]v=(?P(16|19))[$]m=(?P[0-9]+),t=(?P[0-9]+),p=(?P

[0-9]+)(,keyid=(?P[^,$]+))?(,data=(?P[^$]+))?[$](?P[^$]*)[$](?P.*)$") +var scryptHashRegexp = regexp.MustCompile(`^\$fbscrypt\$v=(?P[0-9]+),n=(?P[0-9]+),r=(?P[0-9]+),p=(?P

[0-9]+)(?:,ss=(?P[^,]+))?(?:,sk=(?P[^$]+))?\$(?P[^$]+)\$(?P.+)$`) type Argon2HashInput struct { alg string @@ -73,7 +73,6 @@ type Argon2HashInput struct { } type FirebaseScryptHashInput struct { - alg string v string memory uint64 rounds uint64 @@ -91,7 +90,6 @@ func ParseFirebaseScryptHash(hash string) (*FirebaseScryptHashInput, error) { return nil, errors.New("crypto: incorrect scrypt hash format") } - alg := string(scryptHashRegexp.ExpandString(nil, "$alg", hash, submatch)) v := string(scryptHashRegexp.ExpandString(nil, "$v", hash, submatch)) n := string(scryptHashRegexp.ExpandString(nil, "$n", hash, submatch)) r := string(scryptHashRegexp.ExpandString(nil, "$r", hash, submatch)) @@ -101,9 +99,6 @@ func ParseFirebaseScryptHash(hash string) (*FirebaseScryptHashInput, error) { saltB64 := string(scryptHashRegexp.ExpandString(nil, "$salt", hash, submatch)) hashB64 := string(scryptHashRegexp.ExpandString(nil, "$hash", hash, submatch)) - if alg != "fbscrypt" { - return nil, fmt.Errorf("crypto: Firebase scrypt hash uses unsupported algorithm %q only fbscrypt supported", alg) - } if v != "1" { return nil, fmt.Errorf("crypto: Firebase scrypt hash uses unsupported version %q only version 1 is supported", v) } @@ -112,19 +107,23 @@ func ParseFirebaseScryptHash(hash string) (*FirebaseScryptHashInput, error) { return nil, fmt.Errorf("crypto: Firebase scrypt hash has invalid n parameter %q %w", n, err) } if memoryPower == 0 { - return nil, fmt.Errorf("crypto: Firebase scrypt hash has invalid n parameter %q: must be greater than 0", n) + return nil, fmt.Errorf("crypto: Firebase scrypt hash has invalid n=0") } - // Exponent is passed in - memory := uint64(1) << memoryPower rounds, err := strconv.ParseUint(r, 10, 64) if err != nil { return nil, fmt.Errorf("crypto: Firebase scrypt hash has invalid r parameter %q: %w", r, err) } + if rounds == 0 { + return nil, fmt.Errorf("crypto: Firebase scrypt hash has invalid r=0") + } threads, err := strconv.ParseUint(p, 10, 8) if err != nil { return nil, fmt.Errorf("crypto: Firebase scrypt hash has invalid p parameter %q %w", p, err) } + if threads == 0 { + return nil, fmt.Errorf("crypto: Firebase scrypt hash has invalid p=0") + } rawHash, err := base64.StdEncoding.DecodeString(hashB64) if err != nil { @@ -145,9 +144,8 @@ func ParseFirebaseScryptHash(hash string) (*FirebaseScryptHashInput, error) { } input := &FirebaseScryptHashInput{ - alg: alg, v: v, - memory: memory, + memory: uint64(1) << memoryPower, rounds: rounds, threads: threads, salt: salt, @@ -288,7 +286,6 @@ func compareHashAndPasswordFirebaseScrypt(ctx context.Context, hash, password st } attributes := []attribute.KeyValue{ - attribute.String("alg", input.alg), attribute.String("v", input.v), attribute.Int64("n", int64(input.memory)), attribute.Int64("r", int64(input.rounds)), @@ -297,48 +294,33 @@ func compareHashAndPasswordFirebaseScrypt(ctx context.Context, hash, password st } // #nosec G115 var match bool - var derivedKey []byte compareHashAndPasswordSubmittedCounter.Add(ctx, 1, metric.WithAttributes(attributes...)) defer func() { attributes = append(attributes, attribute.Bool("match", match)) compareHashAndPasswordCompletedCounter.Add(ctx, 1, metric.WithAttributes(attributes...)) }() - switch input.alg { - case "fbscrypt": - derivedKey, err = firebaseScrypt([]byte(password), input.salt, input.signerKey, input.saltSeparator, input.memory, input.rounds, input.threads, FirebaseScryptKeyLen) - if err != nil { - return err - } + derivedKey := firebaseScrypt([]byte(password), input.salt, input.signerKey, input.saltSeparator, input.memory, input.rounds, input.threads) - match = subtle.ConstantTimeCompare(derivedKey, input.rawHash) == 1 - if !match { - return ErrScryptMismatchedHashAndPassword - } - - default: - return fmt.Errorf("unsupported algorithm: %s", input.alg) + match = subtle.ConstantTimeCompare(derivedKey, input.rawHash) == 1 + if !match { + return ErrScryptMismatchedHashAndPassword } return nil } -func firebaseScrypt(password, salt, signerKey, saltSeparator []byte, memCost, rounds, p, keyLen uint64) ([]byte, error) { - ck, err := scrypt.Key(password, append(salt, saltSeparator...), int(memCost), int(rounds), int(p), int(keyLen)) // #nosec G115 - if err != nil { - return nil, err - } - - var block cipher.Block - if block, err = aes.NewCipher(ck); err != nil { - return nil, err - } +func firebaseScrypt(password, salt, signerKey, saltSeparator []byte, memCost, rounds, p uint64) []byte { + ck := must(scrypt.Key(password, append(salt, saltSeparator...), int(memCost), int(rounds), int(p), FirebaseScryptKeyLen)) // #nosec G115 + block := must(aes.NewCipher(ck)) cipherText := make([]byte, aes.BlockSize+len(signerKey)) + // #nosec G407 -- Firebase scrypt requires deterministic IV for consistent results. See: JaakkoL/firebase-scrypt-python@master/firebasescrypt/firebasescrypt.py#L58 stream := cipher.NewCTR(block, cipherText[:aes.BlockSize]) stream.XORKeyStream(cipherText[aes.BlockSize:], signerKey) - return cipherText[aes.BlockSize:], nil + + return cipherText[aes.BlockSize:] } // CompareHashAndPassword compares the hash and @@ -380,14 +362,11 @@ func CompareHashAndPassword(ctx context.Context, hash, password string) error { // password, using PasswordHashCost. Context can be used to cancel the hashing // if the algorithm supports it. func GenerateFromPassword(ctx context.Context, password string) (string, error) { - var hashCost int + hashCost := bcrypt.DefaultCost switch PasswordHashCost { case QuickHashCost: hashCost = bcrypt.MinCost - - default: - hashCost = bcrypt.DefaultCost } attributes := []attribute.KeyValue{ @@ -398,25 +377,20 @@ func GenerateFromPassword(ctx context.Context, password string) (string, error) generateFromPasswordSubmittedCounter.Add(ctx, 1, metric.WithAttributes(attributes...)) defer generateFromPasswordCompletedCounter.Add(ctx, 1, metric.WithAttributes(attributes...)) - hash, err := bcrypt.GenerateFromPassword([]byte(password), hashCost) - if err != nil { - return "", err - } + hash := must(bcrypt.GenerateFromPassword([]byte(password), hashCost)) return string(hash), nil } -func GeneratePassword(requiredChars []string, length int) (string, error) { +func GeneratePassword(requiredChars []string, length int) string { passwordBuilder := strings.Builder{} passwordBuilder.Grow(length) // Add required characters for _, group := range requiredChars { if len(group) > 0 { - randomIndex, err := secureRandomInt(len(group)) - if err != nil { - return "", err - } + randomIndex := secureRandomInt(len(group)) + passwordBuilder.WriteByte(group[randomIndex]) } } @@ -426,10 +400,7 @@ func GeneratePassword(requiredChars []string, length int) (string, error) { // Fill the rest of the password for passwordBuilder.Len() < length { - randomIndex, err := secureRandomInt(len(allChars)) - if err != nil { - return "", err - } + randomIndex := secureRandomInt(len(allChars)) passwordBuilder.WriteByte(allChars[randomIndex]) } @@ -438,12 +409,10 @@ func GeneratePassword(requiredChars []string, length int) (string, error) { // Secure shuffling for i := len(passwordBytes) - 1; i > 0; i-- { - j, err := secureRandomInt(i + 1) - if err != nil { - return "", err - } + j := secureRandomInt(i + 1) + passwordBytes[i], passwordBytes[j] = passwordBytes[j], passwordBytes[i] } - return string(passwordBytes), nil + return string(passwordBytes) } diff --git a/internal/crypto/password_test.go b/internal/crypto/password_test.go index 3f210810c..289c9fe5b 100644 --- a/internal/crypto/password_test.go +++ b/internal/crypto/password_test.go @@ -23,6 +23,35 @@ func TestArgon2(t *testing.T) { for _, example := range examples { assert.Error(t, CompareHashAndPassword(context.Background(), example, "test1")) } + + negativeExamples := []string{ + // 2d + "$argon2d$v=19$m=16,t=2,p=1$bGJRWThNOHJJTVBSdHl2dQ$NfEnUOuUpb7F2fQkgFUG4g", + // v=16 + "$argon2id$v=16$m=16,t=2,p=1$bGJRWThNOHJJTVBSdHl2dQ$NfEnUOuUpb7F2fQkgFUG4g", + // data + "$argon2id$v=19$m=16,t=2,p=1,data=abc$bGJRWThNOHJJTVBSdHl2dQ$NfEnUOuUpb7F2fQkgFUG4g", + // keyid + "$argon2id$v=19$m=16,t=2,p=1,keyid=abc$bGJRWThNOHJJTVBSdHl2dQ$NfEnUOuUpb7F2fQkgFUG4g", + // m larger than 32 bits + "$argon2id$v=19$m=4294967297,t=2,p=1$bGJRWThNOHJJTVBSdHl2dQ$NfEnUOuUpb7F2fQkgFUG4g", + // t larger than 32 bits + "$argon2id$v=19$m=16,t=4294967297,p=1$bGJRWThNOHJJTVBSdHl2dQ$NfEnUOuUpb7F2fQkgFUG4g", + // p larger than 8 bits + "$argon2id$v=19$m=16,t=2,p=256$bGJRWThNOHJJTVBSdHl2dQ$NfEnUOuUpb7F2fQkgFUG4g", + // salt not Base64 + "$argon2id$v=19$m=16,t=2,p=1$!!!$NfEnUOuUpb7F2fQkgFUG4g", + // hash not Base64 + "$argon2id$v=19$m=16,t=2,p=1$bGJRWThNOHJJTVBSdHl2dQ$!!!", + // salt empty + "$argon2id$v=19$m=16,t=2,p=1$$NfEnUOuUpb7F2fQkgFUG4g", + // hash empty + "$argon2id$v=19$m=16,t=2,p=1$bGJRWThNOHJJTVBSdHl2dQ$", + } + + for _, example := range negativeExamples { + assert.Error(t, CompareHashAndPassword(context.Background(), example, "test")) + } } func TestGeneratePassword(t *testing.T) { @@ -30,30 +59,22 @@ func TestGeneratePassword(t *testing.T) { name string requiredChars []string length int - wantErr bool }{ { name: "Valid password generation", requiredChars: []string{"ABC", "123", "@#$"}, length: 12, - wantErr: false, }, { name: "Empty required chars", requiredChars: []string{}, length: 8, - wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := GeneratePassword(tt.requiredChars, tt.length) - - if (err != nil) != tt.wantErr { - t.Errorf("GeneratePassword() error = %v, wantErr %v", err, tt.wantErr) - return - } + got := GeneratePassword(tt.requiredChars, tt.length) if len(got) != tt.length { t.Errorf("GeneratePassword() returned password of length %d, want %d", len(got), tt.length) @@ -78,10 +99,8 @@ func TestGeneratePassword(t *testing.T) { // Check for duplicates passwords passwords := make(map[string]bool) for i := 0; i < 30; i++ { - p, err := GeneratePassword([]string{"ABC", "123", "@#$"}, 30) - if err != nil { - t.Errorf("GeneratePassword() unexpected error: %v", err) - } + p := GeneratePassword([]string{"ABC", "123", "@#$"}, 30) + if passwords[p] { t.Errorf("GeneratePassword() generated duplicate password: %s", p) } @@ -89,37 +108,71 @@ func TestGeneratePassword(t *testing.T) { } } -type scryptTestCase struct { - name string - hash string - password string - shouldPass bool +func TestFirebaseScrypt(t *testing.T) { + // all of these use the `mytestpassword` string as the valid one + + examples := []string{ + "$fbscrypt$v=1,n=14,r=8,p=1,ss=Bw==,sk=ou9tdYTGyYm8kuR6Dt0Bp0kDuAYoXrK16mbZO4yGwAn3oLspjnN0/c41v8xZnO1n14J3MjKj1b2g6AUCAlFwMw==$C0sHCg9ek77hsg==$zKVTMvnWVw5BBOZNUdnsalx4c4c7y/w7IS5p6Ut2+CfEFFlz37J9huyQfov4iizN8dbjvEJlM5tQaJP84+hfTw==", + } + + for _, example := range examples { + assert.NoError(t, CompareHashAndPassword(context.Background(), example, "mytestpassword")) + } + + for _, example := range examples { + assert.Error(t, CompareHashAndPassword(context.Background(), example, "mytestpassword1")) + } + + negativeExamples := []string{ + // v not 1 + "$fbscrypt$v=2,n=14,r=8,p=1,ss=Bw==,sk=ou9tdYTGyYm8kuR6Dt0Bp0kDuAYoXrK16mbZO4yGwAn3oLspjnN0/c41v8xZnO1n14J3MjKj1b2g6AUCAlFwMw==$C0sHCg9ek77hsg==$zKVTMvnWVw5BBOZNUdnsalx4c4c7y/w7IS5p6Ut2+CfEFFlz37J9huyQfov4iizN8dbjvEJlM5tQaJP84+hfTw==", + // n not 32 bits + "$fbscrypt$v=1,n=4294967297,r=8,p=1,ss=Bw==,sk=ou9tdYTGyYm8kuR6Dt0Bp0kDuAYoXrK16mbZO4yGwAn3oLspjnN0/c41v8xZnO1n14J3MjKj1b2g6AUCAlFwMw==$C0sHCg9ek77hsg==$zKVTMvnWVw5BBOZNUdnsalx4c4c7y/w7IS5p6Ut2+CfEFFlz37J9huyQfov4iizN8dbjvEJlM5tQaJP84+hfTw==", + // n is 0 + "$fbscrypt$v=1,n=0,r=8,p=1,ss=Bw==,sk=ou9tdYTGyYm8kuR6Dt0Bp0kDuAYoXrK16mbZO4yGwAn3oLspjnN0/c41v8xZnO1n14J3MjKj1b2g6AUCAlFwMw==$C0sHCg9ek77hsg==$zKVTMvnWVw5BBOZNUdnsalx4c4c7y/w7IS5p6Ut2+CfEFFlz37J9huyQfov4iizN8dbjvEJlM5tQaJP84+hfTw==", + // rounds is not 64 bits + "$fbscrypt$v=1,n=14,r=18446744073709551617,p=1,ss=Bw==,sk=ou9tdYTGyYm8kuR6Dt0Bp0kDuAYoXrK16mbZO4yGwAn3oLspjnN0/c41v8xZnO1n14J3MjKj1b2g6AUCAlFwMw==$C0sHCg9ek77hsg==$zKVTMvnWVw5BBOZNUdnsalx4c4c7y/w7IS5p6Ut2+CfEFFlz37J9huyQfov4iizN8dbjvEJlM5tQaJP84+hfTw==", + // rounds is 0 + "$fbscrypt$v=1,n=14,r=0,p=1,ss=Bw==,sk=ou9tdYTGyYm8kuR6Dt0Bp0kDuAYoXrK16mbZO4yGwAn3oLspjnN0/c41v8xZnO1n14J3MjKj1b2g6AUCAlFwMw==$C0sHCg9ek77hsg==$zKVTMvnWVw5BBOZNUdnsalx4c4c7y/w7IS5p6Ut2+CfEFFlz37J9huyQfov4iizN8dbjvEJlM5tQaJP84+hfTw==", + // threads is not 8 bits + "$fbscrypt$v=1,n=14,r=8,p=256,ss=Bw==,sk=ou9tdYTGyYm8kuR6Dt0Bp0kDuAYoXrK16mbZO4yGwAn3oLspjnN0/c41v8xZnO1n14J3MjKj1b2g6AUCAlFwMw==$C0sHCg9ek77hsg==$zKVTMvnWVw5BBOZNUdnsalx4c4c7y/w7IS5p6Ut2+CfEFFlz37J9huyQfov4iizN8dbjvEJlM5tQaJP84+hfTw==", + // threads is 0 + "$fbscrypt$v=1,n=14,r=8,p=0,ss=Bw==,sk=ou9tdYTGyYm8kuR6Dt0Bp0kDuAYoXrK16mbZO4yGwAn3oLspjnN0/c41v8xZnO1n14J3MjKj1b2g6AUCAlFwMw==$C0sHCg9ek77hsg==$zKVTMvnWVw5BBOZNUdnsalx4c4c7y/w7IS5p6Ut2+CfEFFlz37J9huyQfov4iizN8dbjvEJlM5tQaJP84+hfTw==", + // hash is not base64 + "$fbscrypt$v=1,n=14,r=8,p=1,ss=Bw==,sk=ou9tdYTGyYm8kuR6Dt0Bp0kDuAYoXrK16mbZO4yGwAn3oLspjnN0/c41v8xZnO1n14J3MjKj1b2g6AUCAlFwMw==$C0sHCg9ek77hsg==$!!!", + // salt is not base64 + "$fbscrypt$v=1,n=14,r=8,p=1,ss=Bw==,sk=ou9tdYTGyYm8kuR6Dt0Bp0kDuAYoXrK16mbZO4yGwAn3oLspjnN0/c41v8xZnO1n14J3MjKj1b2g6AUCAlFwMw==$!!!$zKVTMvnWVw5BBOZNUdnsalx4c4c7y/w7IS5p6Ut2+CfEFFlz37J9huyQfov4iizN8dbjvEJlM5tQaJP84+hfTw==", + // signer key is not base64 + "$fbscrypt$v=1,n=14,r=8,p=1,ss=Bw==,sk=!!!$C0sHCg9ek77hsg==$zKVTMvnWVw5BBOZNUdnsalx4c4c7y/w7IS5p6Ut2+CfEFFlz37J9huyQfov4iizN8dbjvEJlM5tQaJP84+hfTw==", + // salt separator is not base64 + "$fbscrypt$v=1,n=14,r=8,p=1,ss=!!!,sk=ou9tdYTGyYm8kuR6Dt0Bp0kDuAYoXrK16mbZO4yGwAn3oLspjnN0/c41v8xZnO1n14J3MjKj1b2g6AUCAlFwMw==$C0sHCg9ek77hsg==$zKVTMvnWVw5BBOZNUdnsalx4c4c7y/w7IS5p6Ut2+CfEFFlz37J9huyQfov4iizN8dbjvEJlM5tQaJP84+hfTw==", + } + + for _, example := range negativeExamples { + assert.Error(t, CompareHashAndPassword(context.Background(), example, "mytestpassword")) + } } -func TestScrypt(t *testing.T) { - testCases := []scryptTestCase{ - { - name: "Firebase Scrypt: appropriate hash", - hash: "$fbscrypt$v=1,n=14,r=8,p=1,ss=Bw==,sk=ou9tdYTGyYm8kuR6Dt0Bp0kDuAYoXrK16mbZO4yGwAn3oLspjnN0/c41v8xZnO1n14J3MjKj1b2g6AUCAlFwMw==$C0sHCg9ek77hsg==$zKVTMvnWVw5BBOZNUdnsalx4c4c7y/w7IS5p6Ut2+CfEFFlz37J9huyQfov4iizN8dbjvEJlM5tQaJP84+hfTw==", - password: "mytestpassword", - shouldPass: true, - }, - { - name: "Firebase Scrypt: incorrect hash", - hash: "$fbscrypt$v=1,n=14,r=8,p=1,ss=Bw==,sk=ou9tdYTGyYm8kuR6Dt0Bp0kDuAYoXrK16mbZO4yGwAn3oLspjnN0/c41v8xZnO1n14J3MjKj1b2g6AUCAlFwMw==$C0sHCg9ek77hsg==$ZGlmZmVyZW50aGFzaA==", - password: "mytestpassword", - shouldPass: false, - }, +func TestBcrypt(t *testing.T) { + // all use the `test` password + + examples := []string{ + "$2y$04$mIJxfrCaEI3GukZe11CiXublhEFanu5.ododkll1WphfSp6pn4zIu", + "$2y$10$srNl09aPtc2qr.0Vl.NtjekJRt/NxRxYQm3qd3OvfcKsJgVnr6.Ve", } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - err := CompareHashAndPassword(context.Background(), tc.hash, tc.password) - if tc.shouldPass { - assert.NoError(t, err, "Expected test case to pass, but it failed") - } else { - assert.Error(t, err, "Expected test case to fail, but it passed") - } - }) + for _, example := range examples { + assert.NoError(t, CompareHashAndPassword(context.Background(), example, "test")) + } + + for _, example := range examples { + assert.Error(t, CompareHashAndPassword(context.Background(), example, "test1")) + } + + negativeExamples := []string{ + "not-a-hash", + } + for _, example := range negativeExamples { + assert.Error(t, CompareHashAndPassword(context.Background(), example, "test")) } } diff --git a/internal/crypto/utils.go b/internal/crypto/utils.go new file mode 100644 index 000000000..a6b38b8e8 --- /dev/null +++ b/internal/crypto/utils.go @@ -0,0 +1,9 @@ +package crypto + +func must[T any](a T, err error) T { + if err != nil { + panic(err) + } + + return a +} diff --git a/internal/crypto/utils_test.go b/internal/crypto/utils_test.go new file mode 100644 index 000000000..1aeeab80c --- /dev/null +++ b/internal/crypto/utils_test.go @@ -0,0 +1,14 @@ +package crypto + +import ( + "testing" + + "github.com/pkg/errors" + "github.com/stretchr/testify/require" +) + +func TestMust(t *testing.T) { + require.Panics(t, func() { + must(123, errors.New("panic")) + }) +}