From 59f7323070c5bd99858fde80232f9fb5b08e597a Mon Sep 17 00:00:00 2001 From: Alexey Vilenski Date: Wed, 7 Oct 2020 21:06:36 +0000 Subject: [PATCH] xsrftoken: add custom timeout support for valid func Added new function 'ValidFor' with custom token timeout support. Function 'Valid' will use default token timeout. Fixes golang/go#41438 Change-Id: I5cf0388aeed7ca34edcb0d3493c3e79c8ce19938 GitHub-Last-Rev: 3e3b5817964aebf5b804ceec6694ed500f439c1e GitHub-Pull-Request: golang/net#86 Reviewed-on: https://go-review.googlesource.com/c/net/+/260317 Run-TryBot: Ian Lance Taylor TryBot-Result: Go Bot Reviewed-by: Ian Lance Taylor Trust: Filippo Valsorda --- xsrftoken/xsrf.go | 13 ++++++++++--- xsrftoken/xsrf_test.go | 31 +++++++++++++++++++++---------- 2 files changed, 31 insertions(+), 13 deletions(-) diff --git a/xsrftoken/xsrf.go b/xsrftoken/xsrf.go index 4f66adfca..3ca5d5b9f 100644 --- a/xsrftoken/xsrf.go +++ b/xsrftoken/xsrf.go @@ -54,12 +54,19 @@ func generateTokenAtTime(key, userID, actionID string, now time.Time) string { } // Valid reports whether a token is a valid, unexpired token returned by Generate. +// The token is considered to be expired and invalid if it is older than the default Timeout. func Valid(token, key, userID, actionID string) bool { - return validTokenAtTime(token, key, userID, actionID, time.Now()) + return validTokenAtTime(token, key, userID, actionID, time.Now(), Timeout) +} + +// ValidFor reports whether a token is a valid, unexpired token returned by Generate. +// The token is considered to be expired and invalid if it is older than the timeout duration. +func ValidFor(token, key, userID, actionID string, timeout time.Duration) bool { + return validTokenAtTime(token, key, userID, actionID, time.Now(), timeout) } // validTokenAtTime reports whether a token is valid at the given time. -func validTokenAtTime(token, key, userID, actionID string, now time.Time) bool { +func validTokenAtTime(token, key, userID, actionID string, now time.Time, timeout time.Duration) bool { if len(key) == 0 { panic("zero length xsrf secret key") } @@ -75,7 +82,7 @@ func validTokenAtTime(token, key, userID, actionID string, now time.Time) bool { issueTime := time.Unix(0, millis*1e6) // Check that the token is not expired. - if now.Sub(issueTime) >= Timeout { + if now.Sub(issueTime) >= timeout { return false } diff --git a/xsrftoken/xsrf_test.go b/xsrftoken/xsrf_test.go index fc0a48a85..60ff84a62 100644 --- a/xsrftoken/xsrf_test.go +++ b/xsrftoken/xsrf_test.go @@ -23,13 +23,22 @@ var ( func TestValidToken(t *testing.T) { tok := generateTokenAtTime(key, userID, actionID, now) - if !validTokenAtTime(tok, key, userID, actionID, oneMinuteFromNow) { + if !validTokenAtTime(tok, key, userID, actionID, oneMinuteFromNow, Timeout) { t.Error("One second later: Expected token to be valid") } - if !validTokenAtTime(tok, key, userID, actionID, now.Add(Timeout-1*time.Nanosecond)) { + if !validTokenAtTime(tok, key, userID, actionID, now.Add(Timeout-1*time.Nanosecond), Timeout) { t.Error("Just before timeout: Expected token to be valid") } - if !validTokenAtTime(tok, key, userID, actionID, now.Add(-1*time.Minute+1*time.Millisecond)) { + if !validTokenAtTime(tok, key, userID, actionID, now.Add(-1*time.Minute+1*time.Millisecond), Timeout) { + t.Error("One minute in the past: Expected token to be valid") + } + if !validTokenAtTime(tok, key, userID, actionID, oneMinuteFromNow, time.Hour) { + t.Error("One second later: Expected token to be valid") + } + if !validTokenAtTime(tok, key, userID, actionID, now.Add(time.Minute-1*time.Nanosecond), time.Minute) { + t.Error("Just before timeout: Expected token to be valid") + } + if !validTokenAtTime(tok, key, userID, actionID, now.Add(-1*time.Minute+1*time.Millisecond), time.Hour) { t.Error("One minute in the past: Expected token to be valid") } } @@ -69,17 +78,19 @@ func TestInvalidToken(t *testing.T) { invalidTokenTests := []struct { name, key, userID, actionID string t time.Time + timeout time.Duration }{ - {"Bad key", "foobar", userID, actionID, oneMinuteFromNow}, - {"Bad userID", key, "foobar", actionID, oneMinuteFromNow}, - {"Bad actionID", key, userID, "foobar", oneMinuteFromNow}, - {"Expired", key, userID, actionID, now.Add(Timeout + 1*time.Millisecond)}, - {"More than 1 minute from the future", key, userID, actionID, now.Add(-1*time.Nanosecond - 1*time.Minute)}, + {"Bad key", "foobar", userID, actionID, oneMinuteFromNow, Timeout}, + {"Bad userID", key, "foobar", actionID, oneMinuteFromNow, Timeout}, + {"Bad actionID", key, userID, "foobar", oneMinuteFromNow, Timeout}, + {"Expired", key, userID, actionID, now.Add(Timeout + 1*time.Millisecond), Timeout}, + {"More than 1 minute from the future", key, userID, actionID, now.Add(-1*time.Nanosecond - 1*time.Minute), Timeout}, + {"Expired with 1 minute timeout", key, userID, actionID, now.Add(time.Minute + 1*time.Millisecond), time.Minute}, } tok := generateTokenAtTime(key, userID, actionID, now) for _, itt := range invalidTokenTests { - if validTokenAtTime(tok, itt.key, itt.userID, itt.actionID, itt.t) { + if validTokenAtTime(tok, itt.key, itt.userID, itt.actionID, itt.t, itt.timeout) { t.Errorf("%v: Expected token to be invalid", itt.name) } } @@ -98,7 +109,7 @@ func TestValidateBadData(t *testing.T) { } for _, bdt := range badDataTests { - if validTokenAtTime(bdt.tok, key, userID, actionID, oneMinuteFromNow) { + if validTokenAtTime(bdt.tok, key, userID, actionID, oneMinuteFromNow, Timeout) { t.Errorf("%v: Expected token to be invalid", bdt.name) } }