Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 bug: Fix square bracket notation in Multipart FormData #3235

Merged
merged 13 commits into from
Dec 31, 2024
103 changes: 100 additions & 3 deletions bind_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"encoding/json"
"errors"
"fmt"
"mime/multipart"
"net/http/httptest"
"reflect"
"testing"
Expand Down Expand Up @@ -886,7 +887,8 @@ func Test_Bind_Body(t *testing.T) {
reqBody := []byte(`{"name":"john"}`)

type Demo struct {
Name string `json:"name" xml:"name" form:"name" query:"name"`
Name string `json:"name" xml:"name" form:"name" query:"name"`
Names []string `json:"names" xml:"names" form:"names" query:"names"`
}

// Helper function to test compressed bodies
Expand Down Expand Up @@ -996,6 +998,48 @@ func Test_Bind_Body(t *testing.T) {
Data []Demo `query:"data"`
}

t.Run("MultipartCollectionQueryDotNotation", func(t *testing.T) {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().Reset()

buf := &bytes.Buffer{}
writer := multipart.NewWriter(buf)
require.NoError(t, writer.WriteField("data.0.name", "john"))
require.NoError(t, writer.WriteField("data.1.name", "doe"))
require.NoError(t, writer.Close())

c.Request().Header.SetContentType(writer.FormDataContentType())
c.Request().SetBody(buf.Bytes())
c.Request().Header.SetContentLength(len(c.Body()))

cq := new(CollectionQuery)
require.NoError(t, c.Bind().Body(cq))
require.Len(t, cq.Data, 2)
require.Equal(t, "john", cq.Data[0].Name)
require.Equal(t, "doe", cq.Data[1].Name)
})

t.Run("MultipartCollectionQuerySquareBrackets", func(t *testing.T) {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().Reset()

buf := &bytes.Buffer{}
writer := multipart.NewWriter(buf)
require.NoError(t, writer.WriteField("data[0][name]", "john"))
require.NoError(t, writer.WriteField("data[1][name]", "doe"))
require.NoError(t, writer.Close())

c.Request().Header.SetContentType(writer.FormDataContentType())
c.Request().SetBody(buf.Bytes())
c.Request().Header.SetContentLength(len(c.Body()))

cq := new(CollectionQuery)
require.NoError(t, c.Bind().Body(cq))
require.Len(t, cq.Data, 2)
require.Equal(t, "john", cq.Data[0].Name)
require.Equal(t, "doe", cq.Data[1].Name)
})

t.Run("CollectionQuerySquareBrackets", func(t *testing.T) {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().Reset()
Expand Down Expand Up @@ -1192,9 +1236,57 @@ func Benchmark_Bind_Body_MultipartForm(b *testing.B) {
Name string `form:"name"`
}

body := []byte("--b\r\nContent-Disposition: form-data; name=\"name\"\r\n\r\njohn\r\n--b--")
buf := &bytes.Buffer{}
writer := multipart.NewWriter(buf)
require.NoError(b, writer.WriteField("name", "john"))
require.NoError(b, writer.Close())
body := buf.Bytes()

c.Request().SetBody(body)
c.Request().Header.SetContentType(MIMEMultipartForm + `;boundary=` + writer.Boundary())
c.Request().Header.SetContentLength(len(body))
d := new(Demo)

b.ReportAllocs()
b.ResetTimer()

for n := 0; n < b.N; n++ {
err = c.Bind().Body(d)
}

require.NoError(b, err)
require.Equal(b, "john", d.Name)
}

// go test -v -run=^$ -bench=Benchmark_Bind_Body_MultipartForm_Nested -benchmem -count=4
func Benchmark_Bind_Body_MultipartForm_Nested(b *testing.B) {
var err error

app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})

type Person struct {
Name string `form:"name"`
Age int `form:"age"`
}

type Demo struct {
Name string `form:"name"`
Persons []Person `form:"persons"`
}

buf := &bytes.Buffer{}
writer := multipart.NewWriter(buf)
require.NoError(b, writer.WriteField("name", "john"))
require.NoError(b, writer.WriteField("persons.0.name", "john"))
require.NoError(b, writer.WriteField("persons[0][age]", "10"))
require.NoError(b, writer.WriteField("persons[1][name]", "doe"))
require.NoError(b, writer.WriteField("persons.1.age", "20"))
require.NoError(b, writer.Close())
body := buf.Bytes()

c.Request().SetBody(body)
c.Request().Header.SetContentType(MIMEMultipartForm + `;boundary="b"`)
c.Request().Header.SetContentType(MIMEMultipartForm + `;boundary=` + writer.Boundary())
c.Request().Header.SetContentLength(len(body))
d := new(Demo)

Expand All @@ -1204,8 +1296,13 @@ func Benchmark_Bind_Body_MultipartForm(b *testing.B) {
for n := 0; n < b.N; n++ {
err = c.Bind().Body(d)
}

require.NoError(b, err)
require.Equal(b, "john", d.Name)
require.Equal(b, "john", d.Persons[0].Name)
require.Equal(b, 10, d.Persons[0].Age)
require.Equal(b, "doe", d.Persons[1].Name)
require.Equal(b, 20, d.Persons[1].Age)
}

// go test -v -run=^$ -bench=Benchmark_Bind_Body_Form_Map -benchmem -count=4
Expand Down
13 changes: 1 addition & 12 deletions binder/cookie.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
package binder

import (
"reflect"
"strings"

"github.com/gofiber/utils/v2"
"github.com/valyala/fasthttp"
)
Expand All @@ -30,15 +27,7 @@ func (b *CookieBinding) Bind(req *fasthttp.Request, out any) error {

k := utils.UnsafeString(key)
v := utils.UnsafeString(val)

if b.EnableSplitting && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) {
values := strings.Split(v, ",")
for i := 0; i < len(values); i++ {
data[k] = append(data[k], values[i])
}
} else {
data[k] = append(data[k], v)
}
err = formatBindData(out, data, k, v, b.EnableSplitting, false)
})

if err != nil {
Expand Down
29 changes: 11 additions & 18 deletions binder/form.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
package binder

import (
"reflect"
"strings"

"github.com/gofiber/utils/v2"
"github.com/valyala/fasthttp"
)
Expand Down Expand Up @@ -37,19 +34,7 @@

k := utils.UnsafeString(key)
v := utils.UnsafeString(val)

if strings.Contains(k, "[") {
k, err = parseParamSquareBrackets(k)
}

if b.EnableSplitting && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) {
values := strings.Split(v, ",")
for i := 0; i < len(values); i++ {
data[k] = append(data[k], values[i])
}
} else {
data[k] = append(data[k], v)
}
err = formatBindData(out, data, k, v, b.EnableSplitting, true)
})

if err != nil {
Expand All @@ -61,12 +46,20 @@

// bindMultipart parses the request body and returns the result.
func (b *FormBinding) bindMultipart(req *fasthttp.Request, out any) error {
data, err := req.MultipartForm()
multipartForm, err := req.MultipartForm()
if err != nil {
return err
}

return parse(b.Name(), out, data.Value)
data := make(map[string][]string)
for key, values := range multipartForm.Value {
err = formatBindData(out, data, key, values, b.EnableSplitting, true)
if err != nil {
return err
}

Check warning on line 59 in binder/form.go

View check run for this annotation

Codecov / codecov/patch

binder/form.go#L58-L59

Added lines #L58 - L59 were not covered by tests
}

return parse(b.Name(), out, data)
}

// Reset resets the FormBinding binder.
Expand Down
16 changes: 15 additions & 1 deletion binder/form_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,14 @@ func Test_FormBinder_BindMultipart(t *testing.T) {
}
require.Equal(t, "form", b.Name())

type Post struct {
Title string `form:"title"`
}

type User struct {
Name string `form:"name"`
Names []string `form:"names"`
Posts []Post `form:"posts"`
Age int `form:"age"`
}
var user User
Expand All @@ -106,9 +111,13 @@ func Test_FormBinder_BindMultipart(t *testing.T) {
mw := multipart.NewWriter(buf)

require.NoError(t, mw.WriteField("name", "john"))
require.NoError(t, mw.WriteField("names", "john"))
require.NoError(t, mw.WriteField("names", "john,eric"))
require.NoError(t, mw.WriteField("names", "doe"))
require.NoError(t, mw.WriteField("age", "42"))
require.NoError(t, mw.WriteField("posts[0][title]", "post1"))
require.NoError(t, mw.WriteField("posts[1][title]", "post2"))
require.NoError(t, mw.WriteField("posts[2][title]", "post3"))

require.NoError(t, mw.Close())

req.Header.SetContentType(mw.FormDataContentType())
Expand All @@ -125,6 +134,11 @@ func Test_FormBinder_BindMultipart(t *testing.T) {
require.Equal(t, 42, user.Age)
require.Contains(t, user.Names, "john")
require.Contains(t, user.Names, "doe")
require.Contains(t, user.Names, "eric")
require.Len(t, user.Posts, 3)
require.Equal(t, "post1", user.Posts[0].Title)
require.Equal(t, "post2", user.Posts[1].Title)
require.Equal(t, "post3", user.Posts[2].Title)
}

func Benchmark_FormBinder_BindMultipart(b *testing.B) {
Expand Down
22 changes: 10 additions & 12 deletions binder/header.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
package binder

import (
"reflect"
"strings"

"github.com/gofiber/utils/v2"
"github.com/valyala/fasthttp"
)
Expand All @@ -21,20 +18,21 @@
// Bind parses the request header and returns the result.
func (b *HeaderBinding) Bind(req *fasthttp.Request, out any) error {
data := make(map[string][]string)
var err error
req.Header.VisitAll(func(key, val []byte) {
if err != nil {
return
}

Check warning on line 25 in binder/header.go

View check run for this annotation

Codecov / codecov/patch

binder/header.go#L24-L25

Added lines #L24 - L25 were not covered by tests

k := utils.UnsafeString(key)
v := utils.UnsafeString(val)

if b.EnableSplitting && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) {
values := strings.Split(v, ",")
for i := 0; i < len(values); i++ {
data[k] = append(data[k], values[i])
}
} else {
data[k] = append(data[k], v)
}
err = formatBindData(out, data, k, v, b.EnableSplitting, false)
})

if err != nil {
return err
}

Check warning on line 34 in binder/header.go

View check run for this annotation

Codecov / codecov/patch

binder/header.go#L33-L34

Added lines #L33 - L34 were not covered by tests

return parse(b.Name(), out, data)
}

Expand Down
38 changes: 37 additions & 1 deletion binder/mapping.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@
func parseToMap(ptr any, data map[string][]string) error {
elem := reflect.TypeOf(ptr).Elem()

switch elem.Kind() { //nolint:exhaustive // it's not necessary to check all types
switch elem.Kind() {
case reflect.Slice:
newMap, ok := ptr.(map[string][]string)
if !ok {
Expand All @@ -130,6 +130,8 @@
}
newMap[k] = v[len(v)-1]
}
default:
return nil // it's not necessary to check all types

Check warning on line 134 in binder/mapping.go

View check run for this annotation

Codecov / codecov/patch

binder/mapping.go#L133-L134

Added lines #L133 - L134 were not covered by tests
Comment on lines +133 to +134
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider returning an error for unsupported types.

The default case silently ignores unsupported types, which could hide potential issues. Consider returning an error to inform users when they attempt to parse unsupported types.

-	default:
-		return nil // it's not necessary to check all types
+	default:
+		return fmt.Errorf("unsupported type %v for map binding", elem.Kind())
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
default:
return nil // it's not necessary to check all types
default:
return fmt.Errorf("unsupported type %v for map binding", elem.Kind())

}

return nil
Expand Down Expand Up @@ -247,3 +249,37 @@
}
return content
}

func formatBindData[T any](out any, data map[string][]string, key string, value T, enableSplitting, supportBracketNotation bool) error { //nolint:revive // it's okay
var err error
if supportBracketNotation && strings.Contains(key, "[") {
key, err = parseParamSquareBrackets(key)
ReneWerner87 marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return err
}

Check warning on line 259 in binder/mapping.go

View check run for this annotation

Codecov / codecov/patch

binder/mapping.go#L258-L259

Added lines #L258 - L259 were not covered by tests
}

switch v := any(value).(type) {
case string:
assignBindData(out, data, key, v, enableSplitting)
case []string:
for _, val := range v {
assignBindData(out, data, key, val, enableSplitting)
}
default:
return fmt.Errorf("unsupported value type: %T", value)

Check warning on line 270 in binder/mapping.go

View check run for this annotation

Codecov / codecov/patch

binder/mapping.go#L269-L270

Added lines #L269 - L270 were not covered by tests
}

return err
}
ReneWerner87 marked this conversation as resolved.
Show resolved Hide resolved

func assignBindData(out any, data map[string][]string, key, value string, enableSplitting bool) { //nolint:revive // it's okay
if enableSplitting && strings.Contains(value, ",") && equalFieldType(out, reflect.Slice, key) {
values := strings.Split(value, ",")
for i := 0; i < len(values); i++ {
data[key] = append(data[key], values[i])
}
} else {
data[key] = append(data[key], value)
}
}
ReneWerner87 marked this conversation as resolved.
Show resolved Hide resolved
17 changes: 1 addition & 16 deletions binder/query.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
package binder

import (
"reflect"
"strings"

"github.com/gofiber/utils/v2"
"github.com/valyala/fasthttp"
)
Expand All @@ -30,19 +27,7 @@ func (b *QueryBinding) Bind(reqCtx *fasthttp.Request, out any) error {

k := utils.UnsafeString(key)
v := utils.UnsafeString(val)

if strings.Contains(k, "[") {
k, err = parseParamSquareBrackets(k)
}

if b.EnableSplitting && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) {
values := strings.Split(v, ",")
for i := 0; i < len(values); i++ {
data[k] = append(data[k], values[i])
}
} else {
data[k] = append(data[k], v)
}
err = formatBindData(out, data, k, v, b.EnableSplitting, true)
})

if err != nil {
Expand Down
Loading
Loading