-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsql_escape.go
82 lines (78 loc) · 1.79 KB
/
sql_escape.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
package gotemplate
import (
"encoding/json"
"fmt"
"reflect"
"strconv"
"strings"
)
// provided template func
func sqlEscape(q interface{}) string {
if q == nil {
return "NULL"
}
return sqlEscapeType(reflect.ValueOf(q))
}
// sqlEscapeType uses Reflect to detect and handle each different type
// and escape it accordingly
func sqlEscapeType(value reflect.Value) string {
var escaped string
switch value.Kind() {
case reflect.String:
escaped = escapeString(value.String())
case reflect.Slice:
vals := make([]string, 0, value.Len())
for i := 0; i < value.Len(); i++ {
vals = append(vals, sqlEscapeType(value.Index(i)))
}
escaped = strings.Join(vals, ", ")
case reflect.Int:
escaped = strconv.FormatInt(value.Int(), 10)
case reflect.Float32:
escaped = strconv.FormatFloat(value.Float(), 'f', -1, 32)
case reflect.Float64:
escaped = strconv.FormatFloat(value.Float(), 'f', -1, 64)
default:
b, err := json.Marshal(value.Interface())
if err != nil {
panic(err)
}
escaped = sqlEscapeType(reflect.ValueOf(string(b)))
}
return escaped
}
// escapeString, escapes unwanted characters from strings
// taken from https://gist.github.com/siddontang/8875771
func escapeString(source string) string {
dest := make([]byte, 0, 2*len(source))
var escape byte
for i := 0; i < len(source); i++ {
c := source[i]
escape = 0
switch c {
case 0: /* Must be escaped for 'mysql' */
escape = '0'
break
case '\n': /* Must be escaped for logs */
escape = 'n'
break
case '\r':
escape = 'r'
break
case '\\':
escape = '\\'
break
case '\'':
escape = '\''
break
case '\032': /* This gives problems on Win32 */
escape = 'Z'
}
if escape != 0 {
dest = append(dest, '\\', escape)
} else {
dest = append(dest, c)
}
}
return fmt.Sprintf("'%s'", dest)
}