diff --git a/compressor_test.go b/compressor_test.go index 835093932..24730c68d 100644 --- a/compressor_test.go +++ b/compressor_test.go @@ -2,9 +2,13 @@ package gocql import ( "bytes" + "strings" "testing" "github.com/golang/snappy" + "github.com/google/go-cmp/cmp" + + "github.com/gocql/gocql/lz4" ) func TestSnappyCompressor(t *testing.T) { @@ -36,3 +40,63 @@ func TestSnappyCompressor(t *testing.T) { t.Fatal("failed to match the expected decoded value with the result decoded value.") } } + +func TestBlobCompressor(t *testing.T) { + session := createSession(t) + defer session.Close() + // TypeVarchar, TypeAscii, TypeBlob, TypeText + if err := createTable(session, `CREATE TABLE gocql_test.test_blob_compressor ( + testuuid timeuuid PRIMARY KEY, + testvarchar varchar, + testblob blob, + testascii ascii, + testtext text, + )`); err != nil { + t.Fatal("create table:", err) + } + m := make(map[string]interface{}) + + BlobCompressor = lz4.NewBlobCompressor(100) + + originalBlob := strings.Repeat("1234567890", 20) + + m["testuuid"] = TimeUUID() + m["testvarchar"] = originalBlob + m["testblob"] = []byte(originalBlob) + m["testascii"] = originalBlob + m["testtext"] = originalBlob + sliceMap := []map[string]interface{}{m} + if err := session.Query(`INSERT INTO gocql_test.test_blob_compressor (testuuid, testvarchar, testblob, testascii, testtext) VALUES (?, ?, ?, ?, ?)`, + m["testuuid"], m["testvarchar"], m["testblob"], m["testascii"], m["testtext"]).Exec(); err != nil { + t.Fatal("insert:", err) + } + if returned, retErr := session.Query(`SELECT * FROM test_blob_compressor`).Iter().SliceMap(); retErr != nil { + t.Fatal("select:", retErr) + } else { + if diff := cmp.Diff(sliceMap, returned); diff != "" { + t.Fatal("mismatch in returned map", diff) + } + } + + // Test for Iter.MapScan() + { + testMap := make(map[string]interface{}) + if !session.Query(`SELECT * FROM test_blob_compressor`).Iter().MapScan(testMap) { + t.Fatal("MapScan failed to work with one row") + } + if diff := cmp.Diff(sliceMap[0], testMap); diff != "" { + t.Fatal("mismatch in returned map", diff) + } + } + + // Test for Query.MapScan() + { + testMap := make(map[string]interface{}) + if session.Query(`SELECT * FROM test_blob_compressor`).MapScan(testMap) != nil { + t.Fatal("MapScan failed to work with one row") + } + if diff := cmp.Diff(sliceMap[0], testMap); diff != "" { + t.Fatal("mismatch in returned map", diff) + } + } +} diff --git a/go.mod b/go.mod index b8402b63a..796b14b64 100644 --- a/go.mod +++ b/go.mod @@ -3,21 +3,23 @@ module github.com/gocql/gocql require ( github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 // indirect github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 // indirect + github.com/gocql/gocql/lz4 v0.0.0-20240625120741-974fa1211cce // indirect github.com/golang/snappy v0.0.3 github.com/google/go-cmp v0.4.0 github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed github.com/kr/pretty v0.1.0 // indirect - github.com/stretchr/testify v1.3.0 // indirect golang.org/x/net v0.0.0-20220526153639-5463443f8c37 gopkg.in/inf.v0 v0.9.1 sigs.k8s.io/yaml v1.3.0 ) retract ( - v1.8.0 // tag from kiwicom/gocql added by mistake to scylladb/gocql - v1.8.1 // tag from kiwicom/gocql added by mistake to scylladb/gocql - v1.9.0 // tag from kiwicom/gocql added by mistake to scylladb/gocql - v1.10.0 // tag from kiwicom/gocql added by mistake to scylladb/gocql + v1.10.0 // tag from kiwicom/gocql added by mistake to scylladb/gocql + v1.9.0 // tag from kiwicom/gocql added by mistake to scylladb/gocql + v1.8.1 // tag from kiwicom/gocql added by mistake to scylladb/gocql + v1.8.0 // tag from kiwicom/gocql added by mistake to scylladb/gocql ) +replace github.com/gocql/gocql/lz4 => ./lz4/ + go 1.13 diff --git a/go.sum b/go.sum index a85b7c363..e2b537b9e 100644 --- a/go.sum +++ b/go.sum @@ -16,11 +16,14 @@ github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORN github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/pierrec/lz4/v4 v4.1.8 h1:ieHkV+i2BRzngO4Wd/3HGowuZStgq6QkPsD1eolNAO4= +github.com/pierrec/lz4/v4 v4.1.8/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= golang.org/x/net v0.0.0-20220526153639-5463443f8c37 h1:lUkvobShwKsOesNfWWlCS5q7fnbG1MEliIzwu886fn8= golang.org/x/net v0.0.0-20220526153639-5463443f8c37/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -36,5 +39,6 @@ gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc= gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= sigs.k8s.io/yaml v1.3.0 h1:a2VclLzOGrwOHDiV8EfBGhvjHvP46CtW5j6POvhYGGo= sigs.k8s.io/yaml v1.3.0/go.mod h1:GeOyir5tyXNByN85N/dRIT9es5UQNerPYEKK56eTBm8= diff --git a/lz4/blob_compressor.go b/lz4/blob_compressor.go new file mode 100644 index 000000000..129ea27e2 --- /dev/null +++ b/lz4/blob_compressor.go @@ -0,0 +1,62 @@ +package lz4 + +import ( + "bytes" + "encoding/binary" + "fmt" + + "github.com/pierrec/lz4/v4" +) + +type BlobCompressor struct { + prefix []byte + prefixPlusLen int + limit int +} + +var defaultPrefix = []byte{0x01, 0x11, 0x22, 0x33} + +func NewBlobCompressor(limit int) BlobCompressor { + return BlobCompressor{ + prefix: defaultPrefix, + prefixPlusLen: len(defaultPrefix) + 4, + limit: limit, + } +} + +func (c BlobCompressor) Compress(data []byte) ([]byte, error) { + if len(data) < c.limit { + return data, nil + } + buf := make([]byte, len(c.prefix)+lz4.CompressBlockBound(len(data)+4)) + copy(buf, c.prefix) + + var compressor lz4.Compressor + + n, err := compressor.CompressBlock(data, buf[c.prefixPlusLen:]) + // According to lz4.CompressBlock doc, it doesn't fail as long as the dst + // buffer length is at least lz4.CompressBlockBound(len(data))) bytes, but + // we check for error anyway just to be thorough. + if err != nil { + return nil, err + } + binary.BigEndian.PutUint32(buf[len(c.prefix):], uint32(len(data))) + + return buf[:c.prefixPlusLen+n], nil +} + +func (c BlobCompressor) Decompress(data []byte) ([]byte, error) { + if !bytes.HasPrefix(data, c.prefix) { + return data, nil + } + if len(data) < 4+len(c.prefix) { + return nil, fmt.Errorf("compressed data should be >4, got=%d", len(data)) + } + uncompressedLength := binary.BigEndian.Uint32(data[len(c.prefix):]) + if uncompressedLength == 0 { + return nil, nil + } + buf := make([]byte, uncompressedLength) + n, err := lz4.UncompressBlock(data[c.prefixPlusLen:], buf) + return buf[:n], err +} diff --git a/marshal.go b/marshal.go index 6edd78ac0..bc0ec007b 100644 --- a/marshal.go +++ b/marshal.go @@ -42,6 +42,13 @@ type Unmarshaler interface { UnmarshalCQL(info TypeInfo, data []byte) error } +type blobCompressor interface { + Compress([]byte) ([]byte, error) + Decompress([]byte) ([]byte, error) +} + +var BlobCompressor blobCompressor + // Marshal returns the CQL encoding of the value for the Cassandra // internal type described by the info parameter. // @@ -110,8 +117,10 @@ func Marshal(info TypeInfo, value interface{}) ([]byte, error) { } switch info.Type() { - case TypeVarchar, TypeAscii, TypeBlob, TypeText: + case TypeBlob, TypeText: return marshalVarchar(info, value) + case TypeAscii, TypeVarchar: + return marshalVarcharRaw(info, value) case TypeBoolean: return marshalBool(info, value) case TypeTinyInt: @@ -212,8 +221,10 @@ func Unmarshal(info TypeInfo, data []byte, value interface{}) error { } switch info.Type() { - case TypeVarchar, TypeAscii, TypeBlob, TypeText: + case TypeBlob, TypeText: return unmarshalVarchar(info, data, value) + case TypeVarchar, TypeAscii: + return unmarshalVarcharRaw(info, data, value) case TypeBoolean: return unmarshalBool(info, data, value) case TypeInt: @@ -289,6 +300,17 @@ func unmarshalNullable(info TypeInfo, data []byte, value interface{}) error { } func marshalVarchar(info TypeInfo, value interface{}) ([]byte, error) { + v, err := marshalVarcharRaw(info, value) + if err != nil { + return nil, err + } + if BlobCompressor == nil { + return v, nil + } + return BlobCompressor.Compress(v) +} + +func marshalVarcharRaw(info TypeInfo, value interface{}) ([]byte, error) { switch v := value.(type) { case Marshaler: return v.MarshalCQL(info) @@ -316,7 +338,17 @@ func marshalVarchar(info TypeInfo, value interface{}) ([]byte, error) { return nil, marshalErrorf("can not marshal %T into %s", value, info) } -func unmarshalVarchar(info TypeInfo, data []byte, value interface{}) error { +func unmarshalVarchar(info TypeInfo, data []byte, value interface{}) (err error) { + if BlobCompressor != nil { + data, err = BlobCompressor.Decompress(data) + if err != nil { + return err + } + } + return unmarshalVarcharRaw(info, data, value) +} + +func unmarshalVarcharRaw(info TypeInfo, data []byte, value interface{}) error { switch v := value.(type) { case Unmarshaler: return v.UnmarshalCQL(info, data)