Skip to content

Commit 0b0d44b

Browse files
author
José Nieto
authored
Merge pull request #377 from upper/issue-370
Fixes #370
2 parents 4d7953a + 8ae1524 commit 0b0d44b

File tree

9 files changed

+510
-38
lines changed

9 files changed

+510
-38
lines changed

internal/sqladapter/collection.go

+5-2
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,8 @@ func (c *collection) InsertReturning(item interface{}) error {
163163
}
164164

165165
// Fetch the row that was just interted into newItem
166-
if err = col.Find(id).One(newItem); err != nil {
166+
err = col.Find(id).One(newItem)
167+
if err != nil {
167168
goto cancel
168169
}
169170

@@ -184,14 +185,16 @@ func (c *collection) InsertReturning(item interface{}) error {
184185
itemV.SetMapIndex(keyV, newItemV.MapIndex(keyV))
185186
}
186187
default:
187-
panic("default")
188+
err = fmt.Errorf("InsertReturning: expecting a pointer to map or struct, got %T", newItem)
189+
goto cancel
188190
}
189191

190192
if !inTx {
191193
// This is only executed if t.Database() was **not** a transaction and if
192194
// sess was created with sess.NewTransaction().
193195
return tx.Commit()
194196
}
197+
195198
return err
196199

197200
cancel:

internal/sqladapter/sqladapter.go

+5-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,11 @@ func IsKeyValue(v interface{}) bool {
3333
return true
3434
}
3535
switch v.(type) {
36-
case int64, int, uint, uint64, driver.Valuer:
36+
case int64, int, uint, uint64,
37+
[]int64, []int, []uint, []uint64,
38+
[]byte, []string,
39+
[]interface{},
40+
driver.Valuer:
3741
return true
3842
}
3943
return false

internal/sqladapter/testing/adapter.go.tpl

+30-15
Original file line numberDiff line numberDiff line change
@@ -1053,26 +1053,41 @@ func TestCompositeKeys(t *testing.T) {
10531053
10541054
compositeKeys := sess.Collection("composite_keys")
10551055
1056-
n := rand.Intn(100000)
1056+
{
1057+
n := rand.Intn(100000)
10571058
1058-
item := itemWithCompoundKey{
1059-
"ABCDEF",
1060-
strconv.Itoa(n),
1061-
"Some value",
1062-
}
1059+
item := itemWithCompoundKey{
1060+
"ABCDEF",
1061+
strconv.Itoa(n),
1062+
"Some value",
1063+
}
10631064
1064-
id, err := compositeKeys.Insert(&item)
1065-
assert.NoError(t, err)
1066-
assert.NotZero(t, id)
1065+
id, err := compositeKeys.Insert(&item)
1066+
assert.NoError(t, err)
1067+
assert.NotZero(t, id)
10671068
1068-
var item2 itemWithCompoundKey
1069-
assert.NotEqual(t, item2.SomeVal, item.SomeVal)
1069+
var item2 itemWithCompoundKey
1070+
assert.NotEqual(t, item2.SomeVal, item.SomeVal)
10701071
1071-
// Finding by ID
1072-
err = compositeKeys.Find(id).One(&item2)
1073-
assert.NoError(t, err)
1072+
// Finding by ID
1073+
err = compositeKeys.Find(id).One(&item2)
1074+
assert.NoError(t, err)
10741075
1075-
assert.Equal(t, item2.SomeVal, item.SomeVal)
1076+
assert.Equal(t, item2.SomeVal, item.SomeVal)
1077+
}
1078+
1079+
{
1080+
n := rand.Intn(100000)
1081+
1082+
item := itemWithCompoundKey{
1083+
"ABCDEF",
1084+
strconv.Itoa(n),
1085+
"Some value",
1086+
}
1087+
1088+
err := compositeKeys.InsertReturning(&item)
1089+
assert.NoError(t, err)
1090+
}
10761091
10771092
assert.NoError(t, cleanUpCheck(sess))
10781093
assert.NoError(t, sess.Close())

internal/sqladapter/tx.go

+5-3
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,12 @@ func (b *baseTx) Committed() bool {
7878
}
7979

8080
func (b *baseTx) Commit() (err error) {
81-
if err = b.Tx.Commit(); err == nil {
82-
b.committed.Store(struct{}{})
81+
err = b.Tx.Commit()
82+
if err != nil {
83+
return err
8384
}
84-
return err
85+
b.committed.Store(struct{}{})
86+
return nil
8587
}
8688

8789
func (w *databaseTx) Commit() error {

lib/sqlbuilder/convert.go

+6
Original file line numberDiff line numberDiff line change
@@ -352,9 +352,15 @@ func (tu *templateWithUtils) toColumnValues(term interface{}) (cv exql.ColumnVal
352352
q, a := Preprocess(value.Raw(), value.Arguments())
353353
columnValue.Value = exql.RawValue(q)
354354
args = append(args, a...)
355+
case driver.Valuer:
356+
columnValue.Value = exql.RawValue("?")
357+
args = append(args, value)
355358
default:
356359
v, isSlice := toInterfaceArguments(value)
357360

361+
//valuer, ok := value.(driver.Valuer)
362+
//log.Printf("valuer: %v, ok: %v, (%v) %T", valuer, ok, value, value)
363+
358364
if isSlice {
359365
if columnValue.Operator == "" {
360366
columnValue.Operator = sqlInOperator

mssql/collection.go

+29-9
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@ package mssql
2323

2424
import (
2525
"database/sql"
26-
"log"
27-
"strings"
2826

2927
"upper.io/db.v3"
3028
"upper.io/db.v3/internal/sqladapter"
@@ -37,6 +35,8 @@ type table struct {
3735

3836
d *database
3937
name string
38+
39+
hasIdentityColumn *bool
4040
}
4141

4242
var (
@@ -76,7 +76,6 @@ func (t *table) Insert(item interface{}) (interface{}, error) {
7676
for j := 0; j < len(pKey); j++ {
7777
if pKey[j] == columnNames[i] {
7878
if columnValues[i] != nil {
79-
log.Printf("%v -- %v", pKey[j], columnValues[i])
8079
hasKeys = true
8180
break
8281
}
@@ -85,13 +84,34 @@ func (t *table) Insert(item interface{}) (interface{}, error) {
8584
}
8685

8786
if hasKeys {
88-
_, err = t.d.Exec("SET IDENTITY_INSERT " + t.Name() + " ON")
89-
// TODO: Find a way to check if the table has composite keys without an
90-
// identity property.
91-
if err != nil && !strings.Contains(err.Error(), "does not have the identity property") {
92-
return nil, err
87+
if t.hasIdentityColumn == nil {
88+
var hasIdentityColumn bool
89+
var identityColumns int
90+
91+
row, err := t.d.QueryRow("SELECT COUNT(1) FROM sys.identity_columns WHERE OBJECT_NAME(object_id) = ?", t.Name())
92+
if err != nil {
93+
return nil, err
94+
}
95+
96+
err = row.Scan(&identityColumns)
97+
if err != nil {
98+
return nil, err
99+
}
100+
101+
if identityColumns > 0 {
102+
hasIdentityColumn = true
103+
}
104+
105+
t.hasIdentityColumn = &hasIdentityColumn
106+
}
107+
108+
if *t.hasIdentityColumn {
109+
_, err = t.d.Exec("SET IDENTITY_INSERT " + t.Name() + " ON")
110+
if err != nil {
111+
return nil, err
112+
}
113+
defer t.d.Exec("SET IDENTITY_INSERT " + t.Name() + " OFF")
93114
}
94-
defer t.d.Exec("SET IDENTITY_INSERT " + t.Name() + " OFF")
95115
}
96116

97117
q := t.d.InsertInto(t.Name()).

mssql/database.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -198,10 +198,10 @@ func (d *database) NewDatabaseTx(ctx context.Context) (sqladapter.DatabaseTx, er
198198

199199
connFn := func() error {
200200
sqlTx, err := compat.BeginTx(clone.BaseDatabase.Session(), ctx, nil)
201-
if err == nil {
202-
return clone.BindTx(ctx, sqlTx)
201+
if err != nil {
202+
return err
203203
}
204-
return err
204+
return clone.BindTx(ctx, sqlTx)
205205
}
206206

207207
if err := d.BaseDatabase.WaitForConnection(connFn); err != nil {

0 commit comments

Comments
 (0)