Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tweaked scan method to work with most sql drivers // Added Must #44

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 8 additions & 40 deletions ksuid.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package ksuid
import (
"bytes"
"crypto/rand"
"database/sql/driver"
"encoding/binary"
"fmt"
"io"
Expand Down Expand Up @@ -58,6 +57,14 @@ var (
Max = KSUID{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255}
)

// Must returns ksuid if err is nil and panics otherwise.
func Must(ksuid KSUID, err error) KSUID {
if err != nil {
panic(err)
}
return ksuid
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding this function seems unrelated to the modification of Scan itself, maybe we can split it in a separate PR?


// Append appends the string representation of i to b, returning a slice to a
// potentially larger memory area.
func (i KSUID) Append(b []byte) []byte {
Expand Down Expand Up @@ -134,45 +141,6 @@ func (i *KSUID) UnmarshalBinary(b []byte) error {
return nil
}

// Value converts the KSUID into a SQL driver value which can be used to
// directly use the KSUID as parameter to a SQL query.
func (i KSUID) Value() (driver.Value, error) {
if i.IsNil() {
return nil, nil
}
return i.String(), nil
}

// Scan implements the sql.Scanner interface. It supports converting from
// string, []byte, or nil into a KSUID value. Attempting to convert from
// another type will return an error.
func (i *KSUID) Scan(src interface{}) error {
switch v := src.(type) {
case nil:
return i.scan(nil)
case []byte:
return i.scan(v)
case string:
return i.scan([]byte(v))
default:
return fmt.Errorf("Scan: unable to scan type %T into KSUID", v)
}
}

func (i *KSUID) scan(b []byte) error {
switch len(b) {
case 0:
*i = Nil
return nil
case byteLength:
return i.UnmarshalBinary(b)
case stringEncodedLength:
return i.UnmarshalText(b)
default:
return errSize
}
}

// Parse decodes a string-encoded representation of a KSUID object
func Parse(s string) (KSUID, error) {
if len(s) != stringEncodedLength {
Expand Down
2 changes: 1 addition & 1 deletion ksuid_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ func TestGetTimestamp(t *testing.T) {
x, _ := NewRandomWithTime(nowTime)
xTime := int64(x.Timestamp())
unix := nowTime.Unix()
if xTime != unix - epochStamp {
if xTime != unix-epochStamp {
t.Fatal(xTime, "!=", unix)
}
}
Expand Down
56 changes: 56 additions & 0 deletions sql.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package ksuid

import (
"database/sql/driver"
"fmt"
)

// Scan implements the sql.Scanner interface. It supports converting from
// string, []byte, or nil into a KSUID value. Attempting to convert from
// another type will return an error.
func (i *KSUID) Scan(src interface{}) error {
switch src := src.(type) {
case nil:
return nil
case string:
// if an empty KSUID comes from a table, we return a null KSUID
if src == "" {
return nil
}
k, err := Parse(src)
if err != nil {
return fmt.Errorf("Scan: %v", err)
}
*i = k
case []byte:
// if an empty KSUID comes from a table, we return a null KSUID
if len(src) == 0 {
return nil
}
// assumes a simple slice of bytes if [byteLength] bytes
if len(src) == byteLength {
copy((*i)[:], src)
return nil
}

if len(src) == stringEncodedLength {
return i.Scan(string(src))
}

return i.Scan(string(src))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These recursive calls to Scan cause heap allocations when packing the string value into an interface{}, this is why I have the unexported scan methods to use internally.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@achille-roussel thanks for your time to review my request!

I see your Point.

Suggested change
return i.Scan(string(src))
return i.UnmarshalText(src)

Using the UnmarshalText is better. Both functions (scan old, scan new) are now allocating the same. Beside that the new Scan is consuming 2,79% less memory + is slightly faster.


default:
return fmt.Errorf("Scan: unable to scan type %T into KSUID", src)
}
return nil
}

// Value implements sql.Valuer so that KSUIDs can be written to databases
// transparently. Currently, KSUIDs map to strings. Please consult database-specific
// driver documentation for matching types.
func (i *KSUID) Value() (driver.Value, error) {
if i == nil || i.IsNil() {
return nil, nil
}
return i.String(), nil
}
113 changes: 113 additions & 0 deletions sql_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
package ksuid

import (
"strings"
"testing"
)

func TestScan(t *testing.T) {
stringTest := "1al9byIH8Ze6OLkD5tZqmByJkSX"
badTypeTest := 6
invalidTest := "1al9byIH8Ze6OLkD5tZqmByJk"

byteTest := make([]byte, byteLength)
byteTestKSUID := Must(Parse(stringTest))
copy(byteTest, byteTestKSUID[:])
textTest := []byte(stringTest)

// valid tests

var ksuid KSUID
err := (&ksuid).Scan(stringTest)
if err != nil {
t.Fatal(err)
}

err = (&ksuid).Scan([]byte(stringTest))
if err != nil {
t.Fatal(err)
}
err = (&ksuid).Scan(byteTest)
if err != nil {
t.Fatal(err)
}

err = (&ksuid).Scan(textTest)
if err != nil {
t.Fatal(err)
}

// bad type tests

err = (&ksuid).Scan(badTypeTest)
if err == nil {
t.Error("int correctly parsed and shouldn't have")
}
if !strings.Contains(err.Error(), "unable to scan type") {
t.Error("attempting to parse an int returned an incorrect error message")
}

// invalid/incomplete ksuids
err = (&ksuid).Scan(invalidTest)
if err == nil {
t.Error("invalid uuid was parsed without error")
}
if !strings.Contains(err.Error(), "Valid encoded KSUIDs") {
t.Error("attempting to parse an invalid KSUID returned an incorrect error message")
}

err = (&ksuid).Scan(byteTest[:len(byteTest)-2])
if err == nil {
t.Error("invalid byte ksuid was parsed without error")
}
if !strings.Contains(err.Error(), "Valid encoded KSUIDs") {
t.Error("attempting to parse an invalid byte KSUID returned an incorrect error message")
}

// empty tests

ksuid = KSUID{}
var emptySlice []byte
err = (&ksuid).Scan(emptySlice)
if err != nil {
t.Fatal(err)
}

for _, v := range ksuid {
if v != 0 {
t.Error("KSUID was not nil after scanning empty byte slice")
}
}

ksuid = KSUID{}
var emptyString string
err = (&ksuid).Scan(emptyString)
if err != nil {
t.Fatal(err)
}
for _, v := range ksuid {
if v != 0 {
t.Error("KSUID was not nil after scanning empty string")
}
}

ksuid = KSUID{}
err = (&ksuid).Scan(nil)
if err != nil {
t.Fatal(err)
}
for _, v := range ksuid {
if v != 0 {
t.Error("KSUID was not nil after scanning nil")
}
}
}

func TestValue(t *testing.T) {
stringTest := "1al9byIH8Ze6OLkD5tZqmByJkSX"
ksuid := Must(Parse(stringTest))
val, _ := ksuid.Value()
if val != stringTest {
t.Error("Value() did ot return expected string")
}
}