Skip to content

Commit

Permalink
Add session tests.
Browse files Browse the repository at this point in the history
We test login-related session changes (regeneration, shows logged in from handler) and CSRF protection.

This commit also changes user_test to use t.Fatal instead of t.Error.

Ref issue #1
  • Loading branch information
Favyen Bastani committed Jun 11, 2015
1 parent 2436469 commit 1ecb075
Show file tree
Hide file tree
Showing 3 changed files with 186 additions and 11 deletions.
180 changes: 180 additions & 0 deletions session_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
package lobster

import "bytes"
import "net/http"
import "net/http/httptest"
import "net/url"
import "testing"

func TestSessionBasic(t *testing.T) {
var seenUserId int
fakeHandler := func(w http.ResponseWriter, r *http.Request, db *Database, session *Session) {
seenUserId = session.UserId
}

db := TestReset()
userId := TestUser(db)
w := httptest.NewRecorder()
session := makeSession(w, db)
req, _ := http.NewRequest("GET", "http://example.com/", nil)
req.AddCookie(&http.Cookie{Name: SESSION_COOKIE_NAME, Value: session.Id})
db.Exec("UPDATE sessions SET user_id = ?", userId)
sessionWrap(fakeHandler)(w, req, db)

if seenUserId != userId {
t.Fatalf("Expected session user id %d but got %d", userId, seenUserId)
}
}

func findResponseCookie(response *http.Response, name string) string {
for _, cookie := range response.Cookies() {
if cookie.Name == name {
return cookie.Value
}
}
return ""
}

func TestSessionLogin(t *testing.T) {
// create session, apply on fake handler, then login
// verify session ID changes after first request after login
// (check both cookie and that new handler on old session doesn't get logged in)
// then verify in next handler that session preserves user ID
// note: we use server for this test since it makes parsing Set-Cookie easier
db := TestReset()
userId := TestUser(db)

var seenUserId int
fakeHandler := func(w http.ResponseWriter, r *http.Request, db *Database, session *Session) {
seenUserId = session.UserId
}

loginHandler := func(w http.ResponseWriter, r *http.Request, db *Database, session *Session) {
session.UserId = userId
}

handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/" {
db.wrapHandler(sessionWrap(fakeHandler))(w, r)
} else if r.URL.Path == "/login" {
db.wrapHandler(sessionWrap(loginHandler))(w, r)
} else {
t.Errorf("Unexpected request path %s", r.URL.Path)
}
})
server := httptest.NewServer(handler)

// initial request
response, err := http.Get(server.URL)
if err != nil {
t.Fatal(err)
} else if seenUserId != 0 {
t.Fatal("Initial request already shows logged in user")
}
initialSessionId := findResponseCookie(response, SESSION_COOKIE_NAME)
if initialSessionId == "" {
t.Fatal("No session cookie provided")
}

// login
request, _ := http.NewRequest("GET", server.URL + "/login", nil)
request.AddCookie(&http.Cookie{Name: SESSION_COOKIE_NAME, Value: initialSessionId})
response, err = new(http.Client).Do(request)
if err != nil {
t.Fatal(err)
}

// get arbitrary page, expect both to be logged in and for server to regenerate session id
request, _ = http.NewRequest("GET", server.URL, nil)
request.AddCookie(&http.Cookie{Name: SESSION_COOKIE_NAME, Value: initialSessionId})
response, err = new(http.Client).Do(request)
if err != nil {
t.Fatal(err)
} else if seenUserId != userId {
t.Fatal("First page after login with initial session should be logged in, but isn't")
}
loginSessionId := findResponseCookie(response, SESSION_COOKIE_NAME)
if loginSessionId == "" {
t.Fatal("No session cookie provided on first request after login")
} else if loginSessionId == initialSessionId {
t.Fatal("Session cookie remains the same on first request after login")
}

// verify old session not logged in
request, _ = http.NewRequest("GET", server.URL, nil)
request.AddCookie(&http.Cookie{Name: SESSION_COOKIE_NAME, Value: initialSessionId})
response, err = new(http.Client).Do(request)
if err != nil {
t.Fatal(err)
} else if seenUserId != 0 {
t.Fatal("Session from before login is logged in")
}

// verify new session is logged in
request, _ = http.NewRequest("GET", server.URL, nil)
request.AddCookie(&http.Cookie{Name: SESSION_COOKIE_NAME, Value: loginSessionId})
response, err = new(http.Client).Do(request)
if err != nil {
t.Fatal(err)
} else if seenUserId != userId {
t.Fatal("Session from after login is not logged in correctly")
}
}

func TestSessionCSRF(t *testing.T) {
// try no token, valid token, reuse token, and other session token
// only valid token should work
// on fail we expect 303 redirect
db := TestReset()
w := httptest.NewRecorder()
session := makeSession(w, db)
fakeHandler := func(w http.ResponseWriter, r *http.Request, db *Database, session *Session) {}

// no token
req, _ := http.NewRequest("POST", "http://example.com/", nil)
req.AddCookie(&http.Cookie{Name: SESSION_COOKIE_NAME, Value: session.Id})
w = httptest.NewRecorder()
sessionWrap(fakeHandler)(w, req, db)

if w.Code != 303 {
t.Error("CSRF protection allowed no token")
}

// valid token
v := url.Values{}
v.Add("token", csrfGenerate(db, session))
req, _ = http.NewRequest("POST", "http://example.com/", bytes.NewBufferString(v.Encode()))
req.AddCookie(&http.Cookie{Name: SESSION_COOKIE_NAME, Value: session.Id})
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
w = httptest.NewRecorder()
sessionWrap(fakeHandler)(w, req, db)

if w.Code != 200 {
t.Error("CSRF protection disallowed valid token")
}

// reuse token
req, _ = http.NewRequest("POST", "http://example.com/", bytes.NewBufferString(v.Encode()))
req.AddCookie(&http.Cookie{Name: SESSION_COOKIE_NAME, Value: session.Id})
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
w = httptest.NewRecorder()
sessionWrap(fakeHandler)(w, req, db)

if w.Code != 303 {
t.Error("CSRF protection allowed reused token")
}

// other session token
session2 := makeSession(w, db)
v = url.Values{}
v.Add("token", csrfGenerate(db, session2))
req, _ = http.NewRequest("POST", "http://example.com/", bytes.NewBufferString(v.Encode()))
req.AddCookie(&http.Cookie{Name: SESSION_COOKIE_NAME, Value: session.Id})
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
w = httptest.NewRecorder()
sessionWrap(fakeHandler)(w, req, db)

if w.Code != 303 {
t.Error("CSRF protection allowed token from another session")
}
}
2 changes: 1 addition & 1 deletion testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import "github.com/LunaNode/lobster/utils"

const TEST_BANDWIDTH = 1000

var testTables []string = []string{"users", "region_bandwidth", "vms", "plans", "charges"}
var testTables []string = []string{"users", "region_bandwidth", "vms", "plans", "charges", "sessions", "form_tokens"}

func TestReset() *Database {
cfg = &Config{
Expand Down
15 changes: 5 additions & 10 deletions user_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,15 @@ func TestBillingBandwidth(t *testing.T) {
testForceUserBilling(db, userId)
expectedCharge := int64(cfg.Default.BandwidthOverageFee * BILLING_PRECISION) * gbUsage
if !testVerifyCharge(db, userId, "bw-test", expectedCharge) {
t.Errorf("Overage of %d GB, but didn't bill according to overage fee", gbUsage)
return
t.Fatalf("Overage of %d GB, but didn't bill according to overage fee", gbUsage)
}

// make sure we can increase both used and allocated without charging again
gigaIncrease := 500
db.Exec("UPDATE region_bandwidth SET bandwidth_used = bandwidth_used + ?, bandwidth_additional = bandwidth_additional + ? WHERE id = ?", gigaToBytes(gigaIncrease), gigaToBytes(gigaIncrease), regionBandwidthId)
testForceUserBilling(db, userId)
if !testVerifyCharge(db, userId, "bw-test", expectedCharge) {
t.Error("Billed when used/allocated increased by same amount")
return
t.Fatal("Billed when used/allocated increased by same amount")
}

// begin testing proportional billing
Expand All @@ -55,16 +53,14 @@ func TestBillingBandwidth(t *testing.T) {
db.Exec("UPDATE region_bandwidth SET bandwidth_used = bandwidth_used + ? WHERE id = ?", gigaToBytes(TEST_BANDWIDTH / 2), regionBandwidthId)
testForceUserBilling(db, userId)
if !testVerifyChargeApprox(db, userId, "bw-test", expectedCharge * 9 / 10, expectedCharge * 11 / 10) {
t.Error("User charged despite proportional virtual machine")
return
t.Fatal("User charged despite proportional virtual machine")
}

db.Exec("UPDATE region_bandwidth SET bandwidth_used = bandwidth_used + ? WHERE id = ?", gigaToBytes(TEST_BANDWIDTH / 2), regionBandwidthId)
expectedCharge += int64(cfg.Default.BandwidthOverageFee * BILLING_PRECISION) * TEST_BANDWIDTH / 2
testForceUserBilling(db, userId)
if !testVerifyChargeApprox(db, userId, "bw-test", expectedCharge * 9 / 10, expectedCharge * 11 / 10) {
t.Error("User charged differently than expected with proportional virtual machine")
return
t.Fatal("User charged differently than expected with proportional virtual machine")
}

// vm provisioned before beginning of the month should add this month's bandwidth only
Expand All @@ -75,7 +71,6 @@ func TestBillingBandwidth(t *testing.T) {
expectedCharge += int64(cfg.Default.BandwidthOverageFee * BILLING_PRECISION) * TEST_BANDWIDTH
testForceUserBilling(db, userId)
if !testVerifyChargeApprox(db, userId, "bw-test", expectedCharge * 9 / 10, expectedCharge * 11 / 10) {
t.Error("User charged differently than expected with long time ago virtual machine")
return
t.Fatal("User charged differently than expected with long time ago virtual machine")
}
}

0 comments on commit 1ecb075

Please sign in to comment.