Skip to content

Commit e50ea8c

Browse files
committed
Merge branch 'feature/gob' into main
2 parents 6b6eac2 + c271ca9 commit e50ea8c

File tree

3 files changed

+76
-29
lines changed

3 files changed

+76
-29
lines changed

criteria.go

Lines changed: 37 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package clover
22

33
import (
44
"regexp"
5+
"strings"
56

67
"github.com/ostafen/clover/encoding"
78
)
@@ -62,13 +63,12 @@ func (f *field) IsNilOrNotExists() *Criteria {
6263
}
6364

6465
func (f *field) Eq(value interface{}) *Criteria {
65-
normalizedValue, err := encoding.Normalize(value)
66-
if err != nil {
67-
return &falseCriteria
68-
}
69-
7066
return &Criteria{
7167
p: func(doc *Document) bool {
68+
normalizedValue, err := encoding.Normalize(getFieldOrValue(doc, value))
69+
if err != nil {
70+
return false
71+
}
7272
if !doc.Has(f.name) {
7373
return false
7474
}
@@ -78,52 +78,48 @@ func (f *field) Eq(value interface{}) *Criteria {
7878
}
7979

8080
func (f *field) Gt(value interface{}) *Criteria {
81-
normValue, err := encoding.Normalize(value)
82-
if err != nil {
83-
return &falseCriteria
84-
}
85-
8681
return &Criteria{
8782
p: func(doc *Document) bool {
83+
normValue, err := encoding.Normalize(getFieldOrValue(doc, value))
84+
if err != nil {
85+
return false
86+
}
8887
return compareValues(doc.Get(f.name), normValue) > 0
8988
},
9089
}
9190
}
9291

9392
func (f *field) GtEq(value interface{}) *Criteria {
94-
normValue, err := encoding.Normalize(value)
95-
if err != nil {
96-
return &falseCriteria
97-
}
98-
9993
return &Criteria{
10094
p: func(doc *Document) bool {
95+
normValue, err := encoding.Normalize(getFieldOrValue(doc, value))
96+
if err != nil {
97+
return false
98+
}
10199
return compareValues(doc.Get(f.name), normValue) >= 0
102100
},
103101
}
104102
}
105103

106104
func (f *field) Lt(value interface{}) *Criteria {
107-
normValue, err := encoding.Normalize(value)
108-
if err != nil {
109-
return &falseCriteria
110-
}
111-
112105
return &Criteria{
113106
p: func(doc *Document) bool {
107+
normValue, err := encoding.Normalize(getFieldOrValue(doc, value))
108+
if err != nil {
109+
return false
110+
}
114111
return compareValues(doc.Get(f.name), normValue) < 0
115112
},
116113
}
117114
}
118115

119116
func (f *field) LtEq(value interface{}) *Criteria {
120-
normValue, err := encoding.Normalize(value)
121-
if err != nil {
122-
return &falseCriteria
123-
}
124-
125117
return &Criteria{
126118
p: func(doc *Document) bool {
119+
normValue, err := encoding.Normalize(getFieldOrValue(doc, value))
120+
if err != nil {
121+
return false
122+
}
127123
return compareValues(doc.Get(f.name), normValue) <= 0
128124
},
129125
}
@@ -142,8 +138,9 @@ func (f *field) In(values ...interface{}) *Criteria {
142138
return &Criteria{
143139
p: func(doc *Document) bool {
144140
docValue := doc.Get(f.name)
145-
for _, value := range normValues.([]interface{}) {
146-
if compareValues(value, docValue) == 0 {
141+
for _, v := range values {
142+
normValue, err := encoding.Normalize(getFieldOrValue(doc, v))
143+
if err == nil && compareValues(normValue, docValue) == 0 {
147144
return true
148145
}
149146
}
@@ -164,7 +161,7 @@ func (f *field) Contains(elems ...interface{}) *Criteria {
164161

165162
for _, elem := range elems {
166163
found := false
167-
normElem, err := encoding.Normalize(elem)
164+
normElem, err := encoding.Normalize(getFieldOrValue(doc, elem))
168165

169166
if err == nil {
170167
for _, val := range slice {
@@ -241,3 +238,15 @@ func (c *Criteria) Not() *Criteria {
241238
p: negatePredicate(c.p),
242239
}
243240
}
241+
242+
// getFieldOrValue returns dereferenced value if value denotes another document field,
243+
// otherwise returns the value itself directly
244+
func getFieldOrValue(doc *Document, value interface{}) interface{} {
245+
if cmpField, ok := value.(*field); ok {
246+
value = doc.Get(cmpField.name)
247+
} else if fStr, ok := value.(string); ok && strings.HasPrefix(fStr, "$") {
248+
fieldName := strings.TrimLeft(fStr, "$")
249+
value = doc.Get(fieldName)
250+
}
251+
return value
252+
}

db.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ func (db *DB) Save(collectionName string, doc *Document) error {
7676
// InsertOne inserts a single document to an existing collection. It returns the id of the inserted document.
7777
func (db *DB) InsertOne(collectionName string, doc *Document) (string, error) {
7878
err := db.Insert(collectionName, doc)
79-
return doc.Get(objectIdField).(string), err
79+
return doc.ObjectId(), err
8080
}
8181

8282
// Open opens a new clover database on the supplied path. If such a folder doesn't exist, it is automatically created.

db_test.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -688,6 +688,17 @@ func TestInCriteria(t *testing.T) {
688688
require.Fail(t, "userId is not in the correct range")
689689
}
690690
}
691+
692+
criteria := c.Field("userId").In(c.Field("id"), 6)
693+
docs, err = db.Query("todos").Where(criteria).FindAll()
694+
require.NoError(t, err)
695+
696+
require.Greater(t, len(docs), 0)
697+
for _, doc := range docs {
698+
userId := doc.Get("userId").(int64)
699+
id := doc.Get("id").(uint64)
700+
require.True(t, uint64(userId) == id || userId == 6)
701+
}
691702
})
692703
}
693704

@@ -1327,3 +1338,30 @@ func TestCompareObjects3(t *testing.T) {
13271338
require.Equal(t, docs[0].Get("data.SomeString"), "aStr")
13281339
})
13291340
}
1341+
1342+
func TestCompareDocumentFields(t *testing.T) {
1343+
runCloverTest(t, airlinesPath, nil, func(t *testing.T, db *c.DB) {
1344+
criteria := c.Field("Statistics.Flights.Diverted").Gt(c.Field("Statistics.Flights.Cancelled"))
1345+
docs, err := db.Query("airlines").Where(criteria).FindAll()
1346+
require.NoError(t, err)
1347+
1348+
require.Greater(t, len(docs), 0)
1349+
for _, doc := range docs {
1350+
diverted := doc.Get("Statistics.Flights.Diverted").(float64)
1351+
cancelled := doc.Get("Statistics.Flights.Cancelled").(float64)
1352+
require.Greater(t, diverted, cancelled)
1353+
}
1354+
1355+
//alternative syntax using $
1356+
criteria = c.Field("Statistics.Flights.Diverted").Gt("$Statistics.Flights.Cancelled")
1357+
docs, err = db.Query("airlines").Where(criteria).FindAll()
1358+
require.NoError(t, err)
1359+
1360+
require.Greater(t, len(docs), 0)
1361+
for _, doc := range docs {
1362+
diverted := doc.Get("Statistics.Flights.Diverted").(float64)
1363+
cancelled := doc.Get("Statistics.Flights.Cancelled").(float64)
1364+
require.Greater(t, diverted, cancelled)
1365+
}
1366+
})
1367+
}

0 commit comments

Comments
 (0)