Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 50 additions & 10 deletions jar.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,15 @@ type Options struct {
// (useful for tests). If this is true, the value of Filename will be
// ignored.
NoPersist bool

// Filter specifies the filter to be used when deciding whether each cookie
// should be persisted to the filesystem.
Filter CookieFilter
}

// CookieFilter is a type for deciding whether a Cookie should be persisted
type CookieFilter interface {
IsPersistent(*http.Cookie) bool
}

// Jar implements the http.CookieJar interface from the net/http package.
Expand All @@ -87,6 +96,9 @@ type Jar struct {
// entries is a set of entries, keyed by their eTLD+1 and subkeyed by
// their name/domain/path.
entries map[string]map[string]entry

// Filter from Options
filter CookieFilter
}

var noOptions Options
Expand All @@ -108,6 +120,11 @@ func newAtTime(o *Options, now time.Time) (*Jar, error) {
if o == nil {
o = &noOptions
}
jar.filter = DefaultFilter
if o.Filter != nil {
jar.filter = o.Filter
}

if jar.psList = o.PublicSuffixList; jar.psList == nil {
jar.psList = publicsuffix.List
}
Expand Down Expand Up @@ -144,10 +161,10 @@ type entry struct {
Path string
Secure bool
HttpOnly bool
Persistent bool
HostOnly bool
Expires time.Time
Creation time.Time
MaxAge int
LastAccess time.Time

// Updated records when the cookie was updated.
Expand Down Expand Up @@ -203,6 +220,10 @@ func (e *entry) pathMatch(requestPath string) bool {
return false
}

func (e *entry) isExpiredAfter(t time.Time) bool {
return !e.Expires.IsZero() && t.After(e.Expires)
}

// hasDotSuffix reports whether s ends in "."+suffix.
func hasDotSuffix(s, suffix string) bool {
return len(s) > len(suffix) && s[len(s)-len(suffix)-1] == '.' && s[len(s)-len(suffix):] == suffix
Expand Down Expand Up @@ -279,7 +300,7 @@ func (j *Jar) cookies(u *url.URL, now time.Time) (cookies []*http.Cookie) {

var selected []entry
for id, e := range submap {
if !e.Expires.After(now) {
if e.isExpiredAfter(now) {
// Save some space by deleting the value when the cookie
// expires. We can't delete the cookie itself because then
// we wouldn't know that the cookie had expired when
Expand Down Expand Up @@ -310,7 +331,7 @@ func (j *Jar) cookies(u *url.URL, now time.Time) (cookies []*http.Cookie) {
// have Domain, Expires, HttpOnly, Name, Secure, Path, and Value filled
// out. Expired cookies will not be returned. This function does not
// modify the cookie jar.
func (j *Jar) AllCookies() (cookies []*http.Cookie) {
func (j *Jar) AllCookies() []*http.Cookie {
return j.allCookies(time.Now())
}

Expand All @@ -321,7 +342,7 @@ func (j *Jar) allCookies(now time.Time) []*http.Cookie {
defer j.mu.Unlock()
for _, submap := range j.entries {
for _, e := range submap {
if !e.Expires.After(now) {
if e.isExpiredAfter(now) {
// Do not return expired cookies.
continue
}
Expand Down Expand Up @@ -393,7 +414,7 @@ var expiryRemovalDuration = 24 * time.Hour
func (j *Jar) deleteExpired(now time.Time) {
for tld, submap := range j.entries {
for id, e := range submap {
if !e.Expires.After(now) && !e.Updated.Add(expiryRemovalDuration).After(now) {
if e.isExpiredAfter(now) && !e.Updated.Add(expiryRemovalDuration).After(now) {
delete(submap, id)
}
}
Expand Down Expand Up @@ -575,6 +596,27 @@ func defaultPath(path string) string {
return path[:i] // Path is either of form "/abc/xyz" or "/abc/xyz/".
}

// CookieFilterFunc implements CookieFilter by calling the underlying func
type CookieFilterFunc func(*http.Cookie) bool

// IsPersistent implements CookieFilter for arbitrary funcs
func (cff CookieFilterFunc) IsPersistent(c *http.Cookie) bool {
return cff(c)
}

var (
// DefaultFilter is the previous behavior which does not persist session
// cookies.
DefaultFilter = CookieFilterFunc(func(c *http.Cookie) bool {
return c.MaxAge != 0 || !c.Expires.IsZero()
})

// AllowAllFilter does not check any cookie properties before persisting.
AllowAllFilter = CookieFilterFunc(func(_ *http.Cookie) bool {
return true
})
)

// newEntry creates an entry from a http.Cookie c. now is the current
// time and is compared to c.Expires to determine deletion of c. defPath
// and host are the default-path and the canonical host name of the URL
Expand All @@ -587,6 +629,7 @@ func defaultPath(path string) string {
// A malformed c.Domain will result in an error.
func (j *Jar) newEntry(c *http.Cookie, now time.Time, defPath, host string) (e entry, err error) {
e.Name = c.Name
e.MaxAge = c.MaxAge
if c.Path == "" || c.Path[0] != '/' {
e.Path = defPath
} else {
Expand All @@ -597,19 +640,16 @@ func (j *Jar) newEntry(c *http.Cookie, now time.Time, defPath, host string) (e e
if err != nil {
return e, err
}

// MaxAge takes precedence over Expires.
if c.MaxAge != 0 {
e.Persistent = true
e.Expires = now.Add(time.Duration(c.MaxAge) * time.Second)
if c.MaxAge < 0 {
return e, nil
}
} else if c.Expires.IsZero() {
e.Expires = endOfTime
} else {
e.Persistent = true
e.Expires = c.Expires
if !c.Expires.After(now) {
if e.isExpiredAfter(now) {
return e, nil
}
}
Expand Down
71 changes: 70 additions & 1 deletion jar_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package cookiejar

import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
Expand Down Expand Up @@ -1811,7 +1812,7 @@ func allCookies(jar *Jar, now time.Time) string {
var cs []string
for _, submap := range jar.entries {
for _, cookie := range submap {
if !cookie.Expires.After(now) {
if !cookie.Expires.IsZero() && now.After(cookie.Expires) {
continue
}
cs = append(cs, cookie.Name+"="+cookie.Value)
Expand Down Expand Up @@ -2067,6 +2068,74 @@ func TestRemoveAllHostIP(t *testing.T) {
testRemoveAllHost(t, mustParseURL("https://10.1.1.1"), "10.1.1.1", true)
}

func TestFilter(t *testing.T) {
j := newTestJar("")

google := mustParseURL("https://www.google.com")

j.SetCookies(
google,
[]*http.Cookie{
&http.Cookie{
Name: "test-cookie",
Value: "test-value",
Expires: time.Now().Add(24 * time.Hour),
},
&http.Cookie{
Name: "test-cookie2",
Value: "test-value",
},
},
)

es, err := jsonRoundTrip(j)
if err != nil {
t.Fatalf("json failed: %v", err)
}

if len(es) != 1 {
t.Errorf("expected only one entry, got %d", len(es))
}

j.filter = CookieFilterFunc(func(_ *http.Cookie) bool {
return false
})

es, err = jsonRoundTrip(j)
if err != nil {
t.Fatalf("json failed: %v", err)
}

if len(es) > 0 {
t.Errorf("expected zero entries")
}

j.filter = AllowAllFilter

es, err = jsonRoundTrip(j)
if err != nil {
t.Fatalf("json failed: %v", err)
}

if len(es) < 2 {
t.Errorf("got fewer than two entries with AllowAllFilter")
}
}

func jsonRoundTrip(j *Jar) ([]entry, error) {
bs, err := j.MarshalJSON()
if err != nil {
return nil, err
}

var es []entry
if json.Unmarshal(bs, &es) != nil {
return nil, err
}

return es, nil
}

func testRemoveAllHost(t *testing.T, setURL *url.URL, removeHost string, shouldRemove bool) {
jar := newTestJar("")
google := mustParseURL("https://www.google.com")
Expand Down
12 changes: 11 additions & 1 deletion serialize.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"encoding/json"
"io"
"log"
"net/http"
"os"
"path/filepath"
"sort"
Expand Down Expand Up @@ -137,7 +138,16 @@ func (j *Jar) allPersistentEntries() []entry {
var entries []entry
for _, submap := range j.entries {
for _, e := range submap {
if e.Persistent {
if j.filter.IsPersistent(&http.Cookie{
Domain: e.Domain,
Expires: e.Expires,
HttpOnly: e.HttpOnly,
MaxAge: e.MaxAge,
Name: e.Name,
Path: e.Path,
Secure: e.Secure,
Value: e.Value,
}) {
entries = append(entries, e)
}
}
Expand Down