diff --git a/jar.go b/jar.go index 5ff2733..4ad290f 100644 --- a/jar.go +++ b/jar.go @@ -72,6 +72,15 @@ type Options struct { // (useful for tests). If this is true, the value of Filename will be // ignored. NoPersist bool + + // Filter specifies the filter to be used when deciding whether each cookie + // should be persisted to the filesystem. + Filter CookieFilter +} + +// CookieFilter is a type for deciding whether a Cookie should be persisted +type CookieFilter interface { + IsPersistent(*http.Cookie) bool } // Jar implements the http.CookieJar interface from the net/http package. @@ -87,6 +96,9 @@ type Jar struct { // entries is a set of entries, keyed by their eTLD+1 and subkeyed by // their name/domain/path. entries map[string]map[string]entry + + // Filter from Options + filter CookieFilter } var noOptions Options @@ -108,6 +120,11 @@ func newAtTime(o *Options, now time.Time) (*Jar, error) { if o == nil { o = &noOptions } + jar.filter = DefaultFilter + if o.Filter != nil { + jar.filter = o.Filter + } + if jar.psList = o.PublicSuffixList; jar.psList == nil { jar.psList = publicsuffix.List } @@ -144,10 +161,10 @@ type entry struct { Path string Secure bool HttpOnly bool - Persistent bool HostOnly bool Expires time.Time Creation time.Time + MaxAge int LastAccess time.Time // Updated records when the cookie was updated. @@ -203,6 +220,10 @@ func (e *entry) pathMatch(requestPath string) bool { return false } +func (e *entry) isExpiredAfter(t time.Time) bool { + return !e.Expires.IsZero() && t.After(e.Expires) +} + // hasDotSuffix reports whether s ends in "."+suffix. func hasDotSuffix(s, suffix string) bool { return len(s) > len(suffix) && s[len(s)-len(suffix)-1] == '.' && s[len(s)-len(suffix):] == suffix @@ -279,7 +300,7 @@ func (j *Jar) cookies(u *url.URL, now time.Time) (cookies []*http.Cookie) { var selected []entry for id, e := range submap { - if !e.Expires.After(now) { + if e.isExpiredAfter(now) { // Save some space by deleting the value when the cookie // expires. We can't delete the cookie itself because then // we wouldn't know that the cookie had expired when @@ -310,7 +331,7 @@ func (j *Jar) cookies(u *url.URL, now time.Time) (cookies []*http.Cookie) { // have Domain, Expires, HttpOnly, Name, Secure, Path, and Value filled // out. Expired cookies will not be returned. This function does not // modify the cookie jar. -func (j *Jar) AllCookies() (cookies []*http.Cookie) { +func (j *Jar) AllCookies() []*http.Cookie { return j.allCookies(time.Now()) } @@ -321,7 +342,7 @@ func (j *Jar) allCookies(now time.Time) []*http.Cookie { defer j.mu.Unlock() for _, submap := range j.entries { for _, e := range submap { - if !e.Expires.After(now) { + if e.isExpiredAfter(now) { // Do not return expired cookies. continue } @@ -393,7 +414,7 @@ var expiryRemovalDuration = 24 * time.Hour func (j *Jar) deleteExpired(now time.Time) { for tld, submap := range j.entries { for id, e := range submap { - if !e.Expires.After(now) && !e.Updated.Add(expiryRemovalDuration).After(now) { + if e.isExpiredAfter(now) && !e.Updated.Add(expiryRemovalDuration).After(now) { delete(submap, id) } } @@ -575,6 +596,27 @@ func defaultPath(path string) string { return path[:i] // Path is either of form "/abc/xyz" or "/abc/xyz/". } +// CookieFilterFunc implements CookieFilter by calling the underlying func +type CookieFilterFunc func(*http.Cookie) bool + +// IsPersistent implements CookieFilter for arbitrary funcs +func (cff CookieFilterFunc) IsPersistent(c *http.Cookie) bool { + return cff(c) +} + +var ( + // DefaultFilter is the previous behavior which does not persist session + // cookies. + DefaultFilter = CookieFilterFunc(func(c *http.Cookie) bool { + return c.MaxAge != 0 || !c.Expires.IsZero() + }) + + // AllowAllFilter does not check any cookie properties before persisting. + AllowAllFilter = CookieFilterFunc(func(_ *http.Cookie) bool { + return true + }) +) + // newEntry creates an entry from a http.Cookie c. now is the current // time and is compared to c.Expires to determine deletion of c. defPath // and host are the default-path and the canonical host name of the URL @@ -587,6 +629,7 @@ func defaultPath(path string) string { // A malformed c.Domain will result in an error. func (j *Jar) newEntry(c *http.Cookie, now time.Time, defPath, host string) (e entry, err error) { e.Name = c.Name + e.MaxAge = c.MaxAge if c.Path == "" || c.Path[0] != '/' { e.Path = defPath } else { @@ -597,19 +640,16 @@ func (j *Jar) newEntry(c *http.Cookie, now time.Time, defPath, host string) (e e if err != nil { return e, err } + // MaxAge takes precedence over Expires. if c.MaxAge != 0 { - e.Persistent = true e.Expires = now.Add(time.Duration(c.MaxAge) * time.Second) if c.MaxAge < 0 { return e, nil } - } else if c.Expires.IsZero() { - e.Expires = endOfTime } else { - e.Persistent = true e.Expires = c.Expires - if !c.Expires.After(now) { + if e.isExpiredAfter(now) { return e, nil } } diff --git a/jar_test.go b/jar_test.go index 5ea3467..49a066e 100644 --- a/jar_test.go +++ b/jar_test.go @@ -5,6 +5,7 @@ package cookiejar import ( + "encoding/json" "fmt" "io/ioutil" "net/http" @@ -1811,7 +1812,7 @@ func allCookies(jar *Jar, now time.Time) string { var cs []string for _, submap := range jar.entries { for _, cookie := range submap { - if !cookie.Expires.After(now) { + if !cookie.Expires.IsZero() && now.After(cookie.Expires) { continue } cs = append(cs, cookie.Name+"="+cookie.Value) @@ -2067,6 +2068,74 @@ func TestRemoveAllHostIP(t *testing.T) { testRemoveAllHost(t, mustParseURL("https://10.1.1.1"), "10.1.1.1", true) } +func TestFilter(t *testing.T) { + j := newTestJar("") + + google := mustParseURL("https://www.google.com") + + j.SetCookies( + google, + []*http.Cookie{ + &http.Cookie{ + Name: "test-cookie", + Value: "test-value", + Expires: time.Now().Add(24 * time.Hour), + }, + &http.Cookie{ + Name: "test-cookie2", + Value: "test-value", + }, + }, + ) + + es, err := jsonRoundTrip(j) + if err != nil { + t.Fatalf("json failed: %v", err) + } + + if len(es) != 1 { + t.Errorf("expected only one entry, got %d", len(es)) + } + + j.filter = CookieFilterFunc(func(_ *http.Cookie) bool { + return false + }) + + es, err = jsonRoundTrip(j) + if err != nil { + t.Fatalf("json failed: %v", err) + } + + if len(es) > 0 { + t.Errorf("expected zero entries") + } + + j.filter = AllowAllFilter + + es, err = jsonRoundTrip(j) + if err != nil { + t.Fatalf("json failed: %v", err) + } + + if len(es) < 2 { + t.Errorf("got fewer than two entries with AllowAllFilter") + } +} + +func jsonRoundTrip(j *Jar) ([]entry, error) { + bs, err := j.MarshalJSON() + if err != nil { + return nil, err + } + + var es []entry + if json.Unmarshal(bs, &es) != nil { + return nil, err + } + + return es, nil +} + func testRemoveAllHost(t *testing.T, setURL *url.URL, removeHost string, shouldRemove bool) { jar := newTestJar("") google := mustParseURL("https://www.google.com") diff --git a/serialize.go b/serialize.go index 2792dfb..11507e8 100644 --- a/serialize.go +++ b/serialize.go @@ -8,6 +8,7 @@ import ( "encoding/json" "io" "log" + "net/http" "os" "path/filepath" "sort" @@ -137,7 +138,16 @@ func (j *Jar) allPersistentEntries() []entry { var entries []entry for _, submap := range j.entries { for _, e := range submap { - if e.Persistent { + if j.filter.IsPersistent(&http.Cookie{ + Domain: e.Domain, + Expires: e.Expires, + HttpOnly: e.HttpOnly, + MaxAge: e.MaxAge, + Name: e.Name, + Path: e.Path, + Secure: e.Secure, + Value: e.Value, + }) { entries = append(entries, e) } }