Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions select.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,21 @@ type selectData struct {
Limit string
Offset string
Suffixes []Sqlizer

// Unions stuff

Union []Sqlizer
UnionAll []Sqlizer

// Unions can have their own OFFSET, LIMIT and ORDER BY clauses.
// Example:
// (SELECT a, b FROM test OFFSET 1 LIMIT 1 ORDER BY a)
// UNION
// (SELECT a, b FROM test OFFSET 2 LIMIT 2 ORDER BY b)
// OFFSET 1 LIMIT 1 ORDER BY b DESC, a ASC
UnionOffset string
UnionLimit string
UnionOrderByParts []Sqlizer
}

func (d *selectData) Exec() (sql.Result, error) {
Expand Down Expand Up @@ -78,6 +93,11 @@ func (d *selectData) toSqlRaw() (sqlStr string, args []interface{}, err error) {
sql.WriteString(" ")
}

hasUnion := len(d.Union) > 0 || len(d.UnionAll) > 0
if hasUnion {
sql.WriteRune('(')
}

sql.WriteString("SELECT ")

if len(d.Options) > 0 {
Expand Down Expand Up @@ -156,6 +176,45 @@ func (d *selectData) toSqlRaw() (sqlStr string, args []interface{}, err error) {
}
}

if hasUnion {
sql.WriteRune(')')
}

if len(d.Union) > 0 {
sql.WriteString(" UNION ")
args, err = appendToSql(d.Union, sql, " UNION ", args)
if err != nil {
return
}
}
if len(d.UnionAll) > 0 {
sql.WriteString(" UNION ALL ")
args, err = appendToSql(d.UnionAll, sql, " UNION ALL ", args)
if err != nil {
return
}
}

if len(d.Union) > 0 || len(d.UnionAll) > 0 {
if len(d.UnionOrderByParts) > 0 {
sql.WriteString(" ORDER BY ")
args, err = appendToSql(d.UnionOrderByParts, sql, ", ", args)
if err != nil {
return
}
}

if len(d.UnionLimit) > 0 {
sql.WriteString(" LIMIT ")
sql.WriteString(d.UnionLimit)
}

if len(d.UnionOffset) > 0 {
sql.WriteString(" OFFSET ")
sql.WriteString(d.UnionOffset)
}
}

sqlStr = sql.String()
return
}
Expand Down Expand Up @@ -363,6 +422,11 @@ func (b SelectBuilder) OrderByClause(pred interface{}, args ...interface{}) Sele
return builder.Append(b, "OrderByParts", newPart(pred, args...)).(SelectBuilder)
}

// UnionOrderByClause adds ORDER BY clause to the UNION (ALL) query.
func (b SelectBuilder) UnionOrderByClause(pred interface{}, args ...interface{}) SelectBuilder {
return builder.Append(b, "UnionOrderByParts", newPart(pred, args...)).(SelectBuilder)
}

// OrderBy adds ORDER BY expressions to the query.
func (b SelectBuilder) OrderBy(orderBys ...string) SelectBuilder {
for _, orderBy := range orderBys {
Expand All @@ -372,26 +436,55 @@ func (b SelectBuilder) OrderBy(orderBys ...string) SelectBuilder {
return b
}

// UnionOrderBy adds ORDER BY expressions to the UNION (ALL) query.
func (b SelectBuilder) UnionOrderBy(orderBys ...string) SelectBuilder {
for _, orderBy := range orderBys {
b = b.UnionOrderByClause(orderBy)
}

return b
}

// Limit sets a LIMIT clause on the query.
func (b SelectBuilder) Limit(limit uint64) SelectBuilder {
return builder.Set(b, "Limit", fmt.Sprintf("%d", limit)).(SelectBuilder)
}

// UnionLimit sets a LIMIT clause on the UNION (ALL) query.
func (b SelectBuilder) UnionLimit(unionLimit uint64) SelectBuilder {
return builder.Set(b, "UnionLimit", fmt.Sprintf("%d", unionLimit)).(SelectBuilder)
}

// Limit ALL allows to access all records with limit
func (b SelectBuilder) RemoveLimit() SelectBuilder {
return builder.Delete(b, "Limit").(SelectBuilder)
}

// Limit ALL allows to access all records with limit
func (b SelectBuilder) RemoveUnionLimit() SelectBuilder {
return builder.Delete(b, "UnionLimit").(SelectBuilder)
}

// Offset sets a OFFSET clause on the query.
func (b SelectBuilder) Offset(offset uint64) SelectBuilder {
return builder.Set(b, "Offset", fmt.Sprintf("%d", offset)).(SelectBuilder)
}

// UnionOffset sets a OFFSET clause on the UNION (ALL) query.
func (b SelectBuilder) UnionOffset(offset uint64) SelectBuilder {
return builder.Set(b, "UnionOffset", fmt.Sprintf("%d", offset)).(SelectBuilder)
}

// RemoveOffset removes OFFSET clause.
func (b SelectBuilder) RemoveOffset() SelectBuilder {
return builder.Delete(b, "Offset").(SelectBuilder)
}

// RemoveUnionOffset removes OFFSET clause from UNION (ALL) query.
func (b SelectBuilder) RemoveUnionOffset() SelectBuilder {
return builder.Delete(b, "UnionOffset").(SelectBuilder)
}

// Suffix adds an expression to the end of the query
func (b SelectBuilder) Suffix(sql string, args ...interface{}) SelectBuilder {
return b.SuffixExpr(Expr(sql, args...))
Expand All @@ -401,3 +494,13 @@ func (b SelectBuilder) Suffix(sql string, args ...interface{}) SelectBuilder {
func (b SelectBuilder) SuffixExpr(expr Sqlizer) SelectBuilder {
return builder.Append(b, "Suffixes", expr).(SelectBuilder)
}

// Union adds a UNION clause to the query
func (b SelectBuilder) Union(query interface{}, args ...interface{}) SelectBuilder {
return builder.Append(b, "Union", newUnionPart(query, args...)).(SelectBuilder)
}

// UnionAll adds a UNION ALL clause to the query
func (b SelectBuilder) UnionAll(query interface{}, args ...interface{}) SelectBuilder {
return builder.Append(b, "UnionAll", newUnionPart(query, args...)).(SelectBuilder)
}
8 changes: 6 additions & 2 deletions select_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ func TestSelectBuilderToSql(t *testing.T) {
OrderBy("o ASC", "p DESC").
Limit(12).
Offset(13).
Union("SELECT * FROM q TABLE WHERE r=?", 15).
UnionAll(Select("*").From("s").Where("t=?", 16)).
Suffix("FETCH FIRST ? ROWS ONLY", 14)

sql, args, err := b.ToSql()
Expand All @@ -52,10 +54,12 @@ func TestSelectBuilderToSql(t *testing.T) {
"CROSS JOIN j1 JOIN j2 LEFT JOIN j3 RIGHT JOIN j4 INNER JOIN j5 CROSS JOIN j6 " +
"WHERE f = ? AND g = ? AND h = ? AND i IN (?,?,?) AND (j = ? OR (k = ? AND true)) " +
"GROUP BY l HAVING m = n ORDER BY ? DESC, o ASC, p DESC LIMIT 12 OFFSET 13 " +
"FETCH FIRST ? ROWS ONLY"
"FETCH FIRST ? ROWS ONLY" +
"UNION SELECT * FROM q TABLE WHERE r=? " +
"UNION ALL SELECT * FROM s WHERE t=?"
assert.Equal(t, expectedSql, sql)

expectedArgs := []interface{}{0, 1, 2, 3, 100, 101, 102, 103, 4, 5, 6, 7, 8, 9, 10, 11, 1, 14}
expectedArgs := []interface{}{0, 1, 2, 3, 100, 101, 102, 103, 4, 5, 6, 7, 8, 9, 10, 11, 1, 14, 15, 16}
assert.Equal(t, expectedArgs, args)
}

Expand Down
27 changes: 27 additions & 0 deletions union.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package squirrel

import (
"fmt"

"github.com/lann/builder"
)

type unionPart part

func newUnionPart(pred interface{}, args ...interface{}) Sqlizer {
return &unionPart{pred: pred, args: args}
}
func (p unionPart) ToSql() (sqlStr string, args []interface{}, err error) {
switch pred := p.pred.(type) {
case SelectBuilder:
entity := builder.GetStruct(pred).(selectData)
sqlStr, args, err = entity.ToSql()
sqlStr = "(" + sqlStr + ")"
case string:
sqlStr = pred
args = p.args
default:
err = fmt.Errorf("expected string or SelectBuilder, not %T", pred)
}
return
}
25 changes: 25 additions & 0 deletions union_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package squirrel

import (
"bytes"
"testing"

"github.com/stretchr/testify/assert"
)

func TestUnionPartsAppendToSqlWithString(t *testing.T) {
parts := []Sqlizer{
newUnionPart(Select("*").Where("col = ?", 10)),
newUnionPart("select * from TEST where col = ", "hello"),
}
sql := &bytes.Buffer{}
args, _ := appendToSql(parts, sql, " UNION ", []interface{}{})
assert.Equal(t, "SELECT * WHERE col = ? UNION select * from TEST where col = ", sql.String())
assert.Equal(t, []interface{}{10, "hello"}, args)
}

func TestUnionPartsAppendToSqlErr(t *testing.T) {
parts := []Sqlizer{newUnionPart(1)}
_, err := appendToSql(parts, &bytes.Buffer{}, "", []interface{}{})
assert.Error(t, err)
}