Skip to content

Commit

Permalink
Merge branch 'main' into change-args-names-to-be-case-sensitive
Browse files Browse the repository at this point in the history
  • Loading branch information
jcchavezs authored May 16, 2024
2 parents 4225223 + 1c3776a commit 3da3bcd
Show file tree
Hide file tree
Showing 11 changed files with 444 additions and 202 deletions.
15 changes: 14 additions & 1 deletion http/interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ func (i *rwInterceptor) WriteHeader(statusCode int) {

i.statusCode = statusCode
if it := i.tx.ProcessResponseHeaders(statusCode, i.proto); it != nil {
i.cleanHeaders()
i.Header().Set("Content-Length", "0")
i.statusCode = obtainStatusCodeFromInterruptionOrDefault(it, i.statusCode)
i.flushWriteHeader()
Expand All @@ -65,6 +66,13 @@ func (i *rwInterceptor) flushWriteHeader() {
}
}

// cleanHeaders removes all headers from the response
func (i *rwInterceptor) cleanHeaders() {
for k := range i.w.Header() {
i.w.Header().Del(k)
}
}

// Write buffers the response body until the request body limit is reach or an
// interruption is triggered, this buffer is later used to analyse the body in
// the response processor.
Expand All @@ -88,7 +96,10 @@ func (i *rwInterceptor) Write(b []byte) (int, error) {
// to it, otherwise we just send it to the response writer.
it, n, err := i.tx.WriteResponseBody(b)
if it != nil {
i.overrideWriteHeader(it.Status)
// if there is an interruption we must clean the headers and override the status code
i.cleanHeaders()
i.Header().Set("Content-Length", "0")
i.overrideWriteHeader(obtainStatusCodeFromInterruptionOrDefault(it, i.statusCode))
// We only flush the status code after an interruption.
i.flushWriteHeader()
return 0, nil
Expand Down Expand Up @@ -153,6 +164,8 @@ func wrap(w http.ResponseWriter, r *http.Request, tx types.Transaction) (
i.flushWriteHeader()
return err
} else if it != nil {
// if there is an interruption we must clean the headers and override the status code
i.cleanHeaders()
i.Header().Set("Content-Length", "0")
i.overrideWriteHeader(obtainStatusCodeFromInterruptionOrDefault(it, i.statusCode))
i.flushWriteHeader()
Expand Down
1 change: 0 additions & 1 deletion http/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,5 @@ func obtainStatusCodeFromInterruptionOrDefault(it *types.Interruption, defaultSt

return statusCode
}

return defaultStatusCode
}
145 changes: 91 additions & 54 deletions http/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,42 +238,53 @@ type httpTest struct {
respBody string
expectedProto string
expectedStatus int
expectedRespHeadersKeys []string
expectedRespBody string
}

var expectedNoBlockingHeaders = []string{"Content-Type", "Content-Length", "Coraza-Middleware", "Date"}

// When an interruption occour, we are expecting that no response headers are sent back to the client.
var expectedBlockingHeaders = []string{"Content-Length", "Date"}

func TestHttpServer(t *testing.T) {
tests := map[string]httpTest{
"no blocking": {
reqURI: "/hello",
expectedProto: "HTTP/1.1",
expectedStatus: 201,
reqURI: "/hello",
expectedProto: "HTTP/1.1",
expectedStatus: 201,
expectedRespHeadersKeys: expectedNoBlockingHeaders,
},
"no blocking HTTP/2": {
http2: true,
reqURI: "/hello",
expectedProto: "HTTP/2.0",
expectedStatus: 201,
http2: true,
reqURI: "/hello",
expectedProto: "HTTP/2.0",
expectedStatus: 201,
expectedRespHeadersKeys: expectedNoBlockingHeaders,
},
"args blocking": {
reqURI: "/hello?id=0",
expectedProto: "HTTP/1.1",
expectedStatus: 403,
reqURI: "/hello?id=0",
expectedProto: "HTTP/1.1",
expectedStatus: 403,
expectedRespHeadersKeys: expectedBlockingHeaders,
},
"request body blocking": {
reqURI: "/hello",
reqBody: "eval('cat /etc/passwd')",
expectedProto: "HTTP/1.1",
expectedStatus: 403,
reqURI: "/hello",
reqBody: "eval('cat /etc/passwd')",
expectedProto: "HTTP/1.1",
expectedStatus: 403,
expectedRespHeadersKeys: expectedBlockingHeaders,
},
"request body larger than limit (process partial)": {
reqURI: "/hello",
reqBody: "eval('cat /etc/passwd')",
echoReqBody: true,
// Coraza only sees eva, not eval
reqBodyLimit: 3,
expectedProto: "HTTP/1.1",
expectedStatus: 201,
expectedRespBody: "eval('cat /etc/passwd')",
reqBodyLimit: 3,
expectedProto: "HTTP/1.1",
expectedStatus: 201,
expectedRespHeadersKeys: expectedNoBlockingHeaders,
expectedRespBody: "eval('cat /etc/passwd')",
},
"request body larger than limit (reject)": {
reqURI: "/hello",
Expand All @@ -283,37 +294,43 @@ func TestHttpServer(t *testing.T) {
shouldRejectOnBodyLimit: true,
expectedProto: "HTTP/1.1",
expectedStatus: 413,
expectedRespHeadersKeys: expectedBlockingHeaders,
expectedRespBody: "",
},
"response headers blocking": {
reqURI: "/hello",
respHeaders: map[string]string{"foo": "bar"},
expectedProto: "HTTP/1.1",
expectedStatus: 401,
reqURI: "/hello",
respHeaders: map[string]string{"foo": "bar"},
expectedProto: "HTTP/1.1",
expectedStatus: 401,
expectedRespHeadersKeys: expectedBlockingHeaders,
},
"response body not blocking": {
reqURI: "/hello",
respBody: "true negative response body",
expectedProto: "HTTP/1.1",
expectedStatus: 201,
expectedRespBody: "true negative response body",
reqURI: "/hello",
respBody: "true negative response body",
expectedProto: "HTTP/1.1",
expectedStatus: 201,
expectedRespHeadersKeys: expectedNoBlockingHeaders,
expectedRespBody: "true negative response body",
},
"response body blocking": {
reqURI: "/hello",
respBody: "password=xxxx",
expectedProto: "HTTP/1.1",
expectedStatus: 403,
expectedRespBody: "", // blocking at response body phase means returning it empty
reqURI: "/hello",
respBody: "password=xxxx",
expectedProto: "HTTP/1.1",
expectedStatus: 403,
expectedRespBody: "", // blocking at response body phase means returning it empty
expectedRespHeadersKeys: expectedBlockingHeaders,
},
"allow": {
reqURI: "/allow_me",
expectedProto: "HTTP/1.1",
expectedStatus: 201,
reqURI: "/allow_me",
expectedProto: "HTTP/1.1",
expectedStatus: 201,
expectedRespHeadersKeys: expectedNoBlockingHeaders,
},
"deny passes over allow due to ordering": {
reqURI: "/allow_me?id=0",
expectedProto: "HTTP/1.1",
expectedStatus: 403,
reqURI: "/allow_me?id=0",
expectedProto: "HTTP/1.1",
expectedStatus: 403,
expectedRespHeadersKeys: expectedBlockingHeaders,
},
}

Expand Down Expand Up @@ -357,26 +374,29 @@ func TestHttpServer(t *testing.T) {
func TestHttpServerWithRuleEngineOff(t *testing.T) {
tests := map[string]httpTest{
"no blocking true negative": {
reqURI: "/hello",
expectedProto: "HTTP/1.1",
expectedStatus: 201,
respBody: "Hello!",
expectedRespBody: "Hello!",
reqURI: "/hello",
expectedProto: "HTTP/1.1",
expectedStatus: 201,
respBody: "Hello!",
expectedRespHeadersKeys: expectedNoBlockingHeaders,
expectedRespBody: "Hello!",
},
"no blocking true positive header phase": {
reqURI: "/hello?id=0",
expectedProto: "HTTP/1.1",
expectedStatus: 201,
respBody: "Downstream works!",
expectedRespBody: "Downstream works!",
reqURI: "/hello?id=0",
expectedProto: "HTTP/1.1",
expectedStatus: 201,
respBody: "Downstream works!",
expectedRespHeadersKeys: expectedNoBlockingHeaders,
expectedRespBody: "Downstream works!",
},
"no blocking true positive body phase": {
reqURI: "/hello",
reqBody: "eval('cat /etc/passwd')",
expectedProto: "HTTP/1.1",
expectedStatus: 201,
respBody: "Waf is Off!",
expectedRespBody: "Waf is Off!",
reqURI: "/hello",
reqBody: "eval('cat /etc/passwd')",
expectedProto: "HTTP/1.1",
expectedStatus: 201,
respBody: "Waf is Off!",
expectedRespHeadersKeys: expectedNoBlockingHeaders,
expectedRespBody: "Waf is Off!",
},
}
logger := debuglog.Default().
Expand Down Expand Up @@ -458,6 +478,10 @@ func runAgainstWAF(t *testing.T, tCase httpTest, waf coraza.WAF) {
t.Errorf("unexpected status code, want: %d, have: %d", want, have)
}

if !keysExistInMap(t, tCase.expectedRespHeadersKeys, res.Header) {
t.Errorf("unexpected response headers, expected keys: %v, headers: %v", tCase.expectedRespHeadersKeys, res.Header)
}

resBody, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("unexpected error when reading the response body: %v", err)
Expand All @@ -480,6 +504,19 @@ func runAgainstWAF(t *testing.T, tCase httpTest, waf coraza.WAF) {
}
}

func keysExistInMap(t *testing.T, keys []string, m map[string][]string) bool {
t.Helper()
if len(keys) != len(m) {
return false
}
for _, key := range keys {
if _, ok := m[key]; !ok {
return false
}
}
return true
}

func TestObtainStatusCodeFromInterruptionOrDefault(t *testing.T) {
tCases := map[string]struct {
interruptionCode int
Expand Down
64 changes: 43 additions & 21 deletions internal/corazawaf/rule_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,27 +65,49 @@ func TestNoMatchEvaluate(t *testing.T) {
}

func TestNoMatchEvaluateBecauseOfException(t *testing.T) {
r := NewRule()
r.Msg, _ = macro.NewMacro("Message")
r.LogData, _ = macro.NewMacro("Data Message")
r.ID_ = 1
if err := r.AddVariable(variables.ArgsGet, "", false); err != nil {
t.Error(err)
}
dummyEqOp := &dummyEqOperator{}
r.SetOperator(dummyEqOp, "@eq", "0")
action := &dummyDenyAction{}
_ = r.AddAction("dummyDeny", action)
tx := NewWAF().NewTransaction()
tx.AddGetRequestArgument("test", "0")
tx.RemoveRuleTargetByID(1, variables.ArgsGet, "test")
var matchedValues []types.MatchData
matchdata := r.doEvaluate(debuglog.Noop(), types.PhaseRequestHeaders, tx, &matchedValues, 0, tx.transformationCache)
if len(matchdata) != 0 {
t.Errorf("Expected 0 matchdata, got %d", len(matchdata))
}
if tx.interruption != nil {
t.Errorf("Expected interruption not triggered because of RemoveRuleTargetByID")
testCases := []struct {
name string
variable variables.RuleVariable
}{
{
name: "Test ArgsGet target exception",
variable: variables.ArgsGet,
},
{
name: "Test Args target exception",
variable: variables.Args,
},
{
name: "Test ArgsNames target exception",
variable: variables.ArgsNames,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
r := NewRule()
r.Msg, _ = macro.NewMacro("Message")
r.LogData, _ = macro.NewMacro("Data Message")
r.ID_ = 1
if err := r.AddVariable(tc.variable, "", false); err != nil {
t.Error(err)
}
dummyEqOp := &dummyEqOperator{}
r.SetOperator(dummyEqOp, "@eq", "0")
action := &dummyDenyAction{}
_ = r.AddAction("dummyDeny", action)
tx := NewWAF().NewTransaction()
tx.AddGetRequestArgument("test", "0")
tx.RemoveRuleTargetByID(1, tc.variable, "test")
var matchedValues []types.MatchData
matchdata := r.doEvaluate(debuglog.Noop(), types.PhaseRequestHeaders, tx, &matchedValues, 0, tx.transformationCache)
if len(matchdata) != 0 {
t.Errorf("Expected 0 matchdata, got %d", len(matchdata))
}
if tx.interruption != nil {
t.Errorf("Expected interruption not triggered because of RemoveRuleTargetByID")
}
})
}
}

Expand Down
Loading

0 comments on commit 3da3bcd

Please sign in to comment.