Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 bug: Fix square bracket notation in Multipart FormData #3235

Merged
merged 13 commits into from
Dec 31, 2024
103 changes: 100 additions & 3 deletions bind_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"encoding/json"
"errors"
"fmt"
"mime/multipart"
"net/http/httptest"
"reflect"
"testing"
Expand Down Expand Up @@ -878,7 +879,8 @@ func Test_Bind_Body(t *testing.T) {
reqBody := []byte(`{"name":"john"}`)

type Demo struct {
Name string `json:"name" xml:"name" form:"name" query:"name"`
Name string `json:"name" xml:"name" form:"name" query:"name"`
Names []string `json:"names" xml:"names" form:"names" query:"names"`
}

// Helper function to test compressed bodies
Expand Down Expand Up @@ -988,6 +990,48 @@ func Test_Bind_Body(t *testing.T) {
Data []Demo `query:"data"`
}

t.Run("MultipartCollectionQueryDotNotation", func(t *testing.T) {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().Reset()

buf := &bytes.Buffer{}
writer := multipart.NewWriter(buf)
require.NoError(t, writer.WriteField("data.0.name", "john"))
require.NoError(t, writer.WriteField("data.1.name", "doe"))
require.NoError(t, writer.Close())

c.Request().Header.SetContentType(writer.FormDataContentType())
c.Request().SetBody(buf.Bytes())
c.Request().Header.SetContentLength(len(c.Body()))

cq := new(CollectionQuery)
require.NoError(t, c.Bind().Body(cq))
require.Len(t, cq.Data, 2)
require.Equal(t, "john", cq.Data[0].Name)
require.Equal(t, "doe", cq.Data[1].Name)
})

t.Run("MultipartCollectionQuerySquareBrackets", func(t *testing.T) {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().Reset()

buf := &bytes.Buffer{}
writer := multipart.NewWriter(buf)
require.NoError(t, writer.WriteField("data[0][name]", "john"))
require.NoError(t, writer.WriteField("data[1][name]", "doe"))
require.NoError(t, writer.Close())

c.Request().Header.SetContentType(writer.FormDataContentType())
c.Request().SetBody(buf.Bytes())
c.Request().Header.SetContentLength(len(c.Body()))

cq := new(CollectionQuery)
require.NoError(t, c.Bind().Body(cq))
require.Len(t, cq.Data, 2)
require.Equal(t, "john", cq.Data[0].Name)
require.Equal(t, "doe", cq.Data[1].Name)
})

t.Run("CollectionQuerySquareBrackets", func(t *testing.T) {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().Reset()
Expand Down Expand Up @@ -1184,9 +1228,57 @@ func Benchmark_Bind_Body_MultipartForm(b *testing.B) {
Name string `form:"name"`
}

body := []byte("--b\r\nContent-Disposition: form-data; name=\"name\"\r\n\r\njohn\r\n--b--")
buf := &bytes.Buffer{}
writer := multipart.NewWriter(buf)
require.NoError(b, writer.WriteField("name", "john"))
require.NoError(b, writer.Close())
body := buf.Bytes()

c.Request().SetBody(body)
c.Request().Header.SetContentType(MIMEMultipartForm + `;boundary=` + writer.Boundary())
c.Request().Header.SetContentLength(len(body))
d := new(Demo)

b.ReportAllocs()
b.ResetTimer()

for n := 0; n < b.N; n++ {
err = c.Bind().Body(d)
}

require.NoError(b, err)
require.Equal(b, "john", d.Name)
}

// go test -v -run=^$ -bench=Benchmark_Bind_Body_MultipartForm_Nested -benchmem -count=4
func Benchmark_Bind_Body_MultipartForm_Nested(b *testing.B) {
var err error

app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})

type Person struct {
Name string `form:"name"`
Age int `form:"age"`
}

type Demo struct {
Name string `form:"name"`
Persons []Person `form:"persons"`
}

buf := &bytes.Buffer{}
writer := multipart.NewWriter(buf)
require.NoError(b, writer.WriteField("name", "john"))
require.NoError(b, writer.WriteField("persons.0.name", "john"))
require.NoError(b, writer.WriteField("persons[0][age]", "10"))
require.NoError(b, writer.WriteField("persons[1][name]", "doe"))
require.NoError(b, writer.WriteField("persons.1.age", "20"))
require.NoError(b, writer.Close())
body := buf.Bytes()

c.Request().SetBody(body)
c.Request().Header.SetContentType(MIMEMultipartForm + `;boundary="b"`)
c.Request().Header.SetContentType(MIMEMultipartForm + `;boundary=` + writer.Boundary())
c.Request().Header.SetContentLength(len(body))
d := new(Demo)

Expand All @@ -1196,8 +1288,13 @@ func Benchmark_Bind_Body_MultipartForm(b *testing.B) {
for n := 0; n < b.N; n++ {
err = c.Bind().Body(d)
}

require.NoError(b, err)
require.Equal(b, "john", d.Name)
require.Equal(b, "john", d.Persons[0].Name)
require.Equal(b, 10, d.Persons[0].Age)
require.Equal(b, "doe", d.Persons[1].Name)
require.Equal(b, 20, d.Persons[1].Age)
}

// go test -v -run=^$ -bench=Benchmark_Bind_Body_Form_Map -benchmem -count=4
Expand Down
22 changes: 21 additions & 1 deletion binder/form.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,5 +57,25 @@
return err
}

return parse(b.Name(), out, data.Value)
temp := make(map[string][]string)
for key, values := range data.Value {
ReneWerner87 marked this conversation as resolved.
Show resolved Hide resolved
if strings.Contains(key, "[") {
k, err := parseParamSquareBrackets(key)
if err != nil {
return err
}

Check warning on line 67 in binder/form.go

View check run for this annotation

Codecov / codecov/patch

binder/form.go#L66-L67

Added lines #L66 - L67 were not covered by tests
key = k // We have to update key in case bracket notation and slice type are used at the same time
}

for _, v := range values {
if strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, key) {
temp[key] = strings.Split(v, ",")
} else {
temp[key] = append(temp[key], v)

Check warning on line 75 in binder/form.go

View check run for this annotation

Codecov / codecov/patch

binder/form.go#L72-L75

Added lines #L72 - L75 were not covered by tests
}
}

Check warning on line 77 in binder/form.go

View check run for this annotation

Codecov / codecov/patch

binder/form.go#L77

Added line #L77 was not covered by tests
}

return parse(b.Name(), out, temp)
}
Loading