diff --git a/registry.go b/registry.go index eabe746..1fb0aa3 100644 --- a/registry.go +++ b/registry.go @@ -106,6 +106,13 @@ func Clear() { } } +func nowFromClock(clock jwt.Clock) time.Time { + if clock == nil { + return time.Now() + } + return clock.Now() +} + // Sign will create a new JWT based on the map of input data, // the Context's configuration, and current signing key. If the // signing key name is not set, an error will be returned. @@ -139,18 +146,11 @@ func Sign(purpose string, claims map[string]string, clock jwt.Clock) (signed []b return } - var now time.Time - if clock != nil { - now = clock.Now() - } else { - now = time.Now() - } - + now := nowFromClock(clock) builder := &jwt.Builder{} builder = builder. Claim(jwt.IssuerKey, c.issuer). Claim(jwt.IssuedAtKey, now) - if c.signingValidityPeriod > 0 { builder = builder.Claim(jwt.ExpirationKey, now.Add(c.signingValidityPeriod)) } @@ -166,7 +166,7 @@ func Sign(purpose string, claims map[string]string, clock jwt.Clock) (signed []b } } - signed, err = jwt.Sign(t, jwa.HS256, key) + signed, err = jwt.Sign(t, jwa.SignatureAlgorithm(key.Algorithm()), key) return } diff --git a/registry_test.go b/registry_test.go index 3fa99d4..5d1798a 100644 --- a/registry_test.go +++ b/registry_test.go @@ -180,6 +180,7 @@ func setupKeys(t *testing.T) { func TestSign(t *testing.T) { setupKeys(t) + t.Parallel() type args struct { purpose string claims map[string]string @@ -274,6 +275,7 @@ func TestSign(t *testing.T) { func TestValidate(t *testing.T) { setupKeys(t) + t.Parallel() type args struct { purpose string signed []byte @@ -362,3 +364,29 @@ func TestValidate(t *testing.T) { }) } } + +func timeEqualsEpislion(want time.Time, got time.Time, fudge time.Duration) bool { + return got.Unix() >= want.Add(-fudge).Unix() && got.Unix() <= want.Add(fudge).Unix() +} + +func Test_nowFromClock(t *testing.T) { + type args struct { + clock jwt.Clock + } + tests := []struct { + name string + args args + want time.Time + }{ + {"nil", args{nil}, time.Now()}, + {"set at 50", args{&TimeClock{50}}, time.Unix(50, 0)}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := nowFromClock(tt.args.clock) + if !timeEqualsEpislion(tt.want, got, 1*time.Second) { + t.Errorf("nowFromClock() = %v, want %v", got, tt.want) + } + }) + } +}