diff --git a/ksuid.go b/ksuid.go index dbe1f9c..b8f2eab 100644 --- a/ksuid.go +++ b/ksuid.go @@ -3,7 +3,6 @@ package ksuid import ( "bytes" "crypto/rand" - "database/sql/driver" "encoding/binary" "fmt" "io" @@ -58,6 +57,14 @@ var ( Max = KSUID{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255} ) +// Must returns ksuid if err is nil and panics otherwise. +func Must(ksuid KSUID, err error) KSUID { + if err != nil { + panic(err) + } + return ksuid +} + // Append appends the string representation of i to b, returning a slice to a // potentially larger memory area. func (i KSUID) Append(b []byte) []byte { @@ -134,45 +141,6 @@ func (i *KSUID) UnmarshalBinary(b []byte) error { return nil } -// Value converts the KSUID into a SQL driver value which can be used to -// directly use the KSUID as parameter to a SQL query. -func (i KSUID) Value() (driver.Value, error) { - if i.IsNil() { - return nil, nil - } - return i.String(), nil -} - -// Scan implements the sql.Scanner interface. It supports converting from -// string, []byte, or nil into a KSUID value. Attempting to convert from -// another type will return an error. -func (i *KSUID) Scan(src interface{}) error { - switch v := src.(type) { - case nil: - return i.scan(nil) - case []byte: - return i.scan(v) - case string: - return i.scan([]byte(v)) - default: - return fmt.Errorf("Scan: unable to scan type %T into KSUID", v) - } -} - -func (i *KSUID) scan(b []byte) error { - switch len(b) { - case 0: - *i = Nil - return nil - case byteLength: - return i.UnmarshalBinary(b) - case stringEncodedLength: - return i.UnmarshalText(b) - default: - return errSize - } -} - // Parse decodes a string-encoded representation of a KSUID object func Parse(s string) (KSUID, error) { if len(s) != stringEncodedLength { diff --git a/ksuid_test.go b/ksuid_test.go index 3620187..e8f9435 100644 --- a/ksuid_test.go +++ b/ksuid_test.go @@ -309,7 +309,7 @@ func TestGetTimestamp(t *testing.T) { x, _ := NewRandomWithTime(nowTime) xTime := int64(x.Timestamp()) unix := nowTime.Unix() - if xTime != unix - epochStamp { + if xTime != unix-epochStamp { t.Fatal(xTime, "!=", unix) } } diff --git a/sql.go b/sql.go new file mode 100644 index 0000000..af8a9c8 --- /dev/null +++ b/sql.go @@ -0,0 +1,56 @@ +package ksuid + +import ( + "database/sql/driver" + "fmt" +) + +// Scan implements the sql.Scanner interface. It supports converting from +// string, []byte, or nil into a KSUID value. Attempting to convert from +// another type will return an error. +func (i *KSUID) Scan(src interface{}) error { + switch src := src.(type) { + case nil: + return nil + case string: + // if an empty KSUID comes from a table, we return a null KSUID + if src == "" { + return nil + } + k, err := Parse(src) + if err != nil { + return fmt.Errorf("Scan: %v", err) + } + *i = k + case []byte: + // if an empty KSUID comes from a table, we return a null KSUID + if len(src) == 0 { + return nil + } + // assumes a simple slice of bytes if [byteLength] bytes + if len(src) == byteLength { + copy((*i)[:], src) + return nil + } + + if len(src) == stringEncodedLength { + return i.Scan(string(src)) + } + + return i.Scan(string(src)) + + default: + return fmt.Errorf("Scan: unable to scan type %T into KSUID", src) + } + return nil +} + +// Value implements sql.Valuer so that KSUIDs can be written to databases +// transparently. Currently, KSUIDs map to strings. Please consult database-specific +// driver documentation for matching types. +func (i *KSUID) Value() (driver.Value, error) { + if i == nil || i.IsNil() { + return nil, nil + } + return i.String(), nil +} diff --git a/sql_test.go b/sql_test.go new file mode 100644 index 0000000..68996fb --- /dev/null +++ b/sql_test.go @@ -0,0 +1,113 @@ +package ksuid + +import ( + "strings" + "testing" +) + +func TestScan(t *testing.T) { + stringTest := "1al9byIH8Ze6OLkD5tZqmByJkSX" + badTypeTest := 6 + invalidTest := "1al9byIH8Ze6OLkD5tZqmByJk" + + byteTest := make([]byte, byteLength) + byteTestKSUID := Must(Parse(stringTest)) + copy(byteTest, byteTestKSUID[:]) + textTest := []byte(stringTest) + + // valid tests + + var ksuid KSUID + err := (&ksuid).Scan(stringTest) + if err != nil { + t.Fatal(err) + } + + err = (&ksuid).Scan([]byte(stringTest)) + if err != nil { + t.Fatal(err) + } + err = (&ksuid).Scan(byteTest) + if err != nil { + t.Fatal(err) + } + + err = (&ksuid).Scan(textTest) + if err != nil { + t.Fatal(err) + } + + // bad type tests + + err = (&ksuid).Scan(badTypeTest) + if err == nil { + t.Error("int correctly parsed and shouldn't have") + } + if !strings.Contains(err.Error(), "unable to scan type") { + t.Error("attempting to parse an int returned an incorrect error message") + } + + // invalid/incomplete ksuids + err = (&ksuid).Scan(invalidTest) + if err == nil { + t.Error("invalid uuid was parsed without error") + } + if !strings.Contains(err.Error(), "Valid encoded KSUIDs") { + t.Error("attempting to parse an invalid KSUID returned an incorrect error message") + } + + err = (&ksuid).Scan(byteTest[:len(byteTest)-2]) + if err == nil { + t.Error("invalid byte ksuid was parsed without error") + } + if !strings.Contains(err.Error(), "Valid encoded KSUIDs") { + t.Error("attempting to parse an invalid byte KSUID returned an incorrect error message") + } + + // empty tests + + ksuid = KSUID{} + var emptySlice []byte + err = (&ksuid).Scan(emptySlice) + if err != nil { + t.Fatal(err) + } + + for _, v := range ksuid { + if v != 0 { + t.Error("KSUID was not nil after scanning empty byte slice") + } + } + + ksuid = KSUID{} + var emptyString string + err = (&ksuid).Scan(emptyString) + if err != nil { + t.Fatal(err) + } + for _, v := range ksuid { + if v != 0 { + t.Error("KSUID was not nil after scanning empty string") + } + } + + ksuid = KSUID{} + err = (&ksuid).Scan(nil) + if err != nil { + t.Fatal(err) + } + for _, v := range ksuid { + if v != 0 { + t.Error("KSUID was not nil after scanning nil") + } + } +} + +func TestValue(t *testing.T) { + stringTest := "1al9byIH8Ze6OLkD5tZqmByJkSX" + ksuid := Must(Parse(stringTest)) + val, _ := ksuid.Value() + if val != stringTest { + t.Error("Value() did ot return expected string") + } +}