Skip to content

Commit 51cf352

Browse files
author
José Carlos
authored
Merge pull request #307 from upper/issue-297
Honor omitempty on InsertInto(). Closes #297
2 parents 8232c84 + f611a71 commit 51cf352

File tree

6 files changed

+173
-45
lines changed

6 files changed

+173
-45
lines changed

Diff for: Makefile

+3-5
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,9 @@ DB_HOST ?= 127.0.0.1
55
export DB_HOST
66

77
test:
8-
go test -v -benchtime=500ms -bench=. ./lib/... & \
9-
go test -v -benchtime=500ms -bench=. ./internal/... & \
10-
wait && \
8+
go test -v -benchtime=500ms -bench=. ./lib/... && \
9+
go test -v -benchtime=500ms -bench=. ./internal/... && \
1110
for ADAPTER in postgresql mysql sqlite ql mongo; do \
12-
$(MAKE) -C $$ADAPTER test & \
11+
$(MAKE) -C $$ADAPTER test; \
1312
done && \
14-
wait && \
1513
go test -v

Diff for: lib/sqlbuilder/builder.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ func columnFragments(template *templateWithUtils, columns []interface{}) ([]exql
326326
for i := 0; i < l; i++ {
327327
switch v := columns[i].(type) {
328328
case *selector:
329-
expanded, rawArgs := expandPlaceholders(v.statement().Compile(v.stringer.t), v.Arguments())
329+
expanded, rawArgs := Preprocess(v.statement().Compile(v.stringer.t), v.Arguments())
330330
f[i] = exql.RawValue(expanded)
331331
args = append(args, rawArgs...)
332332
case db.Function:
@@ -336,11 +336,11 @@ func columnFragments(template *templateWithUtils, columns []interface{}) ([]exql
336336
} else {
337337
fnName = fnName + "(?" + strings.Repeat("?, ", len(fnArgs)-1) + ")"
338338
}
339-
expanded, fnArgs := expandPlaceholders(fnName, fnArgs)
339+
expanded, fnArgs := Preprocess(fnName, fnArgs)
340340
f[i] = exql.RawValue(expanded)
341341
args = append(args, fnArgs...)
342342
case db.RawValue:
343-
expanded, rawArgs := expandPlaceholders(v.Raw(), v.Arguments())
343+
expanded, rawArgs := Preprocess(v.Raw(), v.Arguments())
344344
f[i] = exql.RawValue(expanded)
345345
args = append(args, rawArgs...)
346346
case exql.Fragment:

Diff for: lib/sqlbuilder/builder_test.go

+83
Original file line numberDiff line numberDiff line change
@@ -746,6 +746,89 @@ func TestInsert(t *testing.T) {
746746
)
747747
}
748748

749+
{
750+
type artistStruct struct {
751+
ID int `db:"id,omitempty"`
752+
Name string `db:"name,omitempty"`
753+
}
754+
755+
assert.Equal(
756+
`INSERT INTO "artist" ("name") VALUES ($1)`,
757+
b.InsertInto("artist").
758+
Values(artistStruct{Name: "Chavela Vargas"}).
759+
String(),
760+
)
761+
762+
assert.Equal(
763+
`INSERT INTO "artist" ("id") VALUES ($1)`,
764+
b.InsertInto("artist").
765+
Values(artistStruct{ID: 1}).
766+
String(),
767+
)
768+
}
769+
770+
{
771+
type artistStruct struct {
772+
ID int `db:"id,omitempty"`
773+
Name string `db:"name,omitempty"`
774+
}
775+
776+
{
777+
q := b.InsertInto("artist").Values(artistStruct{Name: "Chavela Vargas"})
778+
779+
assert.Equal(
780+
`INSERT INTO "artist" ("name") VALUES ($1)`,
781+
q.String(),
782+
)
783+
assert.Equal(
784+
[]interface{}{"Chavela Vargas"},
785+
q.Arguments(),
786+
)
787+
}
788+
789+
{
790+
q := b.InsertInto("artist").Values(artistStruct{Name: "Chavela Vargas"}).Values(artistStruct{Name: "Alondra de la Parra"})
791+
792+
assert.Equal(
793+
`INSERT INTO "artist" ("name") VALUES ($1), ($2)`,
794+
q.String(),
795+
)
796+
assert.Equal(
797+
[]interface{}{"Chavela Vargas", "Alondra de la Parra"},
798+
q.Arguments(),
799+
)
800+
}
801+
802+
{
803+
q := b.InsertInto("artist").Values(artistStruct{ID: 1})
804+
805+
assert.Equal(
806+
`INSERT INTO "artist" ("id") VALUES ($1)`,
807+
q.String(),
808+
)
809+
810+
assert.Equal(
811+
[]interface{}{1},
812+
q.Arguments(),
813+
)
814+
}
815+
816+
{
817+
q := b.InsertInto("artist").Values(artistStruct{ID: 1}).Values(artistStruct{ID: 2})
818+
819+
assert.Equal(
820+
`INSERT INTO "artist" ("id") VALUES ($1), ($2)`,
821+
q.String(),
822+
)
823+
824+
assert.Equal(
825+
[]interface{}{1, 2},
826+
q.Arguments(),
827+
)
828+
}
829+
830+
}
831+
749832
{
750833
intRef := func(i int) *int {
751834
if i == 0 {

Diff for: lib/sqlbuilder/convert.go

+4-9
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,6 @@ func Preprocess(in string, args []interface{}) (string, []interface{}) {
7878
return expandQuery(in, args, preprocessFn)
7979
}
8080

81-
func expandPlaceholders(in string, args []interface{}) (string, []interface{}) {
82-
// TODO: Remove after immutable query builder
83-
return in, args
84-
}
85-
8681
// ToWhereWithArguments converts the given parameters into a exql.Where
8782
// value.
8883
func (tu *templateWithUtils) ToWhereWithArguments(term interface{}) (where exql.Where, args []interface{}) {
@@ -93,7 +88,7 @@ func (tu *templateWithUtils) ToWhereWithArguments(term interface{}) (where exql.
9388
if len(t) > 0 {
9489
if s, ok := t[0].(string); ok {
9590
if strings.ContainsAny(s, "?") || len(t) == 1 {
96-
s, args = expandPlaceholders(s, t[1:])
91+
s, args = Preprocess(s, t[1:])
9792
where.Conditions = []exql.Fragment{exql.RawValue(s)}
9893
} else {
9994
var val interface{}
@@ -122,7 +117,7 @@ func (tu *templateWithUtils) ToWhereWithArguments(term interface{}) (where exql.
122117
}
123118
return
124119
case db.RawValue:
125-
r, v := expandPlaceholders(t.Raw(), t.Arguments())
120+
r, v := Preprocess(t.Raw(), t.Arguments())
126121
where.Conditions = []exql.Fragment{exql.RawValue(r)}
127122
args = append(args, v...)
128123
return
@@ -294,11 +289,11 @@ func (tu *templateWithUtils) ToColumnValues(term interface{}) (cv exql.ColumnVal
294289
// A function with one or more arguments.
295290
fnName = fnName + "(?" + strings.Repeat("?, ", len(fnArgs)-1) + ")"
296291
}
297-
expanded, fnArgs := expandPlaceholders(fnName, fnArgs)
292+
expanded, fnArgs := Preprocess(fnName, fnArgs)
298293
columnValue.Value = exql.RawValue(expanded)
299294
args = append(args, fnArgs...)
300295
case db.RawValue:
301-
expanded, rawArgs := expandPlaceholders(value.Raw(), value.Arguments())
296+
expanded, rawArgs := Preprocess(value.Raw(), value.Arguments())
302297
columnValue.Value = exql.RawValue(expanded)
303298
args = append(args, rawArgs...)
304299
default:

Diff for: lib/sqlbuilder/insert.go

+78-26
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,19 @@ package sqlbuilder
22

33
import (
44
"database/sql"
5+
"sync"
56

67
"upper.io/db.v2/internal/sqladapter/exql"
78
)
89

910
type inserter struct {
1011
*stringer
11-
builder *sqlBuilder
12-
table string
13-
values []*exql.Values
12+
builder *sqlBuilder
13+
table string
14+
15+
enqueuedValues [][]interface{}
16+
mu sync.Mutex
17+
1418
returning []exql.Fragment
1519
columns []exql.Fragment
1620
arguments []interface{}
@@ -28,6 +32,7 @@ func (qi *inserter) Batch(n int) *BatchInserter {
2832
}
2933

3034
func (qi *inserter) Arguments() []interface{} {
35+
_ = qi.statement()
3136
return qi.arguments
3237
}
3338

@@ -69,34 +74,77 @@ func (qi *inserter) Columns(columns ...string) Inserter {
6974
}
7075

7176
func (qi *inserter) Values(values ...interface{}) Inserter {
72-
if len(values) == 1 {
73-
ff, vv, err := Map(values[0], &MapOptions{IncludeZeroed: true, IncludeNil: true})
74-
if err == nil {
75-
columns, vals, arguments, _ := qi.builder.t.ToColumnsValuesAndArguments(ff, vv)
76-
77-
qi.arguments = append(qi.arguments, arguments...)
78-
qi.values = append(qi.values, vals)
79-
if len(qi.columns) == 0 {
80-
for _, c := range columns.Columns {
81-
qi.columns = append(qi.columns, c)
77+
qi.mu.Lock()
78+
defer qi.mu.Unlock()
79+
80+
if qi.enqueuedValues == nil {
81+
qi.enqueuedValues = [][]interface{}{}
82+
}
83+
qi.enqueuedValues = append(qi.enqueuedValues, values)
84+
return qi
85+
}
86+
87+
func (qi *inserter) processValues() (values []*exql.Values, arguments []interface{}) {
88+
// TODO: simplify with immutable queries
89+
var insertNils bool
90+
91+
for _, enqueuedValue := range qi.enqueuedValues {
92+
if len(enqueuedValue) == 1 {
93+
ff, vv, err := Map(enqueuedValue[0], nil)
94+
if err == nil {
95+
columns, vals, args, _ := qi.builder.t.ToColumnsValuesAndArguments(ff, vv)
96+
97+
values, arguments = append(values, vals), append(arguments, args...)
98+
99+
if len(qi.columns) == 0 {
100+
for _, c := range columns.Columns {
101+
qi.columns = append(qi.columns, c)
102+
}
103+
} else {
104+
if len(qi.columns) != len(columns.Columns) {
105+
insertNils = true
106+
break
107+
}
82108
}
109+
continue
83110
}
84-
return qi
85111
}
86-
}
87112

88-
if len(qi.columns) == 0 || len(values) == len(qi.columns) {
89-
qi.arguments = append(qi.arguments, values...)
113+
if len(qi.columns) == 0 || len(enqueuedValue) == len(qi.columns) {
114+
arguments = append(arguments, enqueuedValue...)
90115

91-
l := len(values)
92-
placeholders := make([]exql.Fragment, l)
93-
for i := 0; i < l; i++ {
94-
placeholders[i] = exql.RawValue(`?`)
116+
l := len(enqueuedValue)
117+
placeholders := make([]exql.Fragment, l)
118+
for i := 0; i < l; i++ {
119+
placeholders[i] = exql.RawValue(`?`)
120+
}
121+
values = append(values, exql.NewValueGroup(placeholders...))
95122
}
96-
qi.values = append(qi.values, exql.NewValueGroup(placeholders...))
97123
}
98124

99-
return qi
125+
if insertNils {
126+
values, arguments = values[0:0], arguments[0:0]
127+
128+
for _, enqueuedValue := range qi.enqueuedValues {
129+
if len(enqueuedValue) == 1 {
130+
ff, vv, err := Map(enqueuedValue[0], &MapOptions{IncludeZeroed: true, IncludeNil: true})
131+
if err == nil {
132+
columns, vals, args, _ := qi.builder.t.ToColumnsValuesAndArguments(ff, vv)
133+
values, arguments = append(values, vals), append(arguments, args...)
134+
135+
if len(qi.columns) != len(columns.Columns) {
136+
qi.columns = qi.columns[0:0]
137+
for _, c := range columns.Columns {
138+
qi.columns = append(qi.columns, c)
139+
}
140+
}
141+
}
142+
continue
143+
}
144+
}
145+
}
146+
147+
return
100148
}
101149

102150
func (qi *inserter) statement() *exql.Statement {
@@ -105,14 +153,18 @@ func (qi *inserter) statement() *exql.Statement {
105153
Table: exql.TableWithName(qi.table),
106154
}
107155

108-
if len(qi.values) > 0 {
109-
stmt.Values = exql.JoinValueGroups(qi.values...)
110-
}
156+
values, arguments := qi.processValues()
157+
158+
qi.arguments = arguments
111159

112160
if len(qi.columns) > 0 {
113161
stmt.Columns = exql.JoinColumns(qi.columns...)
114162
}
115163

164+
if len(values) > 0 {
165+
stmt.Values = exql.JoinValueGroups(values...)
166+
}
167+
116168
if len(qi.returning) > 0 {
117169
stmt.Returning = exql.ReturningColumns(qi.returning...)
118170
}

Diff for: lib/sqlbuilder/select.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ func (qs *selector) OrderBy(columns ...interface{}) Selector {
156156

157157
switch value := columns[i].(type) {
158158
case db.RawValue:
159-
col, args := expandPlaceholders(value.Raw(), value.Arguments())
159+
col, args := Preprocess(value.Raw(), value.Arguments())
160160
sort = &exql.SortColumn{
161161
Column: exql.RawValue(col),
162162
}
@@ -170,7 +170,7 @@ func (qs *selector) OrderBy(columns ...interface{}) Selector {
170170
} else {
171171
fnName = fnName + "(?" + strings.Repeat("?, ", len(fnArgs)-1) + ")"
172172
}
173-
expanded, fnArgs := expandPlaceholders(fnName, fnArgs)
173+
expanded, fnArgs := Preprocess(fnName, fnArgs)
174174
sort = &exql.SortColumn{
175175
Column: exql.RawValue(expanded),
176176
}

0 commit comments

Comments
 (0)