Skip to content

Commit e13dbac

Browse files
committed
Merge pull request #37 from elgris/select_case
Support of CASE operator in SELECT query
2 parents 761840a + 5eadc24 commit e13dbac

File tree

5 files changed

+303
-2
lines changed

5 files changed

+303
-2
lines changed

case.go

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
package squirrel
2+
3+
import (
4+
"bytes"
5+
"errors"
6+
7+
"github.com/lann/builder"
8+
)
9+
10+
func init() {
11+
builder.Register(CaseBuilder{}, caseData{})
12+
}
13+
14+
// sqlizerBuffer is a helper that allows to write many Sqlizers one by one
15+
// without constant checks for errors that may come from Sqlizer
16+
type sqlizerBuffer struct {
17+
bytes.Buffer
18+
args []interface{}
19+
err error
20+
}
21+
22+
// WriteSql converts Sqlizer to SQL strings and writes it to buffer
23+
func (b *sqlizerBuffer) WriteSql(item Sqlizer) {
24+
if b.err != nil {
25+
return
26+
}
27+
28+
var str string
29+
var args []interface{}
30+
str, args, b.err = item.ToSql()
31+
32+
if b.err != nil {
33+
return
34+
}
35+
36+
b.WriteString(str)
37+
b.WriteByte(' ')
38+
b.args = append(b.args, args...)
39+
}
40+
41+
func (b *sqlizerBuffer) ToSql() (string, []interface{}, error) {
42+
return b.String(), b.args, b.err
43+
}
44+
45+
// whenPart is a helper structure to describe SQLs "WHEN ... THEN ..." expression
46+
type whenPart struct {
47+
when Sqlizer
48+
then Sqlizer
49+
}
50+
51+
func newWhenPart(when interface{}, then interface{}) whenPart {
52+
return whenPart{newPart(when), newPart(then)}
53+
}
54+
55+
// caseData holds all the data required to build a CASE SQL construct
56+
type caseData struct {
57+
What Sqlizer
58+
WhenParts []whenPart
59+
Else Sqlizer
60+
}
61+
62+
// ToSql implements Sqlizer
63+
func (d *caseData) ToSql() (sqlStr string, args []interface{}, err error) {
64+
if len(d.WhenParts) == 0 {
65+
err = errors.New("case expression must contain at lease one WHEN clause")
66+
67+
return
68+
}
69+
70+
sql := sqlizerBuffer{}
71+
72+
sql.WriteString("CASE ")
73+
if d.What != nil {
74+
sql.WriteSql(d.What)
75+
}
76+
77+
for _, p := range d.WhenParts {
78+
sql.WriteString("WHEN ")
79+
sql.WriteSql(p.when)
80+
sql.WriteString("THEN ")
81+
sql.WriteSql(p.then)
82+
}
83+
84+
if d.Else != nil {
85+
sql.WriteString("ELSE ")
86+
sql.WriteSql(d.Else)
87+
}
88+
89+
sql.WriteString("END")
90+
91+
return sql.ToSql()
92+
}
93+
94+
// CaseBuilder builds SQL CASE construct which could be used as parts of queries.
95+
type CaseBuilder builder.Builder
96+
97+
// ToSql builds the query into a SQL string and bound args.
98+
func (b CaseBuilder) ToSql() (string, []interface{}, error) {
99+
data := builder.GetStruct(b).(caseData)
100+
return data.ToSql()
101+
}
102+
103+
// what sets optional value for CASE construct "CASE [value] ..."
104+
func (b CaseBuilder) what(expr interface{}) CaseBuilder {
105+
return builder.Set(b, "What", newPart(expr)).(CaseBuilder)
106+
}
107+
108+
// When adds "WHEN ... THEN ..." part to CASE construct
109+
func (b CaseBuilder) When(when interface{}, then interface{}) CaseBuilder {
110+
// TODO: performance hint: replace slice of WhenPart with just slice of parts
111+
// where even indices of the slice belong to "when"s and odd indices belong to "then"s
112+
return builder.Append(b, "WhenParts", newWhenPart(when, then)).(CaseBuilder)
113+
}
114+
115+
// What sets optional "ELSE ..." part for CASE construct
116+
func (b CaseBuilder) Else(expr interface{}) CaseBuilder {
117+
return builder.Set(b, "Else", newPart(expr)).(CaseBuilder)
118+
}

case_test.go

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
package squirrel
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/assert"
7+
)
8+
9+
func TestCaseWithVal(t *testing.T) {
10+
caseStmt := Case("number").
11+
When("1", "one").
12+
When("2", "two").
13+
Else(Expr("?", "big number"))
14+
15+
qb := Select().
16+
Column(caseStmt).
17+
From("table")
18+
sql, args, err := qb.ToSql()
19+
20+
assert.NoError(t, err)
21+
22+
expectedSql := "SELECT CASE number " +
23+
"WHEN 1 THEN one " +
24+
"WHEN 2 THEN two " +
25+
"ELSE ? " +
26+
"END " +
27+
"FROM table"
28+
assert.Equal(t, expectedSql, sql)
29+
30+
expectedArgs := []interface{}{"big number"}
31+
assert.Equal(t, expectedArgs, args)
32+
}
33+
34+
func TestCaseWithComplexVal(t *testing.T) {
35+
caseStmt := Case("? > ?", 10, 5).
36+
When("true", "'T'")
37+
38+
qb := Select().
39+
Column(Alias(caseStmt, "complexCase")).
40+
From("table")
41+
sql, args, err := qb.ToSql()
42+
43+
assert.NoError(t, err)
44+
45+
expectedSql := "SELECT (CASE ? > ? " +
46+
"WHEN true THEN 'T' " +
47+
"END) AS complexCase " +
48+
"FROM table"
49+
assert.Equal(t, expectedSql, sql)
50+
51+
expectedArgs := []interface{}{10, 5}
52+
assert.Equal(t, expectedArgs, args)
53+
}
54+
55+
func TestCaseWithNoVal(t *testing.T) {
56+
caseStmt := Case().
57+
When(Eq{"x": 0}, "x is zero").
58+
When(Expr("x > ?", 1), Expr("CONCAT('x is greater than ', ?)", 2))
59+
60+
qb := Select().Column(caseStmt).From("table")
61+
sql, args, err := qb.ToSql()
62+
63+
assert.NoError(t, err)
64+
65+
expectedSql := "SELECT CASE " +
66+
"WHEN x = ? THEN x is zero " +
67+
"WHEN x > ? THEN CONCAT('x is greater than ', ?) " +
68+
"END " +
69+
"FROM table"
70+
71+
assert.Equal(t, expectedSql, sql)
72+
73+
expectedArgs := []interface{}{0, 1, 2}
74+
assert.Equal(t, expectedArgs, args)
75+
}
76+
77+
func TestCaseWithExpr(t *testing.T) {
78+
caseStmt := Case(Expr("x = ?", true)).
79+
When("true", Expr("?", "it's true!")).
80+
Else("42")
81+
82+
qb := Select().Column(caseStmt).From("table")
83+
sql, args, err := qb.ToSql()
84+
85+
assert.NoError(t, err)
86+
87+
expectedSql := "SELECT CASE x = ? " +
88+
"WHEN true THEN ? " +
89+
"ELSE 42 " +
90+
"END " +
91+
"FROM table"
92+
93+
assert.Equal(t, expectedSql, sql)
94+
95+
expectedArgs := []interface{}{true, "it's true!"}
96+
assert.Equal(t, expectedArgs, args)
97+
}
98+
99+
func TestMultipleCase(t *testing.T) {
100+
caseStmtNoval := Case(Expr("x = ?", true)).
101+
When("true", Expr("?", "it's true!")).
102+
Else("42")
103+
caseStmtExpr := Case().
104+
When(Eq{"x": 0}, "'x is zero'").
105+
When(Expr("x > ?", 1), Expr("CONCAT('x is greater than ', ?)", 2))
106+
107+
qb := Select().
108+
Column(Alias(caseStmtNoval, "case_noval")).
109+
Column(Alias(caseStmtExpr, "case_expr")).
110+
From("table")
111+
112+
sql, args, err := qb.ToSql()
113+
114+
assert.NoError(t, err)
115+
116+
expectedSql := "SELECT " +
117+
"(CASE x = ? WHEN true THEN ? ELSE 42 END) AS case_noval, " +
118+
"(CASE WHEN x = ? THEN 'x is zero' WHEN x > ? THEN CONCAT('x is greater than ', ?) END) AS case_expr " +
119+
"FROM table"
120+
121+
assert.Equal(t, expectedSql, sql)
122+
123+
expectedArgs := []interface{}{
124+
true, "it's true!",
125+
0, 1, 2,
126+
}
127+
assert.Equal(t, expectedArgs, args)
128+
}
129+
130+
func TestCaseWithNoWhenClause(t *testing.T) {
131+
caseStmt := Case("something").
132+
Else("42")
133+
134+
qb := Select().Column(caseStmt).From("table")
135+
136+
_, _, err := qb.ToSql()
137+
138+
assert.Error(t, err)
139+
140+
assert.Equal(t, "case expression must contain at lease one WHEN clause", err.Error())
141+
}

expr.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,28 @@ func (es exprs) AppendToSql(w io.Writer, sep string, args []interface{}) ([]inte
4444
return args, nil
4545
}
4646

47+
// aliasExpr helps to alias part of SQL query generated with underlying "expr"
48+
type aliasExpr struct {
49+
expr Sqlizer
50+
alias string
51+
}
52+
53+
// Alias allows to define alias for column in SelectBuilder. Useful when column is
54+
// defined as complex expression like IF or CASE
55+
// Ex:
56+
// .Column(Alias(caseStmt, "case_column"))
57+
func Alias(expr Sqlizer, alias string) aliasExpr {
58+
return aliasExpr{expr, alias}
59+
}
60+
61+
func (e aliasExpr) ToSql() (sql string, args []interface{}, err error) {
62+
sql, args, err = e.expr.ToSql()
63+
if err == nil {
64+
sql = fmt.Sprintf("(%s) AS %s", sql, e.alias)
65+
}
66+
return
67+
}
68+
4769
// Eq is syntactic sugar for use with Where/Having/Set methods.
4870
// Ex:
4971
// .Where(Eq{"id": 1})

select_test.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,15 @@ import (
77
)
88

99
func TestSelectBuilderToSql(t *testing.T) {
10+
subQ := Select("aa", "bb").From("dd")
1011
b := Select("a", "b").
1112
Prefix("WITH prefix AS ?", 0).
1213
Distinct().
1314
Columns("c").
1415
Column("IF(d IN ("+Placeholders(3)+"), 1, 0) as stat_column", 1, 2, 3).
1516
Column(Expr("a > ?", 100)).
16-
Column(Eq{"b": []int{101, 102, 103}}).
17+
Column(Alias(Eq{"b": []int{101, 102, 103}}, "b_alias")).
18+
Column(Alias(subQ, "subq")).
1719
From("e").
1820
JoinClause("CROSS JOIN j1").
1921
Join("j2").
@@ -36,7 +38,9 @@ func TestSelectBuilderToSql(t *testing.T) {
3638

3739
expectedSql :=
3840
"WITH prefix AS ? " +
39-
"SELECT DISTINCT a, b, c, IF(d IN (?,?,?), 1, 0) as stat_column, a > ?, b IN (?,?,?) " +
41+
"SELECT DISTINCT a, b, c, IF(d IN (?,?,?), 1, 0) as stat_column, a > ?, " +
42+
"(b IN (?,?,?)) AS b_alias, " +
43+
"(SELECT aa, bb FROM dd) AS subq " +
4044
"FROM e " +
4145
"CROSS JOIN j1 JOIN j2 LEFT JOIN j3 RIGHT JOIN j4 " +
4246
"WHERE f = ? AND g = ? AND h = ? AND i IN (?,?,?) AND (j = ? OR (k = ? AND true)) " +

statement.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,19 @@ func Update(table string) UpdateBuilder {
6565
func Delete(from string) DeleteBuilder {
6666
return StatementBuilder.Delete(from)
6767
}
68+
69+
// Case returns a new CaseBuilder
70+
// "what" represents case value
71+
func Case(what ...interface{}) CaseBuilder {
72+
b := CaseBuilder(builder.EmptyBuilder)
73+
74+
switch len(what) {
75+
case 0:
76+
case 1:
77+
b = b.what(what[0])
78+
default:
79+
b = b.what(newPart(what[0], what[1:]...))
80+
81+
}
82+
return b
83+
}

0 commit comments

Comments
 (0)