diff --git a/batch_test.go b/batch_test.go index 44b52663f..cfa99a7ca 100644 --- a/batch_test.go +++ b/batch_test.go @@ -39,7 +39,7 @@ func TestBatch_Errors(t *testing.T) { session := createSession(t) defer session.Close() - if session.cfg.ProtoVersion < protoVersion2 { + if session.cfg.ProtoVersion < internal.ProtoVersion2 { t.Skip("atomic batches not supported. Please use Cassandra >= 2.0") } @@ -58,7 +58,7 @@ func TestBatch_WithTimestamp(t *testing.T) { session := createSession(t) defer session.Close() - if session.cfg.ProtoVersion < protoVersion3 { + if session.cfg.ProtoVersion < internal.ProtoVersion3 { t.Skip("Batch timestamps are only available on protocol >= 3") } diff --git a/cassandra_test.go b/cassandra_test.go index ec6969190..c92764a10 100644 --- a/cassandra_test.go +++ b/cassandra_test.go @@ -1120,7 +1120,7 @@ func TestSmallInt(t *testing.T) { session := createSession(t) defer session.Close() - if session.cfg.ProtoVersion < protoVersion4 { + if session.cfg.ProtoVersion < internal.ProtoVersion4 { t.Skip("smallint is only supported in cassandra 2.2+") } @@ -2146,7 +2146,7 @@ func TestGetTableMetadata(t *testing.T) { if testTable == nil { t.Fatal("Expected table metadata for name 'test_table_metadata'") } - if session.cfg.ProtoVersion == protoVersion1 { + if session.cfg.ProtoVersion == internal.ProtoVersion1 { if testTable.KeyValidator != "org.apache.cassandra.db.marshal.Int32Type" { t.Errorf("Expected test_table_metadata key validator to be 'org.apache.cassandra.db.marshal.Int32Type' but was '%s'", testTable.KeyValidator) } @@ -2813,7 +2813,7 @@ func TestNegativeStream(t *testing.T) { const stream = -50 writer := frameWriterFunc(func(f *framer, streamID int) error { - f.writeHeader(0, opOptions, stream) + f.writeHeader(0, internal.OpOptions, stream) return f.finish() }) @@ -3116,7 +3116,7 @@ func TestUnmarshallNestedTypes(t *testing.T) { session := createSession(t) defer session.Close() - if session.cfg.ProtoVersion < protoVersion3 { + if session.cfg.ProtoVersion < internal.ProtoVersion3 { t.Skip("can not have frozen types in cassandra < 2.1.3") } diff --git a/cluster.go b/cluster.go index 413695ca4..cadaf3b3f 100644 --- a/cluster.go +++ b/cluster.go @@ -174,7 +174,7 @@ type ClusterConfig struct { ReconnectInterval time.Duration // The maximum amount of time to wait for schema agreement in a cluster after - // receiving a schema change frame. (default: 60s) + // receiving a schema change internal.Frame. (default: 60s) MaxWaitSchemaAgreement time.Duration // HostFilter will filter all incoming events for host, any which don't pass diff --git a/conn.go b/conn.go index ae02bd71c..2af094dcf 100644 --- a/conn.go +++ b/conn.go @@ -39,24 +39,11 @@ import ( "sync/atomic" "time" + "github.com/gocql/gocql/internal" "github.com/gocql/gocql/internal/lru" "github.com/gocql/gocql/internal/streams" ) -// approve the authenticator with the list of allowed authenticators. If the provided list is empty, -// the given authenticator is allowed. -func approve(authenticator string, approvedAuthenticators []string) bool { - if len(approvedAuthenticators) == 0 { - return true - } - for _, s := range approvedAuthenticators { - if authenticator == s { - return true - } - } - return false -} - // JoinHostPort is a utility to return an address string that can be used // by `gocql.Conn` to form a connection with a host. func JoinHostPort(addr string, port int) string { @@ -85,7 +72,7 @@ type PasswordAuthenticator struct { } func (p PasswordAuthenticator) Challenge(req []byte) ([]byte, Authenticator, error) { - if !approve(string(req), p.AllowedAuthenticators) { + if !internal.Approve(string(req), p.AllowedAuthenticators) { return nil, nil, fmt.Errorf("unexpected authenticator %q", req) } resp := make([]byte, 2+len(p.Username)+len(p.Password)) @@ -188,7 +175,7 @@ type Conn struct { frameObserver FrameHeaderObserver streamObserver StreamObserver - headerBuf [maxFrameHeaderSize]byte + headerBuf [internal.MaxFrameHeaderSize]byte streams *streams.IDGenerator mu sync.Mutex @@ -406,7 +393,7 @@ func (s *startupCoordinator) setupConn(ctx context.Context) error { return nil } -func (s *startupCoordinator) write(ctx context.Context, frame frameBuilder) (frame, error) { +func (s *startupCoordinator) write(ctx context.Context, frame frameBuilder) (internal.Frame, error) { select { case s.frameTicker <- struct{}{}: case <-ctx.Done(): @@ -585,8 +572,8 @@ func (c *Conn) serve(ctx context.Context) { c.closeWithError(err) } -func (c *Conn) discardFrame(head frameHeader) error { - _, err := io.CopyN(ioutil.Discard, c, int64(head.length)) +func (c *Conn) discardFrame(head internal.FrameHeader) error { + _, err := io.CopyN(ioutil.Discard, c, int64(head.Length)) if err != nil { return err } @@ -594,14 +581,14 @@ func (c *Conn) discardFrame(head frameHeader) error { } type protocolError struct { - frame frame + frame internal.Frame } func (p *protocolError) Error() string { if err, ok := p.frame.(error); ok { return err.Error() } - return fmt.Sprintf("gocql: received unexpected frame on stream %d: %v", p.frame.Header().stream, p.frame) + return fmt.Sprintf("gocql: received unexpected frame on stream %d: %v", p.frame.Header().Stream, p.frame) } func (c *Conn) heartBeat(ctx context.Context) { @@ -670,20 +657,20 @@ func (c *Conn) recv(ctx context.Context) error { if c.frameObserver != nil { c.frameObserver.ObserveFrameHeader(context.Background(), ObservedFrameHeader{ - Version: protoVersion(head.version), - Flags: head.flags, - Stream: int16(head.stream), - Opcode: frameOp(head.op), - Length: int32(head.length), + Version: internal.ProtoVersion(head.Version), + Flags: head.Flags, + Stream: int16(head.Stream), + Opcode: internal.FrameOp(head.Op), + Length: int32(head.Length), Start: headStartTime, End: headEndTime, Host: c.host, }) } - if head.stream > c.streams.NumStreams { - return fmt.Errorf("gocql: frame header stream is beyond call expected bounds: %d", head.stream) - } else if head.stream == -1 { + if head.Stream > c.streams.NumStreams { + return fmt.Errorf("gocql: frame header stream is beyond call expected bounds: %d", head.Stream) + } else if head.Stream == -1 { // TODO: handle cassandra event frames, we shouldnt get any currently framer := newFramer(c.compressor, c.version) if err := framer.readFrame(c, &head); err != nil { @@ -691,7 +678,7 @@ func (c *Conn) recv(ctx context.Context) error { } go c.session.handleEvent(framer) return nil - } else if head.stream <= 0 { + } else if head.Stream <= 0 { // reserved stream that we dont use, probably due to a protocol error // or a bug in Cassandra, this should be an error, parse it and return. framer := newFramer(c.compressor, c.version) @@ -714,14 +701,14 @@ func (c *Conn) recv(ctx context.Context) error { c.mu.Unlock() return ErrConnectionClosed } - call, ok := c.calls[head.stream] - delete(c.calls, head.stream) + call, ok := c.calls[head.Stream] + delete(c.calls, head.Stream) c.mu.Unlock() if call == nil || !ok { c.logger.Printf("gocql: received response for stream which has no handler: header=%v\n", head) return c.discardFrame(head) - } else if head.stream != call.streamID { - panic(fmt.Sprintf("call has incorrect streamID: got %d expected %d", call.streamID, head.stream)) + } else if head.Stream != call.streamID { + panic(fmt.Sprintf("call has incorrect streamID: got %d expected %d", call.streamID, head.Stream)) } framer := newFramer(c.compressor, c.version) @@ -1150,7 +1137,7 @@ func (c *Conn) exec(ctx context.Context, req frameBuilder, tracer Tracer) (*fram // requests on the stream to prevent nil pointer dereferences in recv(). defer c.releaseStream(call) - if v := resp.framer.header.version.version(); v != c.version { + if v := resp.framer.header.Version.Version(); v != c.version { return nil, NewErrProtocol("unexpected protocol version in response: got %d expected %d", v, c.version) } @@ -1244,7 +1231,7 @@ func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer) prep := &writePrepareFrame{ statement: stmt, } - if c.version > protoVersion4 { + if c.version > internal.ProtoVersion4 { prep.keyspace = c.currentKeyspace } @@ -1276,7 +1263,7 @@ func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer) flight.preparedStatment = &preparedStatment{ // defensively copy as we will recycle the underlying buffer after we // return. - id: copyBytes(x.preparedID), + id: internal.CopyBytes(x.preparedID), // the type info's should _not_ have a reference to the framers read buffer, // therefore we can just copy them directly. request: x.reqMeta, @@ -1303,12 +1290,12 @@ func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer) } func marshalQueryValue(typ TypeInfo, value interface{}, dst *queryValues) error { - if named, ok := value.(*namedValue); ok { - dst.name = named.name - value = named.value + if named, ok := value.(*internal.NamedValue); ok { + dst.name = named.Name + value = named.Value } - if _, ok := value.(unsetColumn); !ok { + if _, ok := value.(internal.UnsetColumn); !ok { val, err := Marshal(typ, value) if err != nil { return err @@ -1338,7 +1325,7 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { if qry.pageSize > 0 { params.pageSize = qry.pageSize } - if c.version > protoVersion4 { + if c.version > internal.ProtoVersion4 { params.keyspace = c.currentKeyspace } @@ -1431,7 +1418,7 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { if params.skipMeta { if info != nil { iter.meta = info.response - iter.meta.pagingState = copyBytes(x.meta.pagingState) + iter.meta.pagingState = internal.CopyBytes(x.meta.pagingState) } else { return &Iter{framer: framer, err: errors.New("gocql: did not receive metadata but prepared info is nil")} } @@ -1442,7 +1429,7 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { if x.meta.morePages() && !qry.disableAutoPage { newQry := new(Query) *newQry = *qry - newQry.pageState = copyBytes(x.meta.pagingState) + newQry.pageState = internal.CopyBytes(x.meta.pagingState) newQry.metrics = &queryMetrics{m: make(map[string]*hostMetrics)} iter.next = &nextIter{ @@ -1531,7 +1518,7 @@ func (c *Conn) UseKeyspace(keyspace string) error { } func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter { - if c.version == protoVersion1 { + if c.version == internal.ProtoVersion1 { return &Iter{err: ErrUnsupported} } diff --git a/conn_test.go b/conn_test.go index 8706683ff..894553b53 100644 --- a/conn_test.go +++ b/conn_test.go @@ -35,6 +35,7 @@ import ( "crypto/x509" "errors" "fmt" + "github.com/gocql/gocql/internal" "io" "io/ioutil" "math/rand" @@ -50,26 +51,26 @@ import ( ) const ( - defaultProto = protoVersion2 + defaultProto = internal.ProtoVersion2 ) func TestApprove(t *testing.T) { tests := map[bool]bool{ - approve("org.apache.cassandra.auth.PasswordAuthenticator", []string{}): true, - approve("org.apache.cassandra.auth.MutualTlsWithPasswordFallbackAuthenticator", []string{}): true, - approve("org.apache.cassandra.auth.MutualTlsAuthenticator", []string{}): true, - approve("com.instaclustr.cassandra.auth.SharedSecretAuthenticator", []string{}): true, - approve("com.datastax.bdp.cassandra.auth.DseAuthenticator", []string{}): true, - approve("io.aiven.cassandra.auth.AivenAuthenticator", []string{}): true, - approve("com.amazon.helenus.auth.HelenusAuthenticator", []string{}): true, - approve("com.ericsson.bss.cassandra.ecaudit.auth.AuditAuthenticator", []string{}): true, - approve("com.scylladb.auth.SaslauthdAuthenticator", []string{}): true, - approve("com.scylladb.auth.TransitionalAuthenticator", []string{}): true, - approve("com.instaclustr.cassandra.auth.InstaclustrPasswordAuthenticator", []string{}): true, - approve("com.apache.cassandra.auth.FakeAuthenticator", []string{}): true, - approve("com.apache.cassandra.auth.FakeAuthenticator", nil): true, - approve("com.apache.cassandra.auth.FakeAuthenticator", []string{"com.apache.cassandra.auth.FakeAuthenticator"}): true, - approve("com.apache.cassandra.auth.FakeAuthenticator", []string{"com.apache.cassandra.auth.NotFakeAuthenticator"}): false, + internal.Approve("org.apache.cassandra.auth.PasswordAuthenticator", []string{}): true, + internal.Approve("org.apache.cassandra.auth.MutualTlsWithPasswordFallbackAuthenticator", []string{}): true, + internal.Approve("org.apache.cassandra.auth.MutualTlsAuthenticator", []string{}): true, + internal.Approve("com.instaclustr.cassandra.auth.SharedSecretAuthenticator", []string{}): true, + internal.Approve("com.datastax.bdp.cassandra.auth.DseAuthenticator", []string{}): true, + internal.Approve("io.aiven.cassandra.auth.AivenAuthenticator", []string{}): true, + internal.Approve("com.amazon.helenus.auth.HelenusAuthenticator", []string{}): true, + internal.Approve("com.ericsson.bss.cassandra.ecaudit.auth.AuditAuthenticator", []string{}): true, + internal.Approve("com.scylladb.auth.SaslauthdAuthenticator", []string{}): true, + internal.Approve("com.scylladb.auth.TransitionalAuthenticator", []string{}): true, + internal.Approve("com.instaclustr.cassandra.auth.InstaclustrPasswordAuthenticator", []string{}): true, + internal.Approve("com.apache.cassandra.auth.FakeAuthenticator", []string{}): true, + internal.Approve("com.apache.cassandra.auth.FakeAuthenticator", nil): true, + internal.Approve("com.apache.cassandra.auth.FakeAuthenticator", []string{"com.apache.cassandra.auth.FakeAuthenticator"}): true, + internal.Approve("com.apache.cassandra.auth.FakeAuthenticator", []string{"com.apache.cassandra.auth.NotFakeAuthenticator"}): false, } for k, v := range tests { if k != v { @@ -92,7 +93,7 @@ func TestJoinHostPort(t *testing.T) { } } -func testCluster(proto protoVersion, addresses ...string) *ClusterConfig { +func testCluster(proto internal.ProtoVersion, addresses ...string) *ClusterConfig { cluster := NewCluster(addresses...) cluster.ProtoVersion = int(proto) cluster.disableControlConn = true @@ -142,7 +143,7 @@ func TestSSLSimpleNoClientCert(t *testing.T) { } } -func createTestSslCluster(addr string, proto protoVersion, useClientCert bool) *ClusterConfig { +func createTestSslCluster(addr string, proto internal.ProtoVersion, useClientCert bool) *ClusterConfig { cluster := testCluster(proto, addr) sslOpts := &SslOptions{ CaPath: "testdata/pki/ca.crt", @@ -176,7 +177,7 @@ func TestClosed(t *testing.T) { } } -func newTestSession(proto protoVersion, addresses ...string) (*Session, error) { +func newTestSession(proto internal.ProtoVersion, addresses ...string) (*Session, error) { return testCluster(proto, addresses...).CreateSession() } @@ -184,8 +185,8 @@ func TestDNSLookupConnected(t *testing.T) { log := &testLogger{} // Override the defaul DNS resolver and restore at the end - failDNS = true - defer func() { failDNS = false }() + internal.FailDNS = true + defer func() { internal.FailDNS = false }() srv := NewTestServer(t, defaultProto, context.Background()) defer srv.Stop() @@ -211,8 +212,8 @@ func TestDNSLookupError(t *testing.T) { log := &testLogger{} // Override the defaul DNS resolver and restore at the end - failDNS = true - defer func() { failDNS = false }() + internal.FailDNS = true + defer func() { internal.FailDNS = false }() cluster := NewCluster("cassandra1.invalid", "cassandra2.invalid") cluster.Logger = log @@ -698,9 +699,9 @@ func TestStream0(t *testing.T) { const expErr = "gocql: received unexpected frame on stream 0" var buf bytes.Buffer - f := newFramer(nil, protoVersion4) - f.writeHeader(0, opResult, 0) - f.writeInt(resultKindVoid) + f := newFramer(nil, internal.ProtoVersion4) + f.writeHeader(0, internal.OpResult, 0) + f.writeInt(internal.ResultKindVoid) f.buf[0] |= 0x80 if err := f.finish(); err != nil { t.Fatal(err) @@ -711,7 +712,7 @@ func TestStream0(t *testing.T) { conn := &Conn{ r: bufio.NewReader(&buf), - streams: streams.New(protoVersion4), + streams: streams.New(internal.ProtoVersion4), logger: &defaultLogger{}, } @@ -757,7 +758,7 @@ func TestContext_CanceledBeforeExec(t *testing.T) { addr: "127.0.0.1:0", protocol: defaultProto, recvHook: func(f *framer) { - if f.header.op == opStartup || f.header.op == opOptions { + if f.header.Op == internal.OpStartup || f.header.Op == internal.OpOptions { // ignore statup and heartbeat messages return } @@ -989,7 +990,7 @@ func TestFrameHeaderObserver(t *testing.T) { } frames := observer.getFrames() - expFrames := []frameOp{opSupported, opReady, opResult} + expFrames := []internal.FrameOp{internal.OpSupported, internal.OpReady, internal.OpResult} if len(frames) != len(expFrames) { t.Fatalf("Expected to receive %d frames, instead received %d", len(expFrames), len(frames)) } @@ -1030,7 +1031,7 @@ func (nts newTestServerOpts) newServer(t testing.TB, ctx context.Context) *TestS } headerSize := 8 - if nts.protocol > protoVersion2 { + if nts.protocol > internal.ProtoVersion2 { headerSize = 9 } @@ -1077,7 +1078,7 @@ func NewSSLTestServer(t testing.TB, protocol uint8, ctx context.Context) *TestSe } headerSize := 8 - if protocol > protoVersion2 { + if protocol > internal.ProtoVersion2 { headerSize = 9 } @@ -1195,8 +1196,8 @@ func (srv *TestServer) process(conn net.Conn, reqFrame *framer) { } respFrame := newFramer(nil, reqFrame.proto) - switch head.op { - case opStartup: + switch head.Op { + case internal.OpStartup: if atomic.LoadInt32(&srv.TimeoutOnStartup) > 0 { // Do not respond to startup command // wait until we get a cancel signal @@ -1205,11 +1206,11 @@ func (srv *TestServer) process(conn net.Conn, reqFrame *framer) { return } } - respFrame.writeHeader(0, opReady, head.stream) - case opOptions: - respFrame.writeHeader(0, opSupported, head.stream) + respFrame.writeHeader(0, internal.OpReady, head.Stream) + case internal.OpOptions: + respFrame.writeHeader(0, internal.OpSupported, head.Stream) respFrame.writeShort(0) - case opQuery: + case internal.OpQuery: query := reqFrame.readLongString() first := query if n := strings.Index(query, " "); n > 0 { @@ -1218,22 +1219,22 @@ func (srv *TestServer) process(conn net.Conn, reqFrame *framer) { switch strings.ToLower(first) { case "kill": atomic.AddInt64(&srv.nKillReq, 1) - respFrame.writeHeader(0, opError, head.stream) + respFrame.writeHeader(0, internal.OpError, head.Stream) respFrame.writeInt(0x1001) respFrame.writeString("query killed") case "use": - respFrame.writeInt(resultKindKeyspace) + respFrame.writeInt(internal.ResultKindKeyspace) respFrame.writeString(strings.TrimSpace(query[3:])) case "void": - respFrame.writeHeader(0, opResult, head.stream) - respFrame.writeInt(resultKindVoid) + respFrame.writeHeader(0, internal.OpResult, head.Stream) + respFrame.writeInt(internal.ResultKindVoid) case "timeout": <-srv.ctx.Done() return case "slow": go func() { - respFrame.writeHeader(0, opResult, head.stream) - respFrame.writeInt(resultKindVoid) + respFrame.writeHeader(0, internal.OpResult, head.Stream) + respFrame.writeInt(internal.ResultKindVoid) respFrame.buf[0] = srv.protocol | 0x80 select { case <-srv.ctx.Done(): @@ -1247,25 +1248,25 @@ func (srv *TestServer) process(conn net.Conn, reqFrame *framer) { case "speculative": atomic.AddInt64(&srv.nKillReq, 1) if atomic.LoadInt64(&srv.nKillReq) > 3 { - respFrame.writeHeader(0, opResult, head.stream) - respFrame.writeInt(resultKindVoid) + respFrame.writeHeader(0, internal.OpResult, head.Stream) + respFrame.writeInt(internal.ResultKindVoid) respFrame.writeString("speculative query success on the node " + srv.Address) } else { - respFrame.writeHeader(0, opError, head.stream) + respFrame.writeHeader(0, internal.OpError, head.Stream) respFrame.writeInt(0x1001) respFrame.writeString("speculative error") rand.Seed(time.Now().UnixNano()) <-time.After(time.Millisecond * 120) } default: - respFrame.writeHeader(0, opResult, head.stream) - respFrame.writeInt(resultKindVoid) + respFrame.writeHeader(0, internal.OpResult, head.Stream) + respFrame.writeInt(internal.ResultKindVoid) } - case opError: - respFrame.writeHeader(0, opError, head.stream) + case internal.OpError: + respFrame.writeHeader(0, internal.OpError, head.Stream) respFrame.buf = append(respFrame.buf, reqFrame.buf...) default: - respFrame.writeHeader(0, opError, head.stream) + respFrame.writeHeader(0, internal.OpError, head.Stream) respFrame.writeInt(0) respFrame.writeString("not supported") } @@ -1295,10 +1296,10 @@ func (srv *TestServer) readFrame(conn net.Conn) (*framer, error) { } // should be a request frame - if head.version.response() { - return nil, fmt.Errorf("expected to read a request frame got version: %v", head.version) - } else if head.version.version() != srv.protocol { - return nil, fmt.Errorf("expected to read protocol version 0x%x got 0x%x", srv.protocol, head.version.version()) + if head.Version.Response() { + return nil, fmt.Errorf("expected to read a request frame got version: %v", head.Version) + } else if head.Version.Version() != srv.protocol { + return nil, fmt.Errorf("expected to read protocol version 0x%x got 0x%x", srv.protocol, head.Version.Version()) } return framer, nil diff --git a/control.go b/control.go index b30b44ea3..fa0c93a1f 100644 --- a/control.go +++ b/control.go @@ -37,6 +37,8 @@ import ( "sync" "sync/atomic" "time" + + "github.com/gocql/gocql/internal" ) var ( @@ -50,7 +52,7 @@ func init() { panic(fmt.Sprintf("unable to seed random number generator: %v", err)) } - randr = rand.New(rand.NewSource(int64(readInt(b)))) + randr = rand.New(rand.NewSource(int64(internal.ReadInt(b)))) } const ( @@ -199,7 +201,7 @@ func parseProtocolFromError(err error) int { matches := protocolSupportRe.FindAllStringSubmatch(err.Error(), -1) if len(matches) != 1 || len(matches[0]) != 2 { if verr, ok := err.(*protocolError); ok { - return int(verr.frame.Header().version.version()) + return int(verr.frame.Header().Version.Version()) } return 0 } @@ -459,7 +461,7 @@ func (c *controlConn) getConn() *connHost { return c.conn.Load().(*connHost) } -func (c *controlConn) writeFrame(w frameBuilder) (frame, error) { +func (c *controlConn) writeFrame(w frameBuilder) (internal.Frame, error) { ch := c.getConn() if ch == nil { return nil, errNoControl diff --git a/control_test.go b/control_test.go index 9713718e6..227782ed3 100644 --- a/control_test.go +++ b/control_test.go @@ -25,6 +25,7 @@ package gocql import ( + "github.com/gocql/gocql/internal" "net" "testing" ) @@ -72,8 +73,8 @@ func TestParseProtocol(t *testing.T) { { err: &protocolError{ frame: errorFrame{ - frameHeader: frameHeader{ - version: 0x83, + FrameHeader: internal.FrameHeader{ + Version: 0x83, }, code: 0x10, message: "Invalid or unsupported protocol version: 5", diff --git a/errors.go b/errors.go index d64c46208..040fffb1a 100644 --- a/errors.go +++ b/errors.go @@ -24,7 +24,10 @@ package gocql -import "fmt" +import ( + "fmt" + "github.com/gocql/gocql/internal" +) // See CQL Binary Protocol v5, section 8 for more details. // https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec @@ -118,7 +121,7 @@ type RequestError interface { } type errorFrame struct { - frameHeader + internal.FrameHeader code int message string diff --git a/events.go b/events.go index 93b001acc..1ed425ea8 100644 --- a/events.go +++ b/events.go @@ -25,6 +25,7 @@ package gocql import ( + "github.com/gocql/gocql/internal" "net" "sync" "time" @@ -34,15 +35,15 @@ type eventDebouncer struct { name string timer *time.Timer mu sync.Mutex - events []frame + events []internal.Frame - callback func([]frame) + callback func([]internal.Frame) quit chan struct{} logger StdLogger } -func newEventDebouncer(name string, eventHandler func([]frame), logger StdLogger) *eventDebouncer { +func newEventDebouncer(name string, eventHandler func([]internal.Frame), logger StdLogger) *eventDebouncer { e := &eventDebouncer{ name: name, quit: make(chan struct{}), @@ -89,10 +90,10 @@ func (e *eventDebouncer) flush() { // the callback multiple times, probably a bad idea. In this case we could drop // frames? go e.callback(e.events) - e.events = make([]frame, 0, eventBufferSize) + e.events = make([]internal.Frame, 0, eventBufferSize) } -func (e *eventDebouncer) debounce(frame frame) { +func (e *eventDebouncer) debounce(frame internal.Frame) { e.mu.Lock() e.timer.Reset(eventDebounceTime) @@ -129,7 +130,7 @@ func (s *Session) handleEvent(framer *framer) { } } -func (s *Session) handleSchemaEvent(frames []frame) { +func (s *Session) handleSchemaEvent(frames []internal.Frame) { // TODO: debounce events for _, frame := range frames { switch f := frame.(type) { @@ -163,7 +164,7 @@ func (s *Session) handleKeyspaceChange(keyspace, change string) { // Processing topology change events before status change events ensures // that a NEW_NODE event is not dropped in favor of a newer UP event (which // would itself be dropped/ignored, as the node is not yet known). -func (s *Session) handleNodeEvent(frames []frame) { +func (s *Session) handleNodeEvent(frames []internal.Frame) { type nodeEvent struct { change string host net.IP diff --git a/events_test.go b/events_test.go index 537c51885..119419707 100644 --- a/events_test.go +++ b/events_test.go @@ -25,6 +25,7 @@ package gocql import ( + "github.com/gocql/gocql/internal" "net" "sync" "testing" @@ -36,7 +37,7 @@ func TestEventDebounce(t *testing.T) { wg.Add(1) eventsSeen := 0 - debouncer := newEventDebouncer("testDebouncer", func(events []frame) { + debouncer := newEventDebouncer("testDebouncer", func(events []internal.Frame) { defer wg.Done() eventsSeen += len(events) }, &defaultLogger{}) diff --git a/example_test.go b/example_test.go index 35ea051a0..14ec4c896 100644 --- a/example_test.go +++ b/example_test.go @@ -27,9 +27,8 @@ package gocql_test import ( "context" "fmt" - "log" - gocql "github.com/gocql/gocql" + "log" ) func Example() { diff --git a/frame.go b/frame.go index d374ae574..7996a6282 100644 --- a/frame.go +++ b/frame.go @@ -34,9 +34,9 @@ import ( "runtime" "strings" "time" -) -type unsetColumn struct{} + "github.com/gocql/gocql/internal" +) // UnsetValue represents a value used in a query binding that will be ignored by Cassandra. // @@ -45,150 +45,17 @@ type unsetColumn struct{} // want to update some fields, where before you needed to make another prepared statement. // // UnsetValue is only available when using the version 4 of the protocol. -var UnsetValue = unsetColumn{} - -type namedValue struct { - name string - value interface{} -} +var UnsetValue = internal.UnsetColumn{} // NamedValue produce a value which will bind to the named parameter in a query func NamedValue(name string, value interface{}) interface{} { - return &namedValue{ - name: name, - value: value, + return &internal.NamedValue{ + Name: name, + Value: value, } } -const ( - protoDirectionMask = 0x80 - protoVersionMask = 0x7F - protoVersion1 = 0x01 - protoVersion2 = 0x02 - protoVersion3 = 0x03 - protoVersion4 = 0x04 - protoVersion5 = 0x05 - - maxFrameSize = 256 * 1024 * 1024 -) - -type protoVersion byte - -func (p protoVersion) request() bool { - return p&protoDirectionMask == 0x00 -} - -func (p protoVersion) response() bool { - return p&protoDirectionMask == 0x80 -} - -func (p protoVersion) version() byte { - return byte(p) & protoVersionMask -} - -func (p protoVersion) String() string { - dir := "REQ" - if p.response() { - dir = "RESP" - } - - return fmt.Sprintf("[version=%d direction=%s]", p.version(), dir) -} - -type frameOp byte - -const ( - // header ops - opError frameOp = 0x00 - opStartup frameOp = 0x01 - opReady frameOp = 0x02 - opAuthenticate frameOp = 0x03 - opOptions frameOp = 0x05 - opSupported frameOp = 0x06 - opQuery frameOp = 0x07 - opResult frameOp = 0x08 - opPrepare frameOp = 0x09 - opExecute frameOp = 0x0A - opRegister frameOp = 0x0B - opEvent frameOp = 0x0C - opBatch frameOp = 0x0D - opAuthChallenge frameOp = 0x0E - opAuthResponse frameOp = 0x0F - opAuthSuccess frameOp = 0x10 -) - -func (f frameOp) String() string { - switch f { - case opError: - return "ERROR" - case opStartup: - return "STARTUP" - case opReady: - return "READY" - case opAuthenticate: - return "AUTHENTICATE" - case opOptions: - return "OPTIONS" - case opSupported: - return "SUPPORTED" - case opQuery: - return "QUERY" - case opResult: - return "RESULT" - case opPrepare: - return "PREPARE" - case opExecute: - return "EXECUTE" - case opRegister: - return "REGISTER" - case opEvent: - return "EVENT" - case opBatch: - return "BATCH" - case opAuthChallenge: - return "AUTH_CHALLENGE" - case opAuthResponse: - return "AUTH_RESPONSE" - case opAuthSuccess: - return "AUTH_SUCCESS" - default: - return fmt.Sprintf("UNKNOWN_OP_%d", f) - } -} - -const ( - // result kind - resultKindVoid = 1 - resultKindRows = 2 - resultKindKeyspace = 3 - resultKindPrepared = 4 - resultKindSchemaChanged = 5 - - // rows flags - flagGlobalTableSpec int = 0x01 - flagHasMorePages int = 0x02 - flagNoMetaData int = 0x04 - - // query flags - flagValues byte = 0x01 - flagSkipMetaData byte = 0x02 - flagPageSize byte = 0x04 - flagWithPagingState byte = 0x08 - flagWithSerialConsistency byte = 0x10 - flagDefaultTimestamp byte = 0x20 - flagWithNameValues byte = 0x40 - flagWithKeyspace byte = 0x80 - - // prepare flags - flagWithPreparedKeyspace uint32 = 0x01 - - // header flags - flagCompress byte = 0x01 - flagTracing byte = 0x02 - flagCustomPayload byte = 0x04 - flagWarning byte = 0x08 - flagBetaProtocol byte = 0x10 -) +//TODO: We should move protoVersion, frameOp, proto version etc. to internal type Consistency uint16 @@ -321,44 +188,15 @@ func (s *SerialConsistency) UnmarshalText(text []byte) error { return nil } -const ( - apacheCassandraTypePrefix = "org.apache.cassandra.db.marshal." -) - var ( ErrFrameTooBig = errors.New("frame length is bigger than the maximum allowed") ) -const maxFrameHeaderSize = 9 - -func readInt(p []byte) int32 { - return int32(p[0])<<24 | int32(p[1])<<16 | int32(p[2])<<8 | int32(p[3]) -} - -type frameHeader struct { - version protoVersion - flags byte - stream int - op frameOp - length int - warnings []string -} - -func (f frameHeader) String() string { - return fmt.Sprintf("[header version=%s flags=0x%x stream=%d op=%s length=%d]", f.version, f.flags, f.stream, f.op, f.length) -} - -func (f frameHeader) Header() frameHeader { - return f -} - -const defaultBufSize = 128 - type ObservedFrameHeader struct { - Version protoVersion + Version internal.ProtoVersion Flags byte Stream int16 - Opcode frameOp + Opcode internal.FrameOp Length int32 // StartHeader is the time we started reading the frame header off the network connection. @@ -390,7 +228,7 @@ type framer struct { compres Compressor headSize int // if this frame was read then the header will be here - header *frameHeader + header *internal.FrameHeader // if tracing flag is set this is not nil traceID []byte @@ -405,23 +243,23 @@ type framer struct { } func newFramer(compressor Compressor, version byte) *framer { - buf := make([]byte, defaultBufSize) + buf := make([]byte, internal.DefaultBufSize) f := &framer{ buf: buf[:0], readBuffer: buf, } var flags byte if compressor != nil { - flags |= flagCompress + flags |= internal.FlagCompress } - if version == protoVersion5 { - flags |= flagBetaProtocol + if version == internal.ProtoVersion5 { + flags |= internal.FlagBetaProtocol } - version &= protoVersionMask + version &= internal.ProtoVersionMask headSize := 8 - if version > protoVersion2 { + if version > internal.ProtoVersion2 { headSize = 9 } @@ -436,53 +274,49 @@ func newFramer(compressor Compressor, version byte) *framer { return f } -type frame interface { - Header() frameHeader -} - -func readHeader(r io.Reader, p []byte) (head frameHeader, err error) { +func readHeader(r io.Reader, p []byte) (head internal.FrameHeader, err error) { _, err = io.ReadFull(r, p[:1]) if err != nil { - return frameHeader{}, err + return internal.FrameHeader{}, err } - version := p[0] & protoVersionMask + version := p[0] & internal.ProtoVersionMask - if version < protoVersion1 || version > protoVersion5 { - return frameHeader{}, fmt.Errorf("gocql: unsupported protocol response version: %d", version) + if version < internal.ProtoVersion1 || version > internal.ProtoVersion5 { + return internal.FrameHeader{}, fmt.Errorf("gocql: unsupported protocol response version: %d", version) } headSize := 9 - if version < protoVersion3 { + if version < internal.ProtoVersion3 { headSize = 8 } _, err = io.ReadFull(r, p[1:headSize]) if err != nil { - return frameHeader{}, err + return internal.FrameHeader{}, err } p = p[:headSize] - head.version = protoVersion(p[0]) - head.flags = p[1] + head.Version = internal.ProtoVersion(p[0]) + head.Flags = p[1] - if version > protoVersion2 { + if version > internal.ProtoVersion2 { if len(p) != 9 { - return frameHeader{}, fmt.Errorf("not enough bytes to read header require 9 got: %d", len(p)) + return internal.FrameHeader{}, fmt.Errorf("not enough bytes to read header require 9 got: %d", len(p)) } - head.stream = int(int16(p[2])<<8 | int16(p[3])) - head.op = frameOp(p[4]) - head.length = int(readInt(p[5:])) + head.Stream = int(int16(p[2])<<8 | int16(p[3])) + head.Op = internal.FrameOp(p[4]) + head.Length = int(internal.ReadInt(p[5:])) } else { if len(p) != 8 { - return frameHeader{}, fmt.Errorf("not enough bytes to read header require 8 got: %d", len(p)) + return internal.FrameHeader{}, fmt.Errorf("not enough bytes to read header require 8 got: %d", len(p)) } - head.stream = int(int8(p[2])) - head.op = frameOp(p[3]) - head.length = int(readInt(p[4:])) + head.Stream = int(int8(p[2])) + head.Op = internal.FrameOp(p[3]) + head.Length = int(internal.ReadInt(p[4:])) } return head, nil @@ -490,41 +324,41 @@ func readHeader(r io.Reader, p []byte) (head frameHeader, err error) { // explicitly enables tracing for the framers outgoing requests func (f *framer) trace() { - f.flags |= flagTracing + f.flags |= internal.FlagTracing } // explicitly enables the custom payload flag func (f *framer) payload() { - f.flags |= flagCustomPayload + f.flags |= internal.FlagCustomPayload } // reads a frame form the wire into the framers buffer -func (f *framer) readFrame(r io.Reader, head *frameHeader) error { - if head.length < 0 { - return fmt.Errorf("frame body length can not be less than 0: %d", head.length) - } else if head.length > maxFrameSize { +func (f *framer) readFrame(r io.Reader, head *internal.FrameHeader) error { + if head.Length < 0 { + return fmt.Errorf("frame body length can not be less than 0: %d", head.Length) + } else if head.Length > internal.MaxFrameSize { // need to free up the connection to be used again - _, err := io.CopyN(ioutil.Discard, r, int64(head.length)) + _, err := io.CopyN(ioutil.Discard, r, int64(head.Length)) if err != nil { return fmt.Errorf("error whilst trying to discard frame with invalid length: %v", err) } return ErrFrameTooBig } - if cap(f.readBuffer) >= head.length { - f.buf = f.readBuffer[:head.length] + if cap(f.readBuffer) >= head.Length { + f.buf = f.readBuffer[:head.Length] } else { - f.readBuffer = make([]byte, head.length) + f.readBuffer = make([]byte, head.Length) f.buf = f.readBuffer } // assume the underlying reader takes care of timeouts and retries n, err := io.ReadFull(r, f.buf) if err != nil { - return fmt.Errorf("unable to read frame body: read %d/%d bytes: %v", n, head.length, err) + return fmt.Errorf("unable to read frame body: read %d/%d bytes: %v", n, head.Length, err) } - if head.flags&flagCompress == flagCompress { + if head.Flags&internal.FlagCompress == internal.FlagCompress { if f.compres == nil { return NewErrProtocol("no compressor available with compressed frame body") } @@ -539,7 +373,7 @@ func (f *framer) readFrame(r io.Reader, head *frameHeader) error { return nil } -func (f *framer) parseFrame() (frame frame, err error) { +func (f *framer) parseFrame() (frame internal.Frame, err error) { defer func() { if r := recover(); r != nil { if _, ok := r.(runtime.Error); ok { @@ -549,53 +383,53 @@ func (f *framer) parseFrame() (frame frame, err error) { } }() - if f.header.version.request() { - return nil, NewErrProtocol("got a request frame from server: %v", f.header.version) + if f.header.Version.Request() { + return nil, NewErrProtocol("got a request frame from server: %v", f.header.Version) } - if f.header.flags&flagTracing == flagTracing { + if f.header.Flags&internal.FlagTracing == internal.FlagTracing { f.readTrace() } - if f.header.flags&flagWarning == flagWarning { - f.header.warnings = f.readStringList() + if f.header.Flags&internal.FlagWarning == internal.FlagWarning { + f.header.Warnings = f.readStringList() } - if f.header.flags&flagCustomPayload == flagCustomPayload { + if f.header.Flags&internal.FlagCustomPayload == internal.FlagCustomPayload { f.customPayload = f.readBytesMap() } // assumes that the frame body has been read into rbuf - switch f.header.op { - case opError: + switch f.header.Op { + case internal.OpError: frame = f.parseErrorFrame() - case opReady: + case internal.OpReady: frame = f.parseReadyFrame() - case opResult: + case internal.OpResult: frame, err = f.parseResultFrame() - case opSupported: + case internal.OpSupported: frame = f.parseSupportedFrame() - case opAuthenticate: + case internal.OpAuthenticate: frame = f.parseAuthenticateFrame() - case opAuthChallenge: + case internal.OpAuthChallenge: frame = f.parseAuthChallengeFrame() - case opAuthSuccess: + case internal.OpAuthSuccess: frame = f.parseAuthSuccessFrame() - case opEvent: + case internal.OpEvent: frame = f.parseEventFrame() default: - return nil, NewErrProtocol("unknown op in frame header: %s", f.header.op) + return nil, NewErrProtocol("unknown op in frame header: %s", f.header.Op) } return } -func (f *framer) parseErrorFrame() frame { +func (f *framer) parseErrorFrame() internal.Frame { code := f.readInt() msg := f.readString() errD := errorFrame{ - frameHeader: *f.header, + FrameHeader: *f.header, code: code, message: msg, } @@ -647,7 +481,7 @@ func (f *framer) parseErrorFrame() frame { stmtId := f.readShortBytes() return &RequestErrUnprepared{ errorFrame: errD, - StatementId: copyBytes(stmtId), // defensively copy + StatementId: internal.CopyBytes(stmtId), // defensively copy } case ErrCodeReadFailure: res := &RequestErrReadFailure{ @@ -656,7 +490,7 @@ func (f *framer) parseErrorFrame() frame { res.Consistency = f.readConsistency() res.Received = f.readInt() res.BlockFor = f.readInt() - if f.proto > protoVersion4 { + if f.proto > internal.ProtoVersion4 { res.ErrorMap = f.readErrorMap() res.NumFailures = len(res.ErrorMap) } else { @@ -672,7 +506,7 @@ func (f *framer) parseErrorFrame() frame { res.Consistency = f.readConsistency() res.Received = f.readInt() res.BlockFor = f.readInt() - if f.proto > protoVersion4 { + if f.proto > internal.ProtoVersion4 { res.ErrorMap = f.readErrorMap() res.NumFailures = len(res.ErrorMap) } else { @@ -721,14 +555,14 @@ func (f *framer) readErrorMap() (errMap ErrorMap) { return } -func (f *framer) writeHeader(flags byte, op frameOp, stream int) { +func (f *framer) writeHeader(flags byte, op internal.FrameOp, stream int) { f.buf = f.buf[:0] f.buf = append(f.buf, f.proto, flags, ) - if f.proto > protoVersion2 { + if f.proto > internal.ProtoVersion2 { f.buf = append(f.buf, byte(stream>>8), byte(stream), @@ -751,7 +585,7 @@ func (f *framer) writeHeader(flags byte, op frameOp, stream int) { func (f *framer) setLength(length int) { p := 4 - if f.proto > protoVersion2 { + if f.proto > internal.ProtoVersion2 { p = 5 } @@ -762,13 +596,13 @@ func (f *framer) setLength(length int) { } func (f *framer) finish() error { - if len(f.buf) > maxFrameSize { + if len(f.buf) > internal.MaxFrameSize { // huge app frame, lets remove it so it doesn't bloat the heap - f.buf = make([]byte, defaultBufSize) + f.buf = make([]byte, internal.DefaultBufSize) return ErrFrameTooBig } - if f.buf[1]&flagCompress == flagCompress { + if f.buf[1]&internal.FlagCompress == internal.FlagCompress { if f.compres == nil { panic("compress flag set with no compressor") } @@ -797,26 +631,26 @@ func (f *framer) readTrace() { } type readyFrame struct { - frameHeader + internal.FrameHeader } -func (f *framer) parseReadyFrame() frame { +func (f *framer) parseReadyFrame() internal.Frame { return &readyFrame{ - frameHeader: *f.header, + FrameHeader: *f.header, } } type supportedFrame struct { - frameHeader + internal.FrameHeader supported map[string][]string } // TODO: if we move the body buffer onto the frameHeader then we only need a single // framer, and can move the methods onto the header. -func (f *framer) parseSupportedFrame() frame { +func (f *framer) parseSupportedFrame() internal.Frame { return &supportedFrame{ - frameHeader: *f.header, + FrameHeader: *f.header, supported: f.readStringMultiMap(), } @@ -831,7 +665,7 @@ func (w writeStartupFrame) String() string { } func (w *writeStartupFrame) buildFrame(f *framer, streamID int) error { - f.writeHeader(f.flags&^flagCompress, opStartup, streamID) + f.writeHeader(f.flags&^internal.FlagCompress, internal.OpStartup, streamID) f.writeStringMap(w.opts) return f.finish() @@ -847,19 +681,19 @@ func (w *writePrepareFrame) buildFrame(f *framer, streamID int) error { if len(w.customPayload) > 0 { f.payload() } - f.writeHeader(f.flags, opPrepare, streamID) + f.writeHeader(f.flags, internal.OpPrepare, streamID) f.writeCustomPayload(&w.customPayload) f.writeLongString(w.statement) var flags uint32 = 0 if w.keyspace != "" { - if f.proto > protoVersion4 { - flags |= flagWithPreparedKeyspace + if f.proto > internal.ProtoVersion4 { + flags |= internal.FlagWithPreparedKeyspace } else { panic(fmt.Errorf("the keyspace can only be set with protocol 5 or higher")) } } - if f.proto > protoVersion4 { + if f.proto > internal.ProtoVersion4 { f.writeUint(flags) } if w.keyspace != "" { @@ -959,7 +793,7 @@ func (f *framer) parsePreparedMetadata() preparedMetadata { } meta.actualColCount = meta.colCount - if f.proto >= protoVersion4 { + if f.proto >= internal.ProtoVersion4 { pkeyCount := f.readInt() pkeys := make([]int, pkeyCount) for i := 0; i < pkeyCount; i++ { @@ -968,15 +802,15 @@ func (f *framer) parsePreparedMetadata() preparedMetadata { meta.pkeyColumns = pkeys } - if meta.flags&flagHasMorePages == flagHasMorePages { - meta.pagingState = copyBytes(f.readBytes()) + if meta.flags&internal.FlagHasMorePages == internal.FlagHasMorePages { + meta.pagingState = internal.CopyBytes(f.readBytes()) } - if meta.flags&flagNoMetaData == flagNoMetaData { + if meta.flags&internal.FlagNoMetaData == internal.FlagNoMetaData { return meta } - globalSpec := meta.flags&flagGlobalTableSpec == flagGlobalTableSpec + globalSpec := meta.flags&internal.FlagGlobalTableSpec == internal.FlagGlobalTableSpec if globalSpec { meta.keyspace = f.readString() meta.table = f.readString() @@ -1020,7 +854,7 @@ type resultMetadata struct { } func (r *resultMetadata) morePages() bool { - return r.flags&flagHasMorePages == flagHasMorePages + return r.flags&internal.FlagHasMorePages == internal.FlagHasMorePages } func (r resultMetadata) String() string { @@ -1056,16 +890,16 @@ func (f *framer) parseResultMetadata() resultMetadata { } meta.actualColCount = meta.colCount - if meta.flags&flagHasMorePages == flagHasMorePages { - meta.pagingState = copyBytes(f.readBytes()) + if meta.flags&internal.FlagHasMorePages == internal.FlagHasMorePages { + meta.pagingState = internal.CopyBytes(f.readBytes()) } - if meta.flags&flagNoMetaData == flagNoMetaData { + if meta.flags&internal.FlagNoMetaData == internal.FlagNoMetaData { return meta } var keyspace, table string - globalSpec := meta.flags&flagGlobalTableSpec == flagGlobalTableSpec + globalSpec := meta.flags&internal.FlagGlobalTableSpec == internal.FlagGlobalTableSpec if globalSpec { keyspace = f.readString() table = f.readString() @@ -1095,26 +929,26 @@ func (f *framer) parseResultMetadata() resultMetadata { } type resultVoidFrame struct { - frameHeader + internal.FrameHeader } func (f *resultVoidFrame) String() string { return "[result_void]" } -func (f *framer) parseResultFrame() (frame, error) { +func (f *framer) parseResultFrame() (internal.Frame, error) { kind := f.readInt() switch kind { - case resultKindVoid: - return &resultVoidFrame{frameHeader: *f.header}, nil - case resultKindRows: + case internal.ResultKindVoid: + return &resultVoidFrame{FrameHeader: *f.header}, nil + case internal.ResultKindRows: return f.parseResultRows(), nil - case resultKindKeyspace: + case internal.ResultKindKeyspace: return f.parseResultSetKeyspace(), nil - case resultKindPrepared: + case internal.ResultKindPrepared: return f.parseResultPrepared(), nil - case resultKindSchemaChanged: + case internal.ResultKindSchemaChanged: return f.parseResultSchemaChange(), nil } @@ -1122,7 +956,7 @@ func (f *framer) parseResultFrame() (frame, error) { } type resultRowsFrame struct { - frameHeader + internal.FrameHeader meta resultMetadata // dont parse the rows here as we only need to do it once @@ -1133,7 +967,7 @@ func (f *resultRowsFrame) String() string { return fmt.Sprintf("[result_rows meta=%v]", f.meta) } -func (f *framer) parseResultRows() frame { +func (f *framer) parseResultRows() internal.Frame { result := &resultRowsFrame{} result.meta = f.parseResultMetadata() @@ -1146,7 +980,7 @@ func (f *framer) parseResultRows() frame { } type resultKeyspaceFrame struct { - frameHeader + internal.FrameHeader keyspace string } @@ -1154,29 +988,29 @@ func (r *resultKeyspaceFrame) String() string { return fmt.Sprintf("[result_keyspace keyspace=%s]", r.keyspace) } -func (f *framer) parseResultSetKeyspace() frame { +func (f *framer) parseResultSetKeyspace() internal.Frame { return &resultKeyspaceFrame{ - frameHeader: *f.header, + FrameHeader: *f.header, keyspace: f.readString(), } } type resultPreparedFrame struct { - frameHeader + internal.FrameHeader preparedID []byte reqMeta preparedMetadata respMeta resultMetadata } -func (f *framer) parseResultPrepared() frame { +func (f *framer) parseResultPrepared() internal.Frame { frame := &resultPreparedFrame{ - frameHeader: *f.header, + FrameHeader: *f.header, preparedID: f.readShortBytes(), reqMeta: f.parsePreparedMetadata(), } - if f.proto < protoVersion2 { + if f.proto < internal.ProtoVersion2 { return frame } @@ -1186,7 +1020,7 @@ func (f *framer) parseResultPrepared() frame { } type schemaChangeKeyspace struct { - frameHeader + internal.FrameHeader change string keyspace string @@ -1197,7 +1031,7 @@ func (f schemaChangeKeyspace) String() string { } type schemaChangeTable struct { - frameHeader + internal.FrameHeader change string keyspace string @@ -1209,7 +1043,7 @@ func (f schemaChangeTable) String() string { } type schemaChangeType struct { - frameHeader + internal.FrameHeader change string keyspace string @@ -1217,7 +1051,7 @@ type schemaChangeType struct { } type schemaChangeFunction struct { - frameHeader + internal.FrameHeader change string keyspace string @@ -1226,7 +1060,7 @@ type schemaChangeFunction struct { } type schemaChangeAggregate struct { - frameHeader + internal.FrameHeader change string keyspace string @@ -1234,22 +1068,22 @@ type schemaChangeAggregate struct { args []string } -func (f *framer) parseResultSchemaChange() frame { - if f.proto <= protoVersion2 { +func (f *framer) parseResultSchemaChange() internal.Frame { + if f.proto <= internal.ProtoVersion2 { change := f.readString() keyspace := f.readString() table := f.readString() if table != "" { return &schemaChangeTable{ - frameHeader: *f.header, + FrameHeader: *f.header, change: change, keyspace: keyspace, object: table, } } else { return &schemaChangeKeyspace{ - frameHeader: *f.header, + FrameHeader: *f.header, change: change, keyspace: keyspace, } @@ -1262,7 +1096,7 @@ func (f *framer) parseResultSchemaChange() frame { switch target { case "KEYSPACE": frame := &schemaChangeKeyspace{ - frameHeader: *f.header, + FrameHeader: *f.header, change: change, } @@ -1271,7 +1105,7 @@ func (f *framer) parseResultSchemaChange() frame { return frame case "TABLE": frame := &schemaChangeTable{ - frameHeader: *f.header, + FrameHeader: *f.header, change: change, } @@ -1281,7 +1115,7 @@ func (f *framer) parseResultSchemaChange() frame { return frame case "TYPE": frame := &schemaChangeType{ - frameHeader: *f.header, + FrameHeader: *f.header, change: change, } @@ -1291,7 +1125,7 @@ func (f *framer) parseResultSchemaChange() frame { return frame case "FUNCTION": frame := &schemaChangeFunction{ - frameHeader: *f.header, + FrameHeader: *f.header, change: change, } @@ -1302,7 +1136,7 @@ func (f *framer) parseResultSchemaChange() frame { return frame case "AGGREGATE": frame := &schemaChangeAggregate{ - frameHeader: *f.header, + FrameHeader: *f.header, change: change, } @@ -1319,7 +1153,7 @@ func (f *framer) parseResultSchemaChange() frame { } type authenticateFrame struct { - frameHeader + internal.FrameHeader class string } @@ -1328,15 +1162,15 @@ func (a *authenticateFrame) String() string { return fmt.Sprintf("[authenticate class=%q]", a.class) } -func (f *framer) parseAuthenticateFrame() frame { +func (f *framer) parseAuthenticateFrame() internal.Frame { return &authenticateFrame{ - frameHeader: *f.header, + FrameHeader: *f.header, class: f.readString(), } } type authSuccessFrame struct { - frameHeader + internal.FrameHeader data []byte } @@ -1345,15 +1179,15 @@ func (a *authSuccessFrame) String() string { return fmt.Sprintf("[auth_success data=%q]", a.data) } -func (f *framer) parseAuthSuccessFrame() frame { +func (f *framer) parseAuthSuccessFrame() internal.Frame { return &authSuccessFrame{ - frameHeader: *f.header, + FrameHeader: *f.header, data: f.readBytes(), } } type authChallengeFrame struct { - frameHeader + internal.FrameHeader data []byte } @@ -1362,15 +1196,15 @@ func (a *authChallengeFrame) String() string { return fmt.Sprintf("[auth_challenge data=%q]", a.data) } -func (f *framer) parseAuthChallengeFrame() frame { +func (f *framer) parseAuthChallengeFrame() internal.Frame { return &authChallengeFrame{ - frameHeader: *f.header, + FrameHeader: *f.header, data: f.readBytes(), } } type statusChangeEventFrame struct { - frameHeader + internal.FrameHeader change string host net.IP @@ -1383,7 +1217,7 @@ func (t statusChangeEventFrame) String() string { // essentially the same as statusChange type topologyChangeEventFrame struct { - frameHeader + internal.FrameHeader change string host net.IP @@ -1394,18 +1228,18 @@ func (t topologyChangeEventFrame) String() string { return fmt.Sprintf("[topology_change change=%s host=%v port=%v]", t.change, t.host, t.port) } -func (f *framer) parseEventFrame() frame { +func (f *framer) parseEventFrame() internal.Frame { eventType := f.readString() switch eventType { case "TOPOLOGY_CHANGE": - frame := &topologyChangeEventFrame{frameHeader: *f.header} + frame := &topologyChangeEventFrame{FrameHeader: *f.header} frame.change = f.readString() frame.host, frame.port = f.readInet() return frame case "STATUS_CHANGE": - frame := &statusChangeEventFrame{frameHeader: *f.header} + frame := &statusChangeEventFrame{FrameHeader: *f.header} frame.change = f.readString() frame.host, frame.port = f.readInet() @@ -1432,7 +1266,7 @@ func (a *writeAuthResponseFrame) buildFrame(framer *framer, streamID int) error } func (f *framer) writeAuthResponseFrame(streamID int, data []byte) error { - f.writeHeader(f.flags, opAuthResponse, streamID) + f.writeHeader(f.flags, internal.OpAuthResponse, streamID) f.writeBytes(data) return f.finish() } @@ -1468,50 +1302,50 @@ func (q queryParams) String() string { func (f *framer) writeQueryParams(opts *queryParams) { f.writeConsistency(opts.consistency) - if f.proto == protoVersion1 { + if f.proto == internal.ProtoVersion1 { return } var flags byte if len(opts.values) > 0 { - flags |= flagValues + flags |= internal.FlagValues } if opts.skipMeta { - flags |= flagSkipMetaData + flags |= internal.FlagSkipMetaData } if opts.pageSize > 0 { - flags |= flagPageSize + flags |= internal.FlagPageSize } if len(opts.pagingState) > 0 { - flags |= flagWithPagingState + flags |= internal.FlagWithPagingState } if opts.serialConsistency > 0 { - flags |= flagWithSerialConsistency + flags |= internal.FlagWithSerialConsistency } names := false // protoV3 specific things - if f.proto > protoVersion2 { + if f.proto > internal.ProtoVersion2 { if opts.defaultTimestamp { - flags |= flagDefaultTimestamp + flags |= internal.FlagDefaultTimestamp } if len(opts.values) > 0 && opts.values[0].name != "" { - flags |= flagWithNameValues + flags |= internal.FlagWithNameValues names = true } } if opts.keyspace != "" { - if f.proto > protoVersion4 { - flags |= flagWithKeyspace + if f.proto > internal.ProtoVersion4 { + flags |= internal.FlagWithKeyspace } else { panic(fmt.Errorf("the keyspace can only be set with protocol 5 or higher")) } } - if f.proto > protoVersion4 { + if f.proto > internal.ProtoVersion4 { f.writeUint(uint32(flags)) } else { f.writeByte(flags) @@ -1544,7 +1378,7 @@ func (f *framer) writeQueryParams(opts *queryParams) { f.writeConsistency(Consistency(opts.serialConsistency)) } - if f.proto > protoVersion2 && opts.defaultTimestamp { + if f.proto > internal.ProtoVersion2 && opts.defaultTimestamp { // timestamp in microseconds var ts int64 if opts.defaultTimestampValue != 0 { @@ -1580,7 +1414,7 @@ func (f *framer) writeQueryFrame(streamID int, statement string, params *queryPa if len(customPayload) > 0 { f.payload() } - f.writeHeader(f.flags, opQuery, streamID) + f.writeHeader(f.flags, internal.OpQuery, streamID) f.writeCustomPayload(&customPayload) f.writeLongString(statement) f.writeQueryParams(params) @@ -1618,10 +1452,10 @@ func (f *framer) writeExecuteFrame(streamID int, preparedID []byte, params *quer if len(*customPayload) > 0 { f.payload() } - f.writeHeader(f.flags, opExecute, streamID) + f.writeHeader(f.flags, internal.OpExecute, streamID) f.writeCustomPayload(customPayload) f.writeShortBytes(preparedID) - if f.proto > protoVersion1 { + if f.proto > internal.ProtoVersion1 { f.writeQueryParams(params) } else { n := len(params.values) @@ -1669,7 +1503,7 @@ func (f *framer) writeBatchFrame(streamID int, w *writeBatchFrame, customPayload if len(customPayload) > 0 { f.payload() } - f.writeHeader(f.flags, opBatch, streamID) + f.writeHeader(f.flags, internal.OpBatch, streamID) f.writeCustomPayload(&customPayload) f.writeByte(byte(w.typ)) @@ -1691,13 +1525,13 @@ func (f *framer) writeBatchFrame(streamID int, w *writeBatchFrame, customPayload f.writeShort(uint16(len(b.values))) for j := range b.values { col := b.values[j] - if f.proto > protoVersion2 && col.name != "" { + if f.proto > internal.ProtoVersion2 && col.name != "" { // TODO: move this check into the caller and set a flag on writeBatchFrame // to indicate using named values - if f.proto <= protoVersion5 { + if f.proto <= internal.ProtoVersion5 { return fmt.Errorf("gocql: named query values are not supported in batches, please see https://issues.apache.org/jira/browse/CASSANDRA-10246") } - flags |= flagWithNameValues + flags |= internal.FlagWithNameValues f.writeString(col.name) } if col.isUnset { @@ -1710,15 +1544,15 @@ func (f *framer) writeBatchFrame(streamID int, w *writeBatchFrame, customPayload f.writeConsistency(w.consistency) - if f.proto > protoVersion2 { + if f.proto > internal.ProtoVersion2 { if w.serialConsistency > 0 { - flags |= flagWithSerialConsistency + flags |= internal.FlagWithSerialConsistency } if w.defaultTimestamp { - flags |= flagDefaultTimestamp + flags |= internal.FlagDefaultTimestamp } - if f.proto > protoVersion4 { + if f.proto > internal.ProtoVersion4 { f.writeUint(uint32(flags)) } else { f.writeByte(flags) @@ -1749,7 +1583,7 @@ func (w *writeOptionsFrame) buildFrame(framer *framer, streamID int) error { } func (f *framer) writeOptionsFrame(stream int, _ *writeOptionsFrame) error { - f.writeHeader(f.flags&^flagCompress, opOptions, stream) + f.writeHeader(f.flags&^internal.FlagCompress, internal.OpOptions, stream) return f.finish() } @@ -1762,7 +1596,7 @@ func (w *writeRegisterFrame) buildFrame(framer *framer, streamID int) error { } func (f *framer) writeRegisterFrame(streamID int, w *writeRegisterFrame) error { - f.writeHeader(f.flags, opRegister, streamID) + f.writeHeader(f.flags, internal.OpRegister, streamID) f.writeStringList(w.events) return f.finish() @@ -1940,52 +1774,9 @@ func (f *framer) writeByte(b byte) { f.buf = append(f.buf, b) } -func appendBytes(p []byte, d []byte) []byte { - if d == nil { - return appendInt(p, -1) - } - p = appendInt(p, int32(len(d))) - p = append(p, d...) - return p -} - -func appendShort(p []byte, n uint16) []byte { - return append(p, - byte(n>>8), - byte(n), - ) -} - -func appendInt(p []byte, n int32) []byte { - return append(p, byte(n>>24), - byte(n>>16), - byte(n>>8), - byte(n)) -} - -func appendUint(p []byte, n uint32) []byte { - return append(p, byte(n>>24), - byte(n>>16), - byte(n>>8), - byte(n)) -} - -func appendLong(p []byte, n int64) []byte { - return append(p, - byte(n>>56), - byte(n>>48), - byte(n>>40), - byte(n>>32), - byte(n>>24), - byte(n>>16), - byte(n>>8), - byte(n), - ) -} - func (f *framer) writeCustomPayload(customPayload *map[string][]byte) { if len(*customPayload) > 0 { - if f.proto < protoVersion4 { + if f.proto < internal.ProtoVersion4 { panic("Custom payload is not supported with version V3 or less") } f.writeBytesMap(*customPayload) @@ -1994,19 +1785,19 @@ func (f *framer) writeCustomPayload(customPayload *map[string][]byte) { // these are protocol level binary types func (f *framer) writeInt(n int32) { - f.buf = appendInt(f.buf, n) + f.buf = internal.AppendInt(f.buf, n) } func (f *framer) writeUint(n uint32) { - f.buf = appendUint(f.buf, n) + f.buf = internal.AppendUint(f.buf, n) } func (f *framer) writeShort(n uint16) { - f.buf = appendShort(f.buf, n) + f.buf = internal.AppendShort(f.buf, n) } func (f *framer) writeLong(n int64) { - f.buf = appendLong(f.buf, n) + f.buf = internal.AppendLong(f.buf, n) } func (f *framer) writeString(s string) { diff --git a/frame_test.go b/frame_test.go index 170cba710..e2e6c51ae 100644 --- a/frame_test.go +++ b/frame_test.go @@ -26,6 +26,7 @@ package gocql import ( "bytes" + "github.com/gocql/gocql/internal" "os" "testing" ) @@ -65,7 +66,7 @@ func TestFuzzBugs(t *testing.T) { continue } - framer := newFramer(nil, byte(head.version)) + framer := newFramer(nil, byte(head.Version)) err = framer.readFrame(r, &head) if err != nil { continue @@ -88,8 +89,8 @@ func TestFrameWriteTooLong(t *testing.T) { framer := newFramer(nil, 2) - framer.writeHeader(0, opStartup, 1) - framer.writeBytes(make([]byte, maxFrameSize+1)) + framer.writeHeader(0, internal.OpStartup, 1) + framer.writeBytes(make([]byte, internal.MaxFrameSize+1)) err := framer.finish() if err != ErrFrameTooBig { t.Fatalf("expected to get %v got %v", ErrFrameTooBig, err) @@ -102,16 +103,16 @@ func TestFrameReadTooLong(t *testing.T) { } r := &bytes.Buffer{} - r.Write(make([]byte, maxFrameSize+1)) + r.Write(make([]byte, internal.MaxFrameSize+1)) // write a new header right after this frame to verify that we can read it - r.Write([]byte{0x02, 0x00, 0x00, byte(opReady), 0x00, 0x00, 0x00, 0x00}) + r.Write([]byte{0x02, 0x00, 0x00, byte(internal.OpReady), 0x00, 0x00, 0x00, 0x00}) framer := newFramer(nil, 2) - head := frameHeader{ - version: 2, - op: opReady, - length: r.Len() - 8, + head := internal.FrameHeader{ + Version: 2, + Op: internal.OpReady, + Length: r.Len() - 8, } err := framer.readFrame(r, &head) @@ -123,7 +124,7 @@ func TestFrameReadTooLong(t *testing.T) { if err != nil { t.Fatal(err) } - if head.op != opReady { - t.Fatalf("expected to get header %v got %v", opReady, head.op) + if head.Op != internal.OpReady { + t.Fatalf("expected to get header %v got %v", internal.OpReady, head.Op) } } diff --git a/framer_bench_test.go b/framer_bench_test.go index bce3742c2..54088fcad 100644 --- a/framer_bench_test.go +++ b/framer_bench_test.go @@ -26,6 +26,7 @@ package gocql import ( "compress/gzip" + "github.com/gocql/gocql/internal" "io/ioutil" "os" "testing" @@ -56,10 +57,10 @@ func BenchmarkParseRowsFrame(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { framer := &framer{ - header: &frameHeader{ - version: protoVersion4 | 0x80, - op: opResult, - length: len(data), + header: &internal.FrameHeader{ + Version: internal.ProtoVersion4 | 0x80, + Op: internal.OpResult, + Length: len(data), }, buf: data, } diff --git a/go.mod b/go.mod index 0aea881ec..b982e2c3d 100644 --- a/go.mod +++ b/go.mod @@ -23,7 +23,7 @@ require ( github.com/golang/snappy v0.0.3 github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed github.com/kr/pretty v0.1.0 // indirect - github.com/stretchr/testify v1.3.0 // indirect + github.com/stretchr/testify v1.3.0 gopkg.in/inf.v0 v0.9.1 ) diff --git a/helpers.go b/helpers.go index f2faee9e0..321726146 100644 --- a/helpers.go +++ b/helpers.go @@ -32,6 +32,8 @@ import ( "strings" "time" + "github.com/gocql/gocql/internal" + "gopkg.in/inf.v0" ) @@ -101,10 +103,6 @@ func goType(t TypeInfo) (reflect.Type, error) { } } -func dereference(i interface{}) interface{} { - return reflect.Indirect(reflect.ValueOf(i)).Interface() -} - func getCassandraBaseType(name string) Type { switch name { case "ascii": @@ -176,7 +174,7 @@ func getCassandraType(name string, logger StdLogger) TypeInfo { Elem: getCassandraType(strings.TrimPrefix(name[:len(name)-1], "list<"), logger), } } else if strings.HasPrefix(name, "map<") { - names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "map<")) + names := internal.SplitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "map<")) if len(names) != 2 { logger.Printf("Error parsing map type, it has %d subelements, expecting 2\n", len(names)) return NativeType{ @@ -189,7 +187,7 @@ func getCassandraType(name string, logger StdLogger) TypeInfo { Elem: getCassandraType(names[1], logger), } } else if strings.HasPrefix(name, "tuple<") { - names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "tuple<")) + names := internal.SplitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "tuple<")) types := make([]TypeInfo, len(names)) for i, name := range names { @@ -207,36 +205,8 @@ func getCassandraType(name string, logger StdLogger) TypeInfo { } } -func splitCompositeTypes(name string) []string { - if !strings.Contains(name, "<") { - return strings.Split(name, ", ") - } - var parts []string - lessCount := 0 - segment := "" - for _, char := range name { - if char == ',' && lessCount == 0 { - if segment != "" { - parts = append(parts, strings.TrimSpace(segment)) - } - segment = "" - continue - } - segment += string(char) - if char == '<' { - lessCount++ - } else if char == '>' { - lessCount-- - } - } - if segment != "" { - parts = append(parts, strings.TrimSpace(segment)) - } - return parts -} - func apacheToCassandraType(t string) string { - t = strings.Replace(t, apacheCassandraTypePrefix, "", -1) + t = strings.Replace(t, internal.ApacheCassandraTypePrefix, "", -1) t = strings.Replace(t, "(", "<", -1) t = strings.Replace(t, ")", ">", -1) types := strings.FieldsFunc(t, func(r rune) bool { @@ -250,7 +220,7 @@ func apacheToCassandraType(t string) string { } func getApacheCassandraType(class string) Type { - switch strings.TrimPrefix(class, apacheCassandraTypePrefix) { + switch strings.TrimPrefix(class, internal.ApacheCassandraTypePrefix) { case "AsciiType": return TypeAscii case "LongType": @@ -304,7 +274,7 @@ func getApacheCassandraType(class string) Type { func (r *RowData) rowMap(m map[string]interface{}) { for i, column := range r.Columns { - val := dereference(r.Values[i]) + val := internal.Dereference(r.Values[i]) if valVal := reflect.ValueOf(val); valVal.Kind() == reflect.Slice { valCopy := reflect.MakeSlice(valVal.Type(), valVal.Len(), valVal.Cap()) reflect.Copy(valCopy, valVal) @@ -451,18 +421,9 @@ func (iter *Iter) MapScan(m map[string]interface{}) bool { return false } -func copyBytes(p []byte) []byte { - b := make([]byte, len(p)) - copy(b, p) - return b -} - -var failDNS = false - func LookupIP(host string) ([]net.IP, error) { - if failDNS { + if internal.FailDNS { return nil, &net.DNSError{} } return net.LookupIP(host) - } diff --git a/internal/frame.go b/internal/frame.go new file mode 100644 index 000000000..6139d5ddb --- /dev/null +++ b/internal/frame.go @@ -0,0 +1,216 @@ +package internal + +import "fmt" + +type NamedValue struct { + Name string + Value interface{} +} + +type UnsetColumn struct{} + +func ReadInt(p []byte) int32 { + return int32(p[0])<<24 | int32(p[1])<<16 | int32(p[2])<<8 | int32(p[3]) +} + +func AppendBytes(p []byte, d []byte) []byte { + if d == nil { + return AppendInt(p, -1) + } + p = AppendInt(p, int32(len(d))) + p = append(p, d...) + return p +} + +func AppendShort(p []byte, n uint16) []byte { + return append(p, + byte(n>>8), + byte(n), + ) +} + +func AppendInt(p []byte, n int32) []byte { + return append(p, byte(n>>24), + byte(n>>16), + byte(n>>8), + byte(n)) +} + +func AppendUint(p []byte, n uint32) []byte { + return append(p, byte(n>>24), + byte(n>>16), + byte(n>>8), + byte(n)) +} + +func AppendLong(p []byte, n int64) []byte { + return append(p, + byte(n>>56), + byte(n>>48), + byte(n>>40), + byte(n>>32), + byte(n>>24), + byte(n>>16), + byte(n>>8), + byte(n), + ) +} + +const ( + ProtoDirectionMask = 0x80 + ProtoVersionMask = 0x7F + ProtoVersion1 = 0x01 + ProtoVersion2 = 0x02 + ProtoVersion3 = 0x03 + ProtoVersion4 = 0x04 + ProtoVersion5 = 0x05 + + MaxFrameSize = 256 * 1024 * 1024 +) + +type ProtoVersion byte + +func (p ProtoVersion) Request() bool { + return p&ProtoDirectionMask == 0x00 +} + +func (p ProtoVersion) Response() bool { + return p&ProtoDirectionMask == 0x80 +} + +func (p ProtoVersion) Version() byte { + return byte(p) & ProtoVersionMask +} + +func (p ProtoVersion) String() string { + dir := "REQ" + if p.Response() { + dir = "RESP" + } + + return fmt.Sprintf("[version=%d direction=%s]", p.Version(), dir) +} + +type FrameOp byte + +const ( + // header ops + OpError FrameOp = 0x00 + OpStartup FrameOp = 0x01 + OpReady FrameOp = 0x02 + OpAuthenticate FrameOp = 0x03 + OpOptions FrameOp = 0x05 + OpSupported FrameOp = 0x06 + OpQuery FrameOp = 0x07 + OpResult FrameOp = 0x08 + OpPrepare FrameOp = 0x09 + OpExecute FrameOp = 0x0A + OpRegister FrameOp = 0x0B + OpEvent FrameOp = 0x0C + OpBatch FrameOp = 0x0D + OpAuthChallenge FrameOp = 0x0E + OpAuthResponse FrameOp = 0x0F + OpAuthSuccess FrameOp = 0x10 +) + +func (f FrameOp) String() string { + switch f { + case OpError: + return "ERROR" + case OpStartup: + return "STARTUP" + case OpReady: + return "READY" + case OpAuthenticate: + return "AUTHENTICATE" + case OpOptions: + return "OPTIONS" + case OpSupported: + return "SUPPORTED" + case OpQuery: + return "QUERY" + case OpResult: + return "RESULT" + case OpPrepare: + return "PREPARE" + case OpExecute: + return "EXECUTE" + case OpRegister: + return "REGISTER" + case OpEvent: + return "EVENT" + case OpBatch: + return "BATCH" + case OpAuthChallenge: + return "AUTH_CHALLENGE" + case OpAuthResponse: + return "AUTH_RESPONSE" + case OpAuthSuccess: + return "AUTH_SUCCESS" + default: + return fmt.Sprintf("UNKNOWN_OP_%d", f) + } +} + +const ( + // result kind + ResultKindVoid = 1 + ResultKindRows = 2 + ResultKindKeyspace = 3 + ResultKindPrepared = 4 + ResultKindSchemaChanged = 5 + + // rows flags + FlagGlobalTableSpec int = 0x01 + FlagHasMorePages int = 0x02 + FlagNoMetaData int = 0x04 + + // query flags + FlagValues byte = 0x01 + FlagSkipMetaData byte = 0x02 + FlagPageSize byte = 0x04 + FlagWithPagingState byte = 0x08 + FlagWithSerialConsistency byte = 0x10 + FlagDefaultTimestamp byte = 0x20 + FlagWithNameValues byte = 0x40 + FlagWithKeyspace byte = 0x80 + + // prepare flags + FlagWithPreparedKeyspace uint32 = 0x01 + + // header flags + FlagCompress byte = 0x01 + FlagTracing byte = 0x02 + FlagCustomPayload byte = 0x04 + FlagWarning byte = 0x08 + FlagBetaProtocol byte = 0x10 +) + +const ( + ApacheCassandraTypePrefix = "org.apache.cassandra.db.marshal." +) + +const MaxFrameHeaderSize = 9 + +type Frame interface { + Header() FrameHeader +} + +type FrameHeader struct { + Version ProtoVersion + Flags byte + Stream int + Op FrameOp + Length int + Warnings []string +} + +func (f FrameHeader) String() string { + return fmt.Sprintf("[header version=%s flags=0x%x stream=%d op=%s length=%d]", f.Version, f.Flags, f.Stream, f.Op, f.Length) +} + +func (f FrameHeader) Header() FrameHeader { + return f +} + +const DefaultBufSize = 128 diff --git a/internal/helpers.go b/internal/helpers.go new file mode 100644 index 000000000..831e1230c --- /dev/null +++ b/internal/helpers.go @@ -0,0 +1,46 @@ +package internal + +import ( + "reflect" + "strings" +) + +func SplitCompositeTypes(name string) []string { + if !strings.Contains(name, "<") { + return strings.Split(name, ", ") + } + var parts []string + lessCount := 0 + segment := "" + for _, char := range name { + if char == ',' && lessCount == 0 { + if segment != "" { + parts = append(parts, strings.TrimSpace(segment)) + } + segment = "" + continue + } + segment += string(char) + if char == '<' { + lessCount++ + } else if char == '>' { + lessCount-- + } + } + if segment != "" { + parts = append(parts, strings.TrimSpace(segment)) + } + return parts +} + +func CopyBytes(p []byte) []byte { + b := make([]byte, len(p)) + copy(b, p) + return b +} + +func Dereference(i interface{}) interface{} { + return reflect.Indirect(reflect.ValueOf(i)).Interface() +} + +var FailDNS = false diff --git a/internal/marshal.go b/internal/marshal.go new file mode 100644 index 000000000..d9abad36c --- /dev/null +++ b/internal/marshal.go @@ -0,0 +1,779 @@ +package internal + +import ( + "encoding/binary" + "errors" + "fmt" + "gopkg.in/inf.v0" + "math" + "math/big" + "math/bits" + "net" + "reflect" + "strconv" + "time" +) + +var ( + bigOne = big.NewInt(1) + EmptyValue reflect.Value +) + +const MillisecondsInADay int64 = 24 * 60 * 60 * 1000 + +func EncInt(x int32) []byte { + return []byte{byte(x >> 24), byte(x >> 16), byte(x >> 8), byte(x)} +} + +func DecInt(x []byte) int32 { + if len(x) != 4 { + return 0 + } + return int32(x[0])<<24 | int32(x[1])<<16 | int32(x[2])<<8 | int32(x[3]) +} + +func EncShort(x int16) []byte { + p := make([]byte, 2) + p[0] = byte(x >> 8) + p[1] = byte(x) + return p +} + +func DecShort(p []byte) int16 { + if len(p) != 2 { + return 0 + } + return int16(p[0])<<8 | int16(p[1]) +} + +func DecTiny(p []byte) int8 { + if len(p) != 1 { + return 0 + } + return int8(p[0]) +} + +func EncBigInt(x int64) []byte { + return []byte{byte(x >> 56), byte(x >> 48), byte(x >> 40), byte(x >> 32), + byte(x >> 24), byte(x >> 16), byte(x >> 8), byte(x)} +} + +func BytesToInt64(data []byte) (ret int64) { + for i := range data { + ret |= int64(data[i]) << (8 * uint(len(data)-i-1)) + } + return ret +} + +func BytesToUint64(data []byte) (ret uint64) { + for i := range data { + ret |= uint64(data[i]) << (8 * uint(len(data)-i-1)) + } + return ret +} + +func DecBigInt(data []byte) int64 { + if len(data) != 8 { + return 0 + } + return int64(data[0])<<56 | int64(data[1])<<48 | + int64(data[2])<<40 | int64(data[3])<<32 | + int64(data[4])<<24 | int64(data[5])<<16 | + int64(data[6])<<8 | int64(data[7]) +} + +func EncBool(v bool) []byte { + if v { + return []byte{1} + } + return []byte{0} +} + +func DecBool(v []byte) bool { + if len(v) == 0 { + return false + } + return v[0] != 0 +} + +// decBigInt2C sets the value of n to the big-endian two's complement +// value stored in the given data. If data[0]&80 != 0, the number +// is negative. If data is empty, the result will be 0. +func DecBigInt2C(data []byte, n *big.Int) *big.Int { + if n == nil { + n = new(big.Int) + } + n.SetBytes(data) + if len(data) > 0 && data[0]&0x80 > 0 { + n.Sub(n, new(big.Int).Lsh(bigOne, uint(len(data))*8)) + } + return n +} + +// EncBigInt2C returns the big-endian two's complement +// form of n. +func EncBigInt2C(n *big.Int) []byte { + switch n.Sign() { + case 0: + return []byte{0} + case 1: + b := n.Bytes() + if b[0]&0x80 > 0 { + b = append([]byte{0}, b...) + } + return b + case -1: + length := uint(n.BitLen()/8+1) * 8 + b := new(big.Int).Add(n, new(big.Int).Lsh(bigOne, length)).Bytes() + // When the most significant bit is on a byte + // boundary, we can get some extra significant + // bits, so strip them off when that happens. + if len(b) >= 2 && b[0] == 0xff && b[1]&0x80 != 0 { + b = b[1:] + } + return b + } + return nil +} + +func DecVints(data []byte) (int32, int32, int64, error) { + month, i, err := DecVint(data, 0) + if err != nil { + return 0, 0, 0, fmt.Errorf("failed to extract month: %s", err.Error()) + } + days, i, err := DecVint(data, i) + if err != nil { + return 0, 0, 0, fmt.Errorf("failed to extract days: %s", err.Error()) + } + nanos, _, err := DecVint(data, i) + if err != nil { + return 0, 0, 0, fmt.Errorf("failed to extract nanoseconds: %s", err.Error()) + } + return int32(month), int32(days), nanos, err +} + +func DecVint(data []byte, start int) (int64, int, error) { + if len(data) <= start { + return 0, 0, errors.New("unexpected eof") + } + firstByte := data[start] + if firstByte&0x80 == 0 { + return decIntZigZag(uint64(firstByte)), start + 1, nil + } + numBytes := bits.LeadingZeros32(uint32(^firstByte)) - 24 + ret := uint64(firstByte & (0xff >> uint(numBytes))) + if len(data) < start+numBytes+1 { + return 0, 0, fmt.Errorf("data expect to have %d bytes, but it has only %d", start+numBytes+1, len(data)) + } + for i := start; i < start+numBytes; i++ { + ret <<= 8 + ret |= uint64(data[i+1] & 0xff) + } + return decIntZigZag(ret), start + numBytes + 1, nil +} + +func decIntZigZag(n uint64) int64 { + return int64((n >> 1) ^ -(n & 1)) +} + +func encIntZigZag(n int64) uint64 { + return uint64((n >> 63) ^ (n << 1)) +} + +func EncVints(months int32, seconds int32, nanos int64) []byte { + buf := append(EncVint(int64(months)), EncVint(int64(seconds))...) + return append(buf, EncVint(nanos)...) +} + +func EncVint(v int64) []byte { + vEnc := encIntZigZag(v) + lead0 := bits.LeadingZeros64(vEnc) + numBytes := (639 - lead0*9) >> 6 + + // It can be 1 or 0 is v ==0 + if numBytes <= 1 { + return []byte{byte(vEnc)} + } + extraBytes := numBytes - 1 + var buf = make([]byte, numBytes) + for i := extraBytes; i >= 0; i-- { + buf[i] = byte(vEnc) + vEnc >>= 8 + } + buf[0] |= byte(^(0xff >> uint(extraBytes))) + return buf +} + +// TODO: move to internal +func ReadBytes(p []byte) ([]byte, []byte) { + // TODO: really should use a framer + size := ReadInt(p) + p = p[4:] + if size < 0 { + return nil, p + } + return p[:size], p[size:] +} + +func MarshalVarchar(info, value interface{}) ([]byte, error) { + switch v := value.(type) { + case UnsetColumn: + return nil, nil + case string: + return []byte(v), nil + case []byte: + return v, nil + } + + if value == nil { + return nil, nil + } + + rv := reflect.ValueOf(value) + t := rv.Type() + k := t.Kind() + switch { + case k == reflect.String: + return []byte(rv.String()), nil + case k == reflect.Slice && t.Elem().Kind() == reflect.Uint8: + return rv.Bytes(), nil + } + return nil, fmt.Errorf("can not marshal %T into %s", value, info) +} + +func MarshalBool(info, value interface{}) ([]byte, error) { + switch v := value.(type) { + case UnsetColumn: + return nil, nil + case bool: + return EncBool(v), nil + } + + if value == nil { + return nil, nil + } + + rv := reflect.ValueOf(value) + switch rv.Type().Kind() { + case reflect.Bool: + return EncBool(rv.Bool()), nil + } + return nil, fmt.Errorf("can not marshal %T into %s", value, info) +} + +func MarshalTinyInt(info, value interface{}) ([]byte, error) { + switch v := value.(type) { + case UnsetColumn: + return nil, nil + case int8: + return []byte{byte(v)}, nil + case uint8: + return []byte{byte(v)}, nil + case int16: + if v > math.MaxInt8 || v < math.MinInt8 { + return nil, fmt.Errorf("marshal tinyint: value %d out of range", v) + } + return []byte{byte(v)}, nil + case uint16: + if v > math.MaxUint8 { + return nil, fmt.Errorf("marshal tinyint: value %d out of range", v) + } + return []byte{byte(v)}, nil + case int: + if v > math.MaxInt8 || v < math.MinInt8 { + return nil, fmt.Errorf("marshal tinyint: value %d out of range", v) + } + return []byte{byte(v)}, nil + case int32: + if v > math.MaxInt8 || v < math.MinInt8 { + return nil, fmt.Errorf("marshal tinyint: value %d out of range", v) + } + return []byte{byte(v)}, nil + case int64: + if v > math.MaxInt8 || v < math.MinInt8 { + return nil, fmt.Errorf("marshal tinyint: value %d out of range", v) + } + return []byte{byte(v)}, nil + case uint: + if v > math.MaxUint8 { + return nil, fmt.Errorf("marshal tinyint: value %d out of range", v) + } + return []byte{byte(v)}, nil + case uint32: + if v > math.MaxUint8 { + return nil, fmt.Errorf("marshal tinyint: value %d out of range", v) + } + return []byte{byte(v)}, nil + case uint64: + if v > math.MaxUint8 { + return nil, fmt.Errorf("marshal tinyint: value %d out of range", v) + } + return []byte{byte(v)}, nil + case string: + n, err := strconv.ParseInt(v, 10, 8) + if err != nil { + return nil, fmt.Errorf("can not marshal %T into %s: %v", value, info, err) + } + return []byte{byte(n)}, nil + } + + if value == nil { + return nil, nil + } + + switch rv := reflect.ValueOf(value); rv.Type().Kind() { + case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8: + v := rv.Int() + if v > math.MaxInt8 || v < math.MinInt8 { + return nil, fmt.Errorf("marshal tinyint: value %d out of range", v) + } + return []byte{byte(v)}, nil + case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: + v := rv.Uint() + if v > math.MaxUint8 { + return nil, fmt.Errorf("marshal tinyint: value %d out of range", v) + } + return []byte{byte(v)}, nil + case reflect.Ptr: + if rv.IsNil() { + return nil, nil + } + } + + return nil, fmt.Errorf("can not marshal %T into %s", value, info) +} + +func MarshalSmallInt(info, value interface{}) ([]byte, error) { + switch v := value.(type) { + case UnsetColumn: + return nil, nil + case int16: + return EncShort(v), nil + case uint16: + return EncShort(int16(v)), nil + case int8: + return EncShort(int16(v)), nil + case uint8: + return EncShort(int16(v)), nil + case int: + if v > math.MaxInt16 || v < math.MinInt16 { + return nil, fmt.Errorf("marshal smallint: value %d out of range", v) + } + return EncShort(int16(v)), nil + case int32: + if v > math.MaxInt16 || v < math.MinInt16 { + return nil, fmt.Errorf("marshal smallint: value %d out of range", v) + } + return EncShort(int16(v)), nil + case int64: + if v > math.MaxInt16 || v < math.MinInt16 { + return nil, fmt.Errorf("marshal smallint: value %d out of range", v) + } + return EncShort(int16(v)), nil + case uint: + if v > math.MaxUint16 { + return nil, fmt.Errorf("marshal smallint: value %d out of range", v) + } + return EncShort(int16(v)), nil + case uint32: + if v > math.MaxUint16 { + return nil, fmt.Errorf("marshal smallint: value %d out of range", v) + } + return EncShort(int16(v)), nil + case uint64: + if v > math.MaxUint16 { + return nil, fmt.Errorf("marshal smallint: value %d out of range", v) + } + return EncShort(int16(v)), nil + case string: + n, err := strconv.ParseInt(v, 10, 16) + if err != nil { + return nil, fmt.Errorf("can not marshal %T into %s: %v", value, info, err) + } + return EncShort(int16(n)), nil + } + + if value == nil { + return nil, nil + } + + switch rv := reflect.ValueOf(value); rv.Type().Kind() { + case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8: + v := rv.Int() + if v > math.MaxInt16 || v < math.MinInt16 { + return nil, fmt.Errorf("marshal smallint: value %d out of range", v) + } + return EncShort(int16(v)), nil + case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: + v := rv.Uint() + if v > math.MaxUint16 { + return nil, fmt.Errorf("marshal smallint: value %d out of range", v) + } + return EncShort(int16(v)), nil + case reflect.Ptr: + if rv.IsNil() { + return nil, nil + } + } + + return nil, fmt.Errorf("can not marshal %T into %s", value, info) +} + +func MarshalInt(info, value interface{}) ([]byte, error) { + switch v := value.(type) { + case UnsetColumn: + return nil, nil + case int: + if v > math.MaxInt32 || v < math.MinInt32 { + return nil, fmt.Errorf("marshal int: value %d out of range", v) + } + return EncInt(int32(v)), nil + case uint: + if v > math.MaxUint32 { + return nil, fmt.Errorf("marshal int: value %d out of range", v) + } + return EncInt(int32(v)), nil + case int64: + if v > math.MaxInt32 || v < math.MinInt32 { + return nil, fmt.Errorf("marshal int: value %d out of range", v) + } + return EncInt(int32(v)), nil + case uint64: + if v > math.MaxUint32 { + return nil, fmt.Errorf("marshal int: value %d out of range", v) + } + return EncInt(int32(v)), nil + case int32: + return EncInt(v), nil + case uint32: + return EncInt(int32(v)), nil + case int16: + return EncInt(int32(v)), nil + case uint16: + return EncInt(int32(v)), nil + case int8: + return EncInt(int32(v)), nil + case uint8: + return EncInt(int32(v)), nil + case string: + i, err := strconv.ParseInt(v, 10, 32) + if err != nil { + return nil, fmt.Errorf("can not marshal string to int: %s", err) + } + return EncInt(int32(i)), nil + } + + if value == nil { + return nil, nil + } + + switch rv := reflect.ValueOf(value); rv.Type().Kind() { + case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8: + v := rv.Int() + if v > math.MaxInt32 || v < math.MinInt32 { + return nil, fmt.Errorf("marshal int: value %d out of range", v) + } + return EncInt(int32(v)), nil + case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: + v := rv.Uint() + if v > math.MaxInt32 { + return nil, fmt.Errorf("marshal int: value %d out of range", v) + } + return EncInt(int32(v)), nil + case reflect.Ptr: + if rv.IsNil() { + return nil, nil + } + } + + return nil, fmt.Errorf("can not marshal %T into %s", value, info) +} + +func MarshalBigInt(info, value interface{}) ([]byte, error) { + switch v := value.(type) { + case UnsetColumn: + return nil, nil + case int: + return EncBigInt(int64(v)), nil + case uint: + if uint64(v) > math.MaxInt64 { + return nil, fmt.Errorf("marshal bigint: value %d out of range", v) + } + return EncBigInt(int64(v)), nil + case int64: + return EncBigInt(v), nil + case uint64: + return EncBigInt(int64(v)), nil + case int32: + return EncBigInt(int64(v)), nil + case uint32: + return EncBigInt(int64(v)), nil + case int16: + return EncBigInt(int64(v)), nil + case uint16: + return EncBigInt(int64(v)), nil + case int8: + return EncBigInt(int64(v)), nil + case uint8: + return EncBigInt(int64(v)), nil + case big.Int: + return EncBigInt2C(&v), nil + case string: + i, err := strconv.ParseInt(value.(string), 10, 64) + if err != nil { + return nil, fmt.Errorf("can not marshal string to bigint: %s", err) + } + return EncBigInt(i), nil + } + + if value == nil { + return nil, nil + } + + rv := reflect.ValueOf(value) + switch rv.Type().Kind() { + case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8: + v := rv.Int() + return EncBigInt(v), nil + case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: + v := rv.Uint() + if v > math.MaxInt64 { + return nil, fmt.Errorf("marshal bigint: value %d out of range", v) + } + return EncBigInt(int64(v)), nil + } + return nil, fmt.Errorf("can not marshal %T into %s", value, info) +} + +func MarshalFloat(info, value interface{}) ([]byte, error) { + switch v := value.(type) { + case UnsetColumn: + return nil, nil + case float32: + return EncInt(int32(math.Float32bits(v))), nil + } + + if value == nil { + return nil, nil + } + + rv := reflect.ValueOf(value) + switch rv.Type().Kind() { + case reflect.Float32: + return EncInt(int32(math.Float32bits(float32(rv.Float())))), nil + } + return nil, fmt.Errorf("can not marshal %T into %s", value, info) +} + +func MarshalDouble(info, value interface{}) ([]byte, error) { + switch v := value.(type) { + case UnsetColumn: + return nil, nil + case float64: + return EncBigInt(int64(math.Float64bits(v))), nil + } + if value == nil { + return nil, nil + } + rv := reflect.ValueOf(value) + switch rv.Type().Kind() { + case reflect.Float64: + return EncBigInt(int64(math.Float64bits(rv.Float()))), nil + } + return nil, fmt.Errorf("can not marshal %T into %s", value, info) +} + +func MarshalDecimal(info, value interface{}) ([]byte, error) { + if value == nil { + return nil, nil + } + + switch v := value.(type) { + case UnsetColumn: + return nil, nil + case inf.Dec: + unscaled := EncBigInt2C(v.UnscaledBig()) + if unscaled == nil { + return nil, fmt.Errorf("can not marshal %T into %s", value, info) + } + + buf := make([]byte, 4+len(unscaled)) + copy(buf[0:4], EncInt(int32(v.Scale()))) + copy(buf[4:], unscaled) + return buf, nil + } + return nil, fmt.Errorf("can not marshal %T into %s", value, info) +} + +func MarshalTime(info, value interface{}) ([]byte, error) { + switch v := value.(type) { + case UnsetColumn: + return nil, nil + case int64: + return EncBigInt(v), nil + case time.Duration: + return EncBigInt(v.Nanoseconds()), nil + } + + if value == nil { + return nil, nil + } + + rv := reflect.ValueOf(value) + switch rv.Type().Kind() { + case reflect.Int64: + return EncBigInt(rv.Int()), nil + } + return nil, fmt.Errorf("can not marshal %T into %s", value, info) +} + +func MarshalTimestamp(info, value interface{}) ([]byte, error) { + switch v := value.(type) { + case UnsetColumn: + return nil, nil + case int64: + return EncBigInt(v), nil + case time.Time: + if v.IsZero() { + return []byte{}, nil + } + x := int64(v.UTC().Unix()*1e3) + int64(v.UTC().Nanosecond()/1e6) + return EncBigInt(x), nil + } + + if value == nil { + return nil, nil + } + + rv := reflect.ValueOf(value) + switch rv.Type().Kind() { + case reflect.Int64: + return EncBigInt(rv.Int()), nil + } + return nil, fmt.Errorf("can not marshal %T into %s", value, info) +} + +func MarshalVarint(info, value interface{}) ([]byte, error) { + var ( + retBytes []byte + err error + ) + + switch v := value.(type) { + case UnsetColumn: + return nil, nil + case uint64: + if v > uint64(math.MaxInt64) { + retBytes = make([]byte, 9) + binary.BigEndian.PutUint64(retBytes[1:], v) + } else { + retBytes = make([]byte, 8) + binary.BigEndian.PutUint64(retBytes, v) + } + default: + retBytes, err = MarshalBigInt(info, value) + } + + if err == nil { + // trim down to most significant byte + i := 0 + for ; i < len(retBytes)-1; i++ { + b0 := retBytes[i] + if b0 != 0 && b0 != 0xFF { + break + } + + b1 := retBytes[i+1] + if b0 == 0 && b1 != 0 { + if b1&0x80 == 0 { + i++ + } + break + } + + if b0 == 0xFF && b1 != 0xFF { + if b1&0x80 > 0 { + i++ + } + break + } + } + retBytes = retBytes[i:] + } + + return retBytes, err +} + +func MarshalInet(info, value interface{}) ([]byte, error) { + // we return either the 4 or 16 byte representation of an + // ip address here otherwise the db value will be prefixed + // with the remaining byte values e.g. ::ffff:127.0.0.1 and not 127.0.0.1 + switch val := value.(type) { + case UnsetColumn: + return nil, nil + case net.IP: + t := val.To4() + if t == nil { + return val.To16(), nil + } + return t, nil + case string: + b := net.ParseIP(val) + if b != nil { + t := b.To4() + if t == nil { + return b.To16(), nil + } + return t, nil + } + return nil, fmt.Errorf("cannot marshal. invalid ip string %s", val) + } + + if value == nil { + return nil, nil + } + + return nil, fmt.Errorf("cannot marshal %T into %s", value, info) +} + +func MarshalDate(info, value interface{}) ([]byte, error) { + var timestamp int64 + switch v := value.(type) { + case UnsetColumn: + return nil, nil + case int64: + timestamp = v + x := timestamp/MillisecondsInADay + int64(1<<31) + return EncInt(int32(x)), nil + case time.Time: + if v.IsZero() { + return []byte{}, nil + } + timestamp = int64(v.UTC().Unix()*1e3) + int64(v.UTC().Nanosecond()/1e6) + x := timestamp/MillisecondsInADay + int64(1<<31) + return EncInt(int32(x)), nil + case *time.Time: + if v.IsZero() { + return []byte{}, nil + } + timestamp = int64(v.UTC().Unix()*1e3) + int64(v.UTC().Nanosecond()/1e6) + x := timestamp/MillisecondsInADay + int64(1<<31) + return EncInt(int32(x)), nil + case string: + if v == "" { + return []byte{}, nil + } + t, err := time.Parse("2006-01-02", v) + if err != nil { + return nil, fmt.Errorf("can not marshal %T into %s, date layout must be '2006-01-02'", value, info) + } + timestamp = int64(t.UTC().Unix()*1e3) + int64(t.UTC().Nanosecond()/1e6) + x := timestamp/MillisecondsInADay + int64(1<<31) + return EncInt(int32(x)), nil + } + + if value == nil { + return nil, nil + } + return nil, fmt.Errorf("can not marshal %T into %s", value, info) +} diff --git a/internal/session.go b/internal/session.go new file mode 100644 index 000000000..1777fadd3 --- /dev/null +++ b/internal/session.go @@ -0,0 +1,12 @@ +package internal + +import "strings" + +// TODO: move to internal +func IsUseStatement(stmt string) bool { + if len(stmt) < 3 { + return false + } + + return strings.EqualFold(stmt[0:3], "use") +} diff --git a/internal/token.go b/internal/token.go new file mode 100644 index 000000000..0ad52519f --- /dev/null +++ b/internal/token.go @@ -0,0 +1,111 @@ +package internal + +import ( + "crypto/md5" + "fmt" + "github.com/gocql/gocql/internal/murmur" + "math/big" + "strconv" +) + +// a Token Partitioner +type Partitioner interface { + Name() string + Hash([]byte) Token + ParseString(string) Token +} + +// a Token +type Token interface { + fmt.Stringer + Less(Token) bool +} + +// murmur3 partitioner and Token +type Murmur3Partitioner struct{} +type Murmur3Token int64 + +func (p Murmur3Partitioner) Name() string { + return "Murmur3Partitioner" +} + +func (p Murmur3Partitioner) Hash(partitionKey []byte) Token { + h1 := murmur.Murmur3H1(partitionKey) + return Murmur3Token(h1) +} + +// murmur3 little-endian, 128-bit hash, but returns only h1 +func (p Murmur3Partitioner) ParseString(str string) Token { + val, _ := strconv.ParseInt(str, 10, 64) + return Murmur3Token(val) +} + +func (m Murmur3Token) String() string { + return strconv.FormatInt(int64(m), 10) +} + +func (m Murmur3Token) Less(Token Token) bool { + return m < Token.(Murmur3Token) +} + +// order preserving partitioner and Token +type OrderedPartitioner struct{} +type OrderedToken string + +func (p OrderedPartitioner) Name() string { + return "OrderedPartitioner" +} + +func (p OrderedPartitioner) Hash(partitionKey []byte) Token { + // the partition key is the Token + return OrderedToken(partitionKey) +} + +func (p OrderedPartitioner) ParseString(str string) Token { + return OrderedToken(str) +} + +func (o OrderedToken) String() string { + return string(o) +} + +func (o OrderedToken) Less(Token Token) bool { + return o < Token.(OrderedToken) +} + +// random partitioner and Token +type RandomPartitioner struct{} +type RandomToken big.Int + +func (r RandomPartitioner) Name() string { + return "RandomPartitioner" +} + +// 2 ** 128 +var maxHashInt, _ = new(big.Int).SetString("340282366920938463463374607431768211456", 10) + +func (p RandomPartitioner) Hash(partitionKey []byte) Token { + sum := md5.Sum(partitionKey) + val := new(big.Int) + val.SetBytes(sum[:]) + if sum[0] > 127 { + val.Sub(val, maxHashInt) + val.Abs(val) + } + + return (*RandomToken)(val) +} + +func (p RandomPartitioner) ParseString(str string) Token { + val := new(big.Int) + val.SetString(str, 10) + return (*RandomToken)(val) +} + +func (r *RandomToken) String() string { + return (*big.Int)(r).String() +} + +func (r *RandomToken) Less(Token Token) bool { + return -1 == (*big.Int)(r).Cmp((*big.Int)(Token.(*RandomToken))) +} diff --git a/marshal.go b/marshal.go index 4d0adb923..15c693a84 100644 --- a/marshal.go +++ b/marshal.go @@ -29,21 +29,16 @@ import ( "encoding/binary" "errors" "fmt" + "gopkg.in/inf.v0" "math" "math/big" - "math/bits" "net" "reflect" "strconv" "strings" "time" - "gopkg.in/inf.v0" -) - -var ( - bigOne = big.NewInt(1) - emptyValue reflect.Value + "github.com/gocql/gocql/internal" ) var ( @@ -111,7 +106,7 @@ type Unmarshaler interface { // duration | gocql.Duration | // duration | string | parsed with time.ParseDuration func Marshal(info TypeInfo, value interface{}) ([]byte, error) { - if info.Version() < protoVersion1 { + if info.Version() < internal.ProtoVersion1 { panic("protocol version not set") } @@ -129,29 +124,32 @@ func Marshal(info TypeInfo, value interface{}) ([]byte, error) { return v.MarshalCQL(info) } + // TODO: move to internal + // Notice: a lot of marshal functions could be moved to internal package, + // if the Marshaler case and internal.UnsetColumn cases will be moved to this level switch info.Type() { case TypeVarchar, TypeAscii, TypeBlob, TypeText: - return marshalVarchar(info, value) + return internal.MarshalVarchar(info, value) case TypeBoolean: - return marshalBool(info, value) + return internal.MarshalBool(info, value) case TypeTinyInt: - return marshalTinyInt(info, value) + return internal.MarshalTinyInt(info, value) case TypeSmallInt: - return marshalSmallInt(info, value) + return internal.MarshalSmallInt(info, value) case TypeInt: - return marshalInt(info, value) + return internal.MarshalInt(info, value) case TypeBigInt, TypeCounter: - return marshalBigInt(info, value) + return internal.MarshalBigInt(info, value) case TypeFloat: - return marshalFloat(info, value) + return internal.MarshalFloat(info, value) case TypeDouble: - return marshalDouble(info, value) + return internal.MarshalDouble(info, value) case TypeDecimal: - return marshalDecimal(info, value) + return internal.MarshalDecimal(info, value) case TypeTime: - return marshalTime(info, value) + return internal.MarshalTime(info, value) case TypeTimestamp: - return marshalTimestamp(info, value) + return internal.MarshalTimestamp(info, value) case TypeList, TypeSet: return marshalList(info, value) case TypeMap: @@ -159,15 +157,15 @@ func Marshal(info TypeInfo, value interface{}) ([]byte, error) { case TypeUUID, TypeTimeUUID: return marshalUUID(info, value) case TypeVarint: - return marshalVarint(info, value) + return internal.MarshalVarint(info, value) case TypeInet: - return marshalInet(info, value) + return internal.MarshalInet(info, value) case TypeTuple: return marshalTuple(info, value) case TypeUDT: return marshalUDT(info, value) case TypeDate: - return marshalDate(info, value) + return internal.MarshalDate(info, value) case TypeDuration: return marshalDuration(info, value) } @@ -308,34 +306,6 @@ func unmarshalNullable(info TypeInfo, data []byte, value interface{}) error { return Unmarshal(info, data, newValue.Interface()) } -func marshalVarchar(info TypeInfo, value interface{}) ([]byte, error) { - switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) - case unsetColumn: - return nil, nil - case string: - return []byte(v), nil - case []byte: - return v, nil - } - - if value == nil { - return nil, nil - } - - rv := reflect.ValueOf(value) - t := rv.Type() - k := t.Kind() - switch { - case k == reflect.String: - return []byte(rv.String()), nil - case k == reflect.Slice && t.Elem().Kind() == reflect.Uint8: - return rv.Bytes(), nil - } - return nil, marshalErrorf("can not marshal %T into %s", value, info) -} - func unmarshalVarchar(info TypeInfo, data []byte, value interface{}) error { switch v := value.(type) { case Unmarshaler: @@ -375,363 +345,20 @@ func unmarshalVarchar(info TypeInfo, data []byte, value interface{}) error { return unmarshalErrorf("can not unmarshal %s into %T", info, value) } -func marshalSmallInt(info TypeInfo, value interface{}) ([]byte, error) { - switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) - case unsetColumn: - return nil, nil - case int16: - return encShort(v), nil - case uint16: - return encShort(int16(v)), nil - case int8: - return encShort(int16(v)), nil - case uint8: - return encShort(int16(v)), nil - case int: - if v > math.MaxInt16 || v < math.MinInt16 { - return nil, marshalErrorf("marshal smallint: value %d out of range", v) - } - return encShort(int16(v)), nil - case int32: - if v > math.MaxInt16 || v < math.MinInt16 { - return nil, marshalErrorf("marshal smallint: value %d out of range", v) - } - return encShort(int16(v)), nil - case int64: - if v > math.MaxInt16 || v < math.MinInt16 { - return nil, marshalErrorf("marshal smallint: value %d out of range", v) - } - return encShort(int16(v)), nil - case uint: - if v > math.MaxUint16 { - return nil, marshalErrorf("marshal smallint: value %d out of range", v) - } - return encShort(int16(v)), nil - case uint32: - if v > math.MaxUint16 { - return nil, marshalErrorf("marshal smallint: value %d out of range", v) - } - return encShort(int16(v)), nil - case uint64: - if v > math.MaxUint16 { - return nil, marshalErrorf("marshal smallint: value %d out of range", v) - } - return encShort(int16(v)), nil - case string: - n, err := strconv.ParseInt(v, 10, 16) - if err != nil { - return nil, marshalErrorf("can not marshal %T into %s: %v", value, info, err) - } - return encShort(int16(n)), nil - } - - if value == nil { - return nil, nil - } - - switch rv := reflect.ValueOf(value); rv.Type().Kind() { - case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8: - v := rv.Int() - if v > math.MaxInt16 || v < math.MinInt16 { - return nil, marshalErrorf("marshal smallint: value %d out of range", v) - } - return encShort(int16(v)), nil - case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: - v := rv.Uint() - if v > math.MaxUint16 { - return nil, marshalErrorf("marshal smallint: value %d out of range", v) - } - return encShort(int16(v)), nil - case reflect.Ptr: - if rv.IsNil() { - return nil, nil - } - } - - return nil, marshalErrorf("can not marshal %T into %s", value, info) -} - -func marshalTinyInt(info TypeInfo, value interface{}) ([]byte, error) { - switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) - case unsetColumn: - return nil, nil - case int8: - return []byte{byte(v)}, nil - case uint8: - return []byte{byte(v)}, nil - case int16: - if v > math.MaxInt8 || v < math.MinInt8 { - return nil, marshalErrorf("marshal tinyint: value %d out of range", v) - } - return []byte{byte(v)}, nil - case uint16: - if v > math.MaxUint8 { - return nil, marshalErrorf("marshal tinyint: value %d out of range", v) - } - return []byte{byte(v)}, nil - case int: - if v > math.MaxInt8 || v < math.MinInt8 { - return nil, marshalErrorf("marshal tinyint: value %d out of range", v) - } - return []byte{byte(v)}, nil - case int32: - if v > math.MaxInt8 || v < math.MinInt8 { - return nil, marshalErrorf("marshal tinyint: value %d out of range", v) - } - return []byte{byte(v)}, nil - case int64: - if v > math.MaxInt8 || v < math.MinInt8 { - return nil, marshalErrorf("marshal tinyint: value %d out of range", v) - } - return []byte{byte(v)}, nil - case uint: - if v > math.MaxUint8 { - return nil, marshalErrorf("marshal tinyint: value %d out of range", v) - } - return []byte{byte(v)}, nil - case uint32: - if v > math.MaxUint8 { - return nil, marshalErrorf("marshal tinyint: value %d out of range", v) - } - return []byte{byte(v)}, nil - case uint64: - if v > math.MaxUint8 { - return nil, marshalErrorf("marshal tinyint: value %d out of range", v) - } - return []byte{byte(v)}, nil - case string: - n, err := strconv.ParseInt(v, 10, 8) - if err != nil { - return nil, marshalErrorf("can not marshal %T into %s: %v", value, info, err) - } - return []byte{byte(n)}, nil - } - - if value == nil { - return nil, nil - } - - switch rv := reflect.ValueOf(value); rv.Type().Kind() { - case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8: - v := rv.Int() - if v > math.MaxInt8 || v < math.MinInt8 { - return nil, marshalErrorf("marshal tinyint: value %d out of range", v) - } - return []byte{byte(v)}, nil - case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: - v := rv.Uint() - if v > math.MaxUint8 { - return nil, marshalErrorf("marshal tinyint: value %d out of range", v) - } - return []byte{byte(v)}, nil - case reflect.Ptr: - if rv.IsNil() { - return nil, nil - } - } - - return nil, marshalErrorf("can not marshal %T into %s", value, info) -} - -func marshalInt(info TypeInfo, value interface{}) ([]byte, error) { - switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) - case unsetColumn: - return nil, nil - case int: - if v > math.MaxInt32 || v < math.MinInt32 { - return nil, marshalErrorf("marshal int: value %d out of range", v) - } - return encInt(int32(v)), nil - case uint: - if v > math.MaxUint32 { - return nil, marshalErrorf("marshal int: value %d out of range", v) - } - return encInt(int32(v)), nil - case int64: - if v > math.MaxInt32 || v < math.MinInt32 { - return nil, marshalErrorf("marshal int: value %d out of range", v) - } - return encInt(int32(v)), nil - case uint64: - if v > math.MaxUint32 { - return nil, marshalErrorf("marshal int: value %d out of range", v) - } - return encInt(int32(v)), nil - case int32: - return encInt(v), nil - case uint32: - return encInt(int32(v)), nil - case int16: - return encInt(int32(v)), nil - case uint16: - return encInt(int32(v)), nil - case int8: - return encInt(int32(v)), nil - case uint8: - return encInt(int32(v)), nil - case string: - i, err := strconv.ParseInt(v, 10, 32) - if err != nil { - return nil, marshalErrorf("can not marshal string to int: %s", err) - } - return encInt(int32(i)), nil - } - - if value == nil { - return nil, nil - } - - switch rv := reflect.ValueOf(value); rv.Type().Kind() { - case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8: - v := rv.Int() - if v > math.MaxInt32 || v < math.MinInt32 { - return nil, marshalErrorf("marshal int: value %d out of range", v) - } - return encInt(int32(v)), nil - case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: - v := rv.Uint() - if v > math.MaxInt32 { - return nil, marshalErrorf("marshal int: value %d out of range", v) - } - return encInt(int32(v)), nil - case reflect.Ptr: - if rv.IsNil() { - return nil, nil - } - } - - return nil, marshalErrorf("can not marshal %T into %s", value, info) -} - -func encInt(x int32) []byte { - return []byte{byte(x >> 24), byte(x >> 16), byte(x >> 8), byte(x)} -} - -func decInt(x []byte) int32 { - if len(x) != 4 { - return 0 - } - return int32(x[0])<<24 | int32(x[1])<<16 | int32(x[2])<<8 | int32(x[3]) -} - -func encShort(x int16) []byte { - p := make([]byte, 2) - p[0] = byte(x >> 8) - p[1] = byte(x) - return p -} - -func decShort(p []byte) int16 { - if len(p) != 2 { - return 0 - } - return int16(p[0])<<8 | int16(p[1]) -} - -func decTiny(p []byte) int8 { - if len(p) != 1 { - return 0 - } - return int8(p[0]) -} - -func marshalBigInt(info TypeInfo, value interface{}) ([]byte, error) { - switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) - case unsetColumn: - return nil, nil - case int: - return encBigInt(int64(v)), nil - case uint: - if uint64(v) > math.MaxInt64 { - return nil, marshalErrorf("marshal bigint: value %d out of range", v) - } - return encBigInt(int64(v)), nil - case int64: - return encBigInt(v), nil - case uint64: - return encBigInt(int64(v)), nil - case int32: - return encBigInt(int64(v)), nil - case uint32: - return encBigInt(int64(v)), nil - case int16: - return encBigInt(int64(v)), nil - case uint16: - return encBigInt(int64(v)), nil - case int8: - return encBigInt(int64(v)), nil - case uint8: - return encBigInt(int64(v)), nil - case big.Int: - return encBigInt2C(&v), nil - case string: - i, err := strconv.ParseInt(value.(string), 10, 64) - if err != nil { - return nil, marshalErrorf("can not marshal string to bigint: %s", err) - } - return encBigInt(i), nil - } - - if value == nil { - return nil, nil - } - - rv := reflect.ValueOf(value) - switch rv.Type().Kind() { - case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8: - v := rv.Int() - return encBigInt(v), nil - case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: - v := rv.Uint() - if v > math.MaxInt64 { - return nil, marshalErrorf("marshal bigint: value %d out of range", v) - } - return encBigInt(int64(v)), nil - } - return nil, marshalErrorf("can not marshal %T into %s", value, info) -} - -func encBigInt(x int64) []byte { - return []byte{byte(x >> 56), byte(x >> 48), byte(x >> 40), byte(x >> 32), - byte(x >> 24), byte(x >> 16), byte(x >> 8), byte(x)} -} - -func bytesToInt64(data []byte) (ret int64) { - for i := range data { - ret |= int64(data[i]) << (8 * uint(len(data)-i-1)) - } - return ret -} - -func bytesToUint64(data []byte) (ret uint64) { - for i := range data { - ret |= uint64(data[i]) << (8 * uint(len(data)-i-1)) - } - return ret -} - func unmarshalBigInt(info TypeInfo, data []byte, value interface{}) error { - return unmarshalIntlike(info, decBigInt(data), data, value) + return unmarshalIntlike(info, internal.DecBigInt(data), data, value) } func unmarshalInt(info TypeInfo, data []byte, value interface{}) error { - return unmarshalIntlike(info, int64(decInt(data)), data, value) + return unmarshalIntlike(info, int64(internal.DecInt(data)), data, value) } func unmarshalSmallInt(info TypeInfo, data []byte, value interface{}) error { - return unmarshalIntlike(info, int64(decShort(data)), data, value) + return unmarshalIntlike(info, int64(internal.DecShort(data)), data, value) } func unmarshalTinyInt(info TypeInfo, data []byte, value interface{}) error { - return unmarshalIntlike(info, int64(decTiny(data)), data, value) + return unmarshalIntlike(info, int64(internal.DecTiny(data)), data, value) } func unmarshalVarint(info TypeInfo, data []byte, value interface{}) error { @@ -740,7 +367,7 @@ func unmarshalVarint(info TypeInfo, data []byte, value interface{}) error { return unmarshalIntlike(info, 0, data, value) case *uint64: if len(data) == 9 && data[0] == 0 { - *v = bytesToUint64(data[1:]) + *v = internal.BytesToUint64(data[1:]) return nil } } @@ -749,64 +376,13 @@ func unmarshalVarint(info TypeInfo, data []byte, value interface{}) error { return unmarshalErrorf("unmarshal int: varint value %v out of range for %T (use big.Int)", data, value) } - int64Val := bytesToInt64(data) + int64Val := internal.BytesToInt64(data) if len(data) > 0 && len(data) < 8 && data[0]&0x80 > 0 { int64Val -= (1 << uint(len(data)*8)) } return unmarshalIntlike(info, int64Val, data, value) } -func marshalVarint(info TypeInfo, value interface{}) ([]byte, error) { - var ( - retBytes []byte - err error - ) - - switch v := value.(type) { - case unsetColumn: - return nil, nil - case uint64: - if v > uint64(math.MaxInt64) { - retBytes = make([]byte, 9) - binary.BigEndian.PutUint64(retBytes[1:], v) - } else { - retBytes = make([]byte, 8) - binary.BigEndian.PutUint64(retBytes, v) - } - default: - retBytes, err = marshalBigInt(info, value) - } - - if err == nil { - // trim down to most significant byte - i := 0 - for ; i < len(retBytes)-1; i++ { - b0 := retBytes[i] - if b0 != 0 && b0 != 0xFF { - break - } - - b1 := retBytes[i+1] - if b0 == 0 && b1 != 0 { - if b1&0x80 == 0 { - i++ - } - break - } - - if b0 == 0xFF && b1 != 0xFF { - if b1&0x80 > 0 { - i++ - } - break - } - } - retBytes = retBytes[i:] - } - - return retBytes, err -} - func unmarshalIntlike(info TypeInfo, int64Val int64, data []byte, value interface{}) error { switch v := value.(type) { case *int: @@ -899,7 +475,7 @@ func unmarshalIntlike(info TypeInfo, int64Val int64, data []byte, value interfac *v = uint8(int64Val) & 0xFF return nil case *big.Int: - decBigInt2C(data, v) + internal.DecBigInt2C(data, v) return nil case *string: *v = strconv.FormatInt(int64Val, 10) @@ -1009,51 +585,12 @@ func unmarshalIntlike(info TypeInfo, int64Val int64, data []byte, value interfac return unmarshalErrorf("can not unmarshal %s into %T", info, value) } -func decBigInt(data []byte) int64 { - if len(data) != 8 { - return 0 - } - return int64(data[0])<<56 | int64(data[1])<<48 | - int64(data[2])<<40 | int64(data[3])<<32 | - int64(data[4])<<24 | int64(data[5])<<16 | - int64(data[6])<<8 | int64(data[7]) -} - -func marshalBool(info TypeInfo, value interface{}) ([]byte, error) { - switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) - case unsetColumn: - return nil, nil - case bool: - return encBool(v), nil - } - - if value == nil { - return nil, nil - } - - rv := reflect.ValueOf(value) - switch rv.Type().Kind() { - case reflect.Bool: - return encBool(rv.Bool()), nil - } - return nil, marshalErrorf("can not marshal %T into %s", value, info) -} - -func encBool(v bool) []byte { - if v { - return []byte{1} - } - return []byte{0} -} - func unmarshalBool(info TypeInfo, data []byte, value interface{}) error { switch v := value.(type) { case Unmarshaler: return v.UnmarshalCQL(info, data) case *bool: - *v = decBool(data) + *v = internal.DecBool(data) return nil } rv := reflect.ValueOf(value) @@ -1063,47 +600,18 @@ func unmarshalBool(info TypeInfo, data []byte, value interface{}) error { rv = rv.Elem() switch rv.Type().Kind() { case reflect.Bool: - rv.SetBool(decBool(data)) + rv.SetBool(internal.DecBool(data)) return nil } return unmarshalErrorf("can not unmarshal %s into %T", info, value) } -func decBool(v []byte) bool { - if len(v) == 0 { - return false - } - return v[0] != 0 -} - -func marshalFloat(info TypeInfo, value interface{}) ([]byte, error) { - switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) - case unsetColumn: - return nil, nil - case float32: - return encInt(int32(math.Float32bits(v))), nil - } - - if value == nil { - return nil, nil - } - - rv := reflect.ValueOf(value) - switch rv.Type().Kind() { - case reflect.Float32: - return encInt(int32(math.Float32bits(float32(rv.Float())))), nil - } - return nil, marshalErrorf("can not marshal %T into %s", value, info) -} - func unmarshalFloat(info TypeInfo, data []byte, value interface{}) error { switch v := value.(type) { case Unmarshaler: return v.UnmarshalCQL(info, data) case *float32: - *v = math.Float32frombits(uint32(decInt(data))) + *v = math.Float32frombits(uint32(internal.DecInt(data))) return nil } rv := reflect.ValueOf(value) @@ -1113,38 +621,18 @@ func unmarshalFloat(info TypeInfo, data []byte, value interface{}) error { rv = rv.Elem() switch rv.Type().Kind() { case reflect.Float32: - rv.SetFloat(float64(math.Float32frombits(uint32(decInt(data))))) + rv.SetFloat(float64(math.Float32frombits(uint32(internal.DecInt(data))))) return nil } return unmarshalErrorf("can not unmarshal %s into %T", info, value) } -func marshalDouble(info TypeInfo, value interface{}) ([]byte, error) { - switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) - case unsetColumn: - return nil, nil - case float64: - return encBigInt(int64(math.Float64bits(v))), nil - } - if value == nil { - return nil, nil - } - rv := reflect.ValueOf(value) - switch rv.Type().Kind() { - case reflect.Float64: - return encBigInt(int64(math.Float64bits(rv.Float()))), nil - } - return nil, marshalErrorf("can not marshal %T into %s", value, info) -} - func unmarshalDouble(info TypeInfo, data []byte, value interface{}) error { switch v := value.(type) { case Unmarshaler: return v.UnmarshalCQL(info, data) case *float64: - *v = math.Float64frombits(uint64(decBigInt(data))) + *v = math.Float64frombits(uint64(internal.DecBigInt(data))) return nil } rv := reflect.ValueOf(value) @@ -1154,36 +642,12 @@ func unmarshalDouble(info TypeInfo, data []byte, value interface{}) error { rv = rv.Elem() switch rv.Type().Kind() { case reflect.Float64: - rv.SetFloat(math.Float64frombits(uint64(decBigInt(data)))) + rv.SetFloat(math.Float64frombits(uint64(internal.DecBigInt(data)))) return nil } return unmarshalErrorf("can not unmarshal %s into %T", info, value) } -func marshalDecimal(info TypeInfo, value interface{}) ([]byte, error) { - if value == nil { - return nil, nil - } - - switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) - case unsetColumn: - return nil, nil - case inf.Dec: - unscaled := encBigInt2C(v.UnscaledBig()) - if unscaled == nil { - return nil, marshalErrorf("can not marshal %T into %s", value, info) - } - - buf := make([]byte, 4+len(unscaled)) - copy(buf[0:4], encInt(int32(v.Scale()))) - copy(buf[4:], unscaled) - return buf, nil - } - return nil, marshalErrorf("can not marshal %T into %s", value, info) -} - func unmarshalDecimal(info TypeInfo, data []byte, value interface{}) error { switch v := value.(type) { case Unmarshaler: @@ -1192,115 +656,23 @@ func unmarshalDecimal(info TypeInfo, data []byte, value interface{}) error { if len(data) < 4 { return unmarshalErrorf("inf.Dec needs at least 4 bytes, while value has only %d", len(data)) } - scale := decInt(data[0:4]) - unscaled := decBigInt2C(data[4:], nil) + scale := internal.DecInt(data[0:4]) + unscaled := internal.DecBigInt2C(data[4:], nil) *v = *inf.NewDecBig(unscaled, inf.Scale(scale)) return nil } return unmarshalErrorf("can not unmarshal %s into %T", info, value) } -// decBigInt2C sets the value of n to the big-endian two's complement -// value stored in the given data. If data[0]&80 != 0, the number -// is negative. If data is empty, the result will be 0. -func decBigInt2C(data []byte, n *big.Int) *big.Int { - if n == nil { - n = new(big.Int) - } - n.SetBytes(data) - if len(data) > 0 && data[0]&0x80 > 0 { - n.Sub(n, new(big.Int).Lsh(bigOne, uint(len(data))*8)) - } - return n -} - -// encBigInt2C returns the big-endian two's complement -// form of n. -func encBigInt2C(n *big.Int) []byte { - switch n.Sign() { - case 0: - return []byte{0} - case 1: - b := n.Bytes() - if b[0]&0x80 > 0 { - b = append([]byte{0}, b...) - } - return b - case -1: - length := uint(n.BitLen()/8+1) * 8 - b := new(big.Int).Add(n, new(big.Int).Lsh(bigOne, length)).Bytes() - // When the most significant bit is on a byte - // boundary, we can get some extra significant - // bits, so strip them off when that happens. - if len(b) >= 2 && b[0] == 0xff && b[1]&0x80 != 0 { - b = b[1:] - } - return b - } - return nil -} - -func marshalTime(info TypeInfo, value interface{}) ([]byte, error) { - switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) - case unsetColumn: - return nil, nil - case int64: - return encBigInt(v), nil - case time.Duration: - return encBigInt(v.Nanoseconds()), nil - } - - if value == nil { - return nil, nil - } - - rv := reflect.ValueOf(value) - switch rv.Type().Kind() { - case reflect.Int64: - return encBigInt(rv.Int()), nil - } - return nil, marshalErrorf("can not marshal %T into %s", value, info) -} - -func marshalTimestamp(info TypeInfo, value interface{}) ([]byte, error) { - switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) - case unsetColumn: - return nil, nil - case int64: - return encBigInt(v), nil - case time.Time: - if v.IsZero() { - return []byte{}, nil - } - x := int64(v.UTC().Unix()*1e3) + int64(v.UTC().Nanosecond()/1e6) - return encBigInt(x), nil - } - - if value == nil { - return nil, nil - } - - rv := reflect.ValueOf(value) - switch rv.Type().Kind() { - case reflect.Int64: - return encBigInt(rv.Int()), nil - } - return nil, marshalErrorf("can not marshal %T into %s", value, info) -} - func unmarshalTime(info TypeInfo, data []byte, value interface{}) error { switch v := value.(type) { case Unmarshaler: return v.UnmarshalCQL(info, data) case *int64: - *v = decBigInt(data) + *v = internal.DecBigInt(data) return nil case *time.Duration: - *v = time.Duration(decBigInt(data)) + *v = time.Duration(internal.DecBigInt(data)) return nil } @@ -1311,7 +683,7 @@ func unmarshalTime(info TypeInfo, data []byte, value interface{}) error { rv = rv.Elem() switch rv.Type().Kind() { case reflect.Int64: - rv.SetInt(decBigInt(data)) + rv.SetInt(internal.DecBigInt(data)) return nil } return unmarshalErrorf("can not unmarshal %s into %T", info, value) @@ -1322,14 +694,14 @@ func unmarshalTimestamp(info TypeInfo, data []byte, value interface{}) error { case Unmarshaler: return v.UnmarshalCQL(info, data) case *int64: - *v = decBigInt(data) + *v = internal.DecBigInt(data) return nil case *time.Time: if len(data) == 0 { *v = time.Time{} return nil } - x := decBigInt(data) + x := internal.DecBigInt(data) sec := x / 1000 nsec := (x - sec*1000) * 1000000 *v = time.Unix(sec, nsec).In(time.UTC) @@ -1343,58 +715,12 @@ func unmarshalTimestamp(info TypeInfo, data []byte, value interface{}) error { rv = rv.Elem() switch rv.Type().Kind() { case reflect.Int64: - rv.SetInt(decBigInt(data)) + rv.SetInt(internal.DecBigInt(data)) return nil } return unmarshalErrorf("can not unmarshal %s into %T", info, value) } -const millisecondsInADay int64 = 24 * 60 * 60 * 1000 - -func marshalDate(info TypeInfo, value interface{}) ([]byte, error) { - var timestamp int64 - switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) - case unsetColumn: - return nil, nil - case int64: - timestamp = v - x := timestamp/millisecondsInADay + int64(1<<31) - return encInt(int32(x)), nil - case time.Time: - if v.IsZero() { - return []byte{}, nil - } - timestamp = int64(v.UTC().Unix()*1e3) + int64(v.UTC().Nanosecond()/1e6) - x := timestamp/millisecondsInADay + int64(1<<31) - return encInt(int32(x)), nil - case *time.Time: - if v.IsZero() { - return []byte{}, nil - } - timestamp = int64(v.UTC().Unix()*1e3) + int64(v.UTC().Nanosecond()/1e6) - x := timestamp/millisecondsInADay + int64(1<<31) - return encInt(int32(x)), nil - case string: - if v == "" { - return []byte{}, nil - } - t, err := time.Parse("2006-01-02", v) - if err != nil { - return nil, marshalErrorf("can not marshal %T into %s, date layout must be '2006-01-02'", value, info) - } - timestamp = int64(t.UTC().Unix()*1e3) + int64(t.UTC().Nanosecond()/1e6) - x := timestamp/millisecondsInADay + int64(1<<31) - return encInt(int32(x)), nil - } - - if value == nil { - return nil, nil - } - return nil, marshalErrorf("can not marshal %T into %s", value, info) -} - func unmarshalDate(info TypeInfo, data []byte, value interface{}) error { switch v := value.(type) { case Unmarshaler: @@ -1406,7 +732,7 @@ func unmarshalDate(info TypeInfo, data []byte, value interface{}) error { } var origin uint32 = 1 << 31 var current uint32 = binary.BigEndian.Uint32(data) - timestamp := (int64(current) - int64(origin)) * millisecondsInADay + timestamp := (int64(current) - int64(origin)) * internal.MillisecondsInADay *v = time.UnixMilli(timestamp).In(time.UTC) return nil case *string: @@ -1416,7 +742,7 @@ func unmarshalDate(info TypeInfo, data []byte, value interface{}) error { } var origin uint32 = 1 << 31 var current uint32 = binary.BigEndian.Uint32(data) - timestamp := (int64(current) - int64(origin)) * millisecondsInADay + timestamp := (int64(current) - int64(origin)) * internal.MillisecondsInADay *v = time.UnixMilli(timestamp).In(time.UTC).Format("2006-01-02") return nil } @@ -1427,20 +753,20 @@ func marshalDuration(info TypeInfo, value interface{}) ([]byte, error) { switch v := value.(type) { case Marshaler: return v.MarshalCQL(info) - case unsetColumn: + case internal.UnsetColumn: return nil, nil case int64: - return encVints(0, 0, v), nil + return internal.EncVints(0, 0, v), nil case time.Duration: - return encVints(0, 0, v.Nanoseconds()), nil + return internal.EncVints(0, 0, v.Nanoseconds()), nil case string: d, err := time.ParseDuration(v) if err != nil { return nil, err } - return encVints(0, 0, d.Nanoseconds()), nil + return internal.EncVints(0, 0, d.Nanoseconds()), nil case Duration: - return encVints(v.Months, v.Days, v.Nanoseconds), nil + return internal.EncVints(v.Months, v.Days, v.Nanoseconds), nil } if value == nil { @@ -1450,7 +776,7 @@ func marshalDuration(info TypeInfo, value interface{}) ([]byte, error) { rv := reflect.ValueOf(value) switch rv.Type().Kind() { case reflect.Int64: - return encBigInt(rv.Int()), nil + return internal.EncBigInt(rv.Int()), nil } return nil, marshalErrorf("can not marshal %T into %s", value, info) } @@ -1468,7 +794,7 @@ func unmarshalDuration(info TypeInfo, data []byte, value interface{}) error { } return nil } - months, days, nanos, err := decVints(data) + months, days, nanos, err := internal.DecVints(data) if err != nil { return unmarshalErrorf("failed to unmarshal %s into %T: %s", info, value, err.Error()) } @@ -1482,76 +808,10 @@ func unmarshalDuration(info TypeInfo, data []byte, value interface{}) error { return unmarshalErrorf("can not unmarshal %s into %T", info, value) } -func decVints(data []byte) (int32, int32, int64, error) { - month, i, err := decVint(data, 0) - if err != nil { - return 0, 0, 0, fmt.Errorf("failed to extract month: %s", err.Error()) - } - days, i, err := decVint(data, i) - if err != nil { - return 0, 0, 0, fmt.Errorf("failed to extract days: %s", err.Error()) - } - nanos, _, err := decVint(data, i) - if err != nil { - return 0, 0, 0, fmt.Errorf("failed to extract nanoseconds: %s", err.Error()) - } - return int32(month), int32(days), nanos, err -} - -func decVint(data []byte, start int) (int64, int, error) { - if len(data) <= start { - return 0, 0, errors.New("unexpected eof") - } - firstByte := data[start] - if firstByte&0x80 == 0 { - return decIntZigZag(uint64(firstByte)), start + 1, nil - } - numBytes := bits.LeadingZeros32(uint32(^firstByte)) - 24 - ret := uint64(firstByte & (0xff >> uint(numBytes))) - if len(data) < start+numBytes+1 { - return 0, 0, fmt.Errorf("data expect to have %d bytes, but it has only %d", start+numBytes+1, len(data)) - } - for i := start; i < start+numBytes; i++ { - ret <<= 8 - ret |= uint64(data[i+1] & 0xff) - } - return decIntZigZag(ret), start + numBytes + 1, nil -} - -func decIntZigZag(n uint64) int64 { - return int64((n >> 1) ^ -(n & 1)) -} - -func encIntZigZag(n int64) uint64 { - return uint64((n >> 63) ^ (n << 1)) -} - -func encVints(months int32, seconds int32, nanos int64) []byte { - buf := append(encVint(int64(months)), encVint(int64(seconds))...) - return append(buf, encVint(nanos)...) -} - -func encVint(v int64) []byte { - vEnc := encIntZigZag(v) - lead0 := bits.LeadingZeros64(vEnc) - numBytes := (639 - lead0*9) >> 6 - - // It can be 1 or 0 is v ==0 - if numBytes <= 1 { - return []byte{byte(vEnc)} - } - extraBytes := numBytes - 1 - var buf = make([]byte, numBytes) - for i := extraBytes; i >= 0; i-- { - buf[i] = byte(vEnc) - vEnc >>= 8 - } - buf[0] |= byte(^(0xff >> uint(extraBytes))) - return buf -} - +// TODO: move to internal +// just pass the CollectionType.proto to this method instead of CollectionType func writeCollectionSize(info CollectionType, n int, buf *bytes.Buffer) error { - if info.proto > protoVersion2 { + if info.proto > internal.ProtoVersion2 { if n > math.MaxInt32 { return marshalErrorf("marshal: collection too large") } @@ -1580,7 +840,7 @@ func marshalList(info TypeInfo, value interface{}) ([]byte, error) { if value == nil { return nil, nil - } else if _, ok := value.(unsetColumn); ok { + } else if _, ok := value.(internal.UnsetColumn); ok { return nil, nil } @@ -1607,7 +867,7 @@ func marshalList(info TypeInfo, value interface{}) ([]byte, error) { } itemLen := len(item) // Set the value to null for supported protocols - if item == nil && listInfo.proto > protoVersion2 { + if item == nil && listInfo.proto > internal.ProtoVersion2 { itemLen = -1 } if err := writeCollectionSize(listInfo, itemLen, buf); err != nil { @@ -1631,7 +891,7 @@ func marshalList(info TypeInfo, value interface{}) ([]byte, error) { } func readCollectionSize(info CollectionType, data []byte) (size, read int, err error) { - if info.proto > protoVersion2 { + if info.proto > internal.ProtoVersion2 { if len(data) < 4 { return 0, 0, unmarshalErrorf("unmarshal list: unexpected eof") } @@ -1717,7 +977,7 @@ func marshalMap(info TypeInfo, value interface{}) ([]byte, error) { if value == nil { return nil, nil - } else if _, ok := value.(unsetColumn); ok { + } else if _, ok := value.(internal.UnsetColumn); ok { return nil, nil } @@ -1747,7 +1007,7 @@ func marshalMap(info TypeInfo, value interface{}) ([]byte, error) { } itemLen := len(item) // Set the key to null for supported protocols - if item == nil && mapInfo.proto > protoVersion2 { + if item == nil && mapInfo.proto > internal.ProtoVersion2 { itemLen = -1 } if err := writeCollectionSize(mapInfo, itemLen, buf); err != nil { @@ -1761,7 +1021,7 @@ func marshalMap(info TypeInfo, value interface{}) ([]byte, error) { } itemLen = len(item) // Set the value to null for supported protocols - if item == nil && mapInfo.proto > protoVersion2 { + if item == nil && mapInfo.proto > internal.ProtoVersion2 { itemLen = -1 } if err := writeCollectionSize(mapInfo, itemLen, buf); err != nil { @@ -1847,7 +1107,7 @@ func unmarshalMap(info TypeInfo, data []byte, value interface{}) error { func marshalUUID(info TypeInfo, value interface{}) ([]byte, error) { switch val := value.(type) { - case unsetColumn: + case internal.UnsetColumn: return nil, nil case UUID: return val.Bytes(), nil @@ -1936,38 +1196,6 @@ func unmarshalTimeUUID(info TypeInfo, data []byte, value interface{}) error { } } -func marshalInet(info TypeInfo, value interface{}) ([]byte, error) { - // we return either the 4 or 16 byte representation of an - // ip address here otherwise the db value will be prefixed - // with the remaining byte values e.g. ::ffff:127.0.0.1 and not 127.0.0.1 - switch val := value.(type) { - case unsetColumn: - return nil, nil - case net.IP: - t := val.To4() - if t == nil { - return val.To16(), nil - } - return t, nil - case string: - b := net.ParseIP(val) - if b != nil { - t := b.To4() - if t == nil { - return b.To16(), nil - } - return t, nil - } - return nil, marshalErrorf("cannot marshal. invalid ip string %s", val) - } - - if value == nil { - return nil, nil - } - - return nil, marshalErrorf("cannot marshal %T into %s", value, info) -} - func unmarshalInet(info TypeInfo, data []byte, value interface{}) error { switch v := value.(type) { case Unmarshaler: @@ -1976,7 +1204,7 @@ func unmarshalInet(info TypeInfo, data []byte, value interface{}) error { if x := len(data); !(x == 4 || x == 16) { return unmarshalErrorf("cannot unmarshal %s into %T: invalid sized IP: got %d bytes not 4 or 16", info, value, x) } - buf := copyBytes(data) + buf := internal.CopyBytes(data) ip := net.IP(buf) if v4 := ip.To4(); v4 != nil { *v = v4 @@ -2003,7 +1231,7 @@ func unmarshalInet(info TypeInfo, data []byte, value interface{}) error { func marshalTuple(info TypeInfo, value interface{}) ([]byte, error) { tuple := info.(TupleTypeInfo) switch v := value.(type) { - case unsetColumn: + case internal.UnsetColumn: return nil, unmarshalErrorf("Invalid request: UnsetValue is unsupported for tuples") case []interface{}: if len(v) != len(tuple.Elems) { @@ -2013,7 +1241,7 @@ func marshalTuple(info TypeInfo, value interface{}) ([]byte, error) { var buf []byte for i, elem := range v { if elem == nil { - buf = appendInt(buf, int32(-1)) + buf = internal.AppendInt(buf, int32(-1)) continue } @@ -2023,7 +1251,7 @@ func marshalTuple(info TypeInfo, value interface{}) ([]byte, error) { } n := len(data) - buf = appendInt(buf, int32(n)) + buf = internal.AppendInt(buf, int32(n)) buf = append(buf, data...) } @@ -2045,7 +1273,7 @@ func marshalTuple(info TypeInfo, value interface{}) ([]byte, error) { field := rv.Field(i) if field.Kind() == reflect.Ptr && field.IsNil() { - buf = appendInt(buf, int32(-1)) + buf = internal.AppendInt(buf, int32(-1)) continue } @@ -2055,7 +1283,7 @@ func marshalTuple(info TypeInfo, value interface{}) ([]byte, error) { } n := len(data) - buf = appendInt(buf, int32(n)) + buf = internal.AppendInt(buf, int32(n)) buf = append(buf, data...) } @@ -2071,7 +1299,7 @@ func marshalTuple(info TypeInfo, value interface{}) ([]byte, error) { item := rv.Index(i) if item.Kind() == reflect.Ptr && item.IsNil() { - buf = appendInt(buf, int32(-1)) + buf = internal.AppendInt(buf, int32(-1)) continue } @@ -2081,7 +1309,7 @@ func marshalTuple(info TypeInfo, value interface{}) ([]byte, error) { } n := len(data) - buf = appendInt(buf, int32(n)) + buf = internal.AppendInt(buf, int32(n)) buf = append(buf, data...) } @@ -2091,16 +1319,6 @@ func marshalTuple(info TypeInfo, value interface{}) ([]byte, error) { return nil, marshalErrorf("cannot marshal %T into %s", value, tuple) } -func readBytes(p []byte) ([]byte, []byte) { - // TODO: really should use a framer - size := readInt(p) - p = p[4:] - if size < 0 { - return nil, p - } - return p[:size], p[size:] -} - // currently only support unmarshal into a list of values, this makes it possible // to support tuples without changing the query API. In the future this can be extend // to allow unmarshalling into custom tuple types. @@ -2116,7 +1334,7 @@ func unmarshalTuple(info TypeInfo, data []byte, value interface{}) error { // each element inside data is a [bytes] var p []byte if len(data) >= 4 { - p, data = readBytes(data) + p, data = internal.ReadBytes(data) } err := Unmarshal(elem, p, v[i]) if err != nil { @@ -2145,7 +1363,7 @@ func unmarshalTuple(info TypeInfo, data []byte, value interface{}) error { for i, elem := range tuple.Elems { var p []byte if len(data) >= 4 { - p, data = readBytes(data) + p, data = internal.ReadBytes(data) } v, err := elem.NewWithError() @@ -2182,7 +1400,7 @@ func unmarshalTuple(info TypeInfo, data []byte, value interface{}) error { for i, elem := range tuple.Elems { var p []byte if len(data) >= 4 { - p, data = readBytes(data) + p, data = internal.ReadBytes(data) } v, err := elem.NewWithError() @@ -2236,7 +1454,7 @@ func marshalUDT(info TypeInfo, value interface{}) ([]byte, error) { switch v := value.(type) { case Marshaler: return v.MarshalCQL(info) - case unsetColumn: + case internal.UnsetColumn: return nil, unmarshalErrorf("invalid request: UnsetValue is unsupported for user defined types") case UDTMarshaler: var buf []byte @@ -2246,7 +1464,7 @@ func marshalUDT(info TypeInfo, value interface{}) ([]byte, error) { return nil, err } - buf = appendBytes(buf, data) + buf = internal.AppendBytes(buf, data) } return buf, nil @@ -2265,7 +1483,7 @@ func marshalUDT(info TypeInfo, value interface{}) ([]byte, error) { } } - buf = appendBytes(buf, data) + buf = internal.AppendBytes(buf, data) } return buf, nil @@ -2309,7 +1527,7 @@ func marshalUDT(info TypeInfo, value interface{}) ([]byte, error) { } } - buf = appendBytes(buf, data) + buf = internal.AppendBytes(buf, data) } return buf, nil @@ -2331,7 +1549,7 @@ func unmarshalUDT(info TypeInfo, data []byte, value interface{}) error { } var p []byte - p, data = readBytes(data) + p, data = internal.ReadBytes(data) if err := v.UnmarshalUDT(e.Name, e.Type, p); err != nil { return err } @@ -2374,7 +1592,7 @@ func unmarshalUDT(info TypeInfo, data []byte, value interface{}) error { val := reflect.New(valType) var p []byte - p, data = readBytes(data) + p, data = internal.ReadBytes(data) if err := Unmarshal(e.Type, p, val.Interface()); err != nil { return err @@ -2424,12 +1642,12 @@ func unmarshalUDT(info TypeInfo, data []byte, value interface{}) error { } var p []byte - p, data = readBytes(data) + p, data = internal.ReadBytes(data) f, ok := fields[e.Name] if !ok { f = k.FieldByName(e.Name) - if f == emptyValue { + if f == internal.EmptyValue { // skip fields which exist in the UDT but not in // the struct passed in continue diff --git a/marshal_test.go b/marshal_test.go index 6c139e6bc..31b1b812d 100644 --- a/marshal_test.go +++ b/marshal_test.go @@ -38,6 +38,8 @@ import ( "testing" "time" + "github.com/gocql/gocql/internal" + "gopkg.in/inf.v0" ) @@ -1075,7 +1077,7 @@ var marshalTests = []struct { }, { NativeType{proto: 2, typ: TypeTime}, - encBigInt(1000), + internal.EncBigInt(1000), time.Duration(1000), nil, nil, @@ -1692,7 +1694,7 @@ var typeLookupTest = []struct { } func testType(t *testing.T, cassType string, expectedType Type) { - if computedType := getApacheCassandraType(apacheCassandraTypePrefix + cassType); computedType != expectedType { + if computedType := getApacheCassandraType(internal.ApacheCassandraTypePrefix + cassType); computedType != expectedType { t.Errorf("Cassandra custom type lookup for %s failed. Expected %s, got %s.", cassType, expectedType.String(), computedType.String()) } } @@ -1726,7 +1728,7 @@ func TestMarshalPointer(t *testing.T) { func TestMarshalTime(t *testing.T) { durationS := "1h10m10s" duration, _ := time.ParseDuration(durationS) - expectedData := encBigInt(duration.Nanoseconds()) + expectedData := internal.EncBigInt(duration.Nanoseconds()) var marshalTimeTests = []struct { Info TypeInfo Data []byte @@ -1758,7 +1760,7 @@ func TestMarshalTime(t *testing.T) { } if !bytes.Equal(data, test.Data) { t.Errorf("marshalTest[%d]: expected %x (%v), got %x (%v) for time %s", i, - test.Data, decInt(test.Data), data, decInt(data), test.Value) + test.Data, internal.DecInt(test.Data), data, internal.DecInt(data), test.Value) } } } @@ -1824,7 +1826,7 @@ func TestMarshalTimestamp(t *testing.T) { } if !bytes.Equal(data, test.Data) { t.Errorf("marshalTest[%d]: expected %x (%v), got %x (%v) for time %s", i, - test.Data, decBigInt(test.Data), data, decBigInt(data), test.Value) + test.Data, internal.DecBigInt(test.Data), data, internal.DecBigInt(data), test.Value) } } } @@ -1961,7 +1963,7 @@ func TestMarshalTuple(t *testing.T) { if !bytes.Equal(data, tc.expected) { t.Errorf("marshalTest: expected %x (%v), got %x (%v)", - tc.expected, decBigInt(tc.expected), data, decBigInt(data)) + tc.expected, internal.DecBigInt(tc.expected), data, internal.DecBigInt(data)) return } @@ -2244,7 +2246,7 @@ func TestUnmarshalDate(t *testing.T) { func TestMarshalDate(t *testing.T) { now := time.Now().UTC() timestamp := now.UnixNano() / int64(time.Millisecond) - expectedData := encInt(int32(timestamp/86400000 + int64(1<<31))) + expectedData := internal.EncInt(int32(timestamp/86400000 + int64(1<<31))) var marshalDateTests = []struct { Info TypeInfo @@ -2282,17 +2284,17 @@ func TestMarshalDate(t *testing.T) { } if !bytes.Equal(data, test.Data) { t.Errorf("marshalTest[%d]: expected %x (%v), got %x (%v) for time %s", i, - test.Data, decInt(test.Data), data, decInt(data), test.Value) + test.Data, internal.DecInt(test.Data), data, internal.DecInt(data), test.Value) } } } func TestLargeDate(t *testing.T) { farFuture := time.Date(999999, time.December, 31, 0, 0, 0, 0, time.UTC) - expectedFutureData := encInt(int32(farFuture.UnixMilli()/86400000 + int64(1<<31))) + expectedFutureData := internal.EncInt(int32(farFuture.UnixMilli()/86400000 + int64(1<<31))) farPast := time.Date(-999999, time.January, 1, 0, 0, 0, 0, time.UTC) - expectedPastData := encInt(int32(farPast.UnixMilli()/86400000 + int64(1<<31))) + expectedPastData := internal.EncInt(int32(farPast.UnixMilli()/86400000 + int64(1<<31))) var marshalDateTests = []struct { Data []byte @@ -2323,7 +2325,7 @@ func TestLargeDate(t *testing.T) { } if !bytes.Equal(data, test.Data) { t.Errorf("largeDateTest[%d]: expected %x (%v), got %x (%v) for time %s", i, - test.Data, decInt(test.Data), data, decInt(data), test.Value) + test.Data, internal.DecInt(test.Data), data, internal.DecInt(data), test.Value) } var date time.Time @@ -2354,7 +2356,7 @@ func BenchmarkUnmarshalVarchar(b *testing.B) { func TestMarshalDuration(t *testing.T) { durationS := "1h10m10s" duration, _ := time.ParseDuration(durationS) - expectedData := append([]byte{0, 0}, encVint(duration.Nanoseconds())...) + expectedData := append([]byte{0, 0}, internal.EncVint(duration.Nanoseconds())...) var marshalDurationTests = []struct { Info TypeInfo Data []byte @@ -2391,7 +2393,7 @@ func TestMarshalDuration(t *testing.T) { } if !bytes.Equal(data, test.Data) { t.Errorf("marshalTest[%d]: expected %x (%v), got %x (%v) for time %s", i, - test.Data, decInt(test.Data), data, decInt(data), test.Value) + test.Data, internal.DecInt(test.Data), data, internal.DecInt(data), test.Value) } } } diff --git a/metadata.go b/metadata.go index 6eb798f8a..b1fdb6116 100644 --- a/metadata.go +++ b/metadata.go @@ -32,6 +32,7 @@ import ( "encoding/hex" "encoding/json" "fmt" + "github.com/gocql/gocql/internal" "strconv" "strings" "sync" @@ -410,7 +411,7 @@ func compileMetadata( table.OrderedColumns = append(table.OrderedColumns, col.Name) } - if protoVersion == protoVersion1 { + if protoVersion == internal.ProtoVersion1 { compileV1Metadata(tables, logger) } else { compileV2Metadata(tables, logger) @@ -669,7 +670,7 @@ func getTableMetadata(session *Session, keyspaceName string) ([]TableMetadata, e } return r } - } else if session.cfg.ProtoVersion == protoVersion1 { + } else if session.cfg.ProtoVersion == internal.ProtoVersion1 { // we have key aliases stmt = ` SELECT @@ -948,14 +949,14 @@ func getColumnMetadata(session *Session, keyspaceName string) ([]ColumnMetadata, } func getTypeInfo(t string, logger StdLogger) TypeInfo { - if strings.HasPrefix(t, apacheCassandraTypePrefix) { + if strings.HasPrefix(t, internal.ApacheCassandraTypePrefix) { t = apacheToCassandraType(t) } return getCassandraType(t, logger) } func getViewsMetadata(session *Session, keyspaceName string) ([]ViewMetadata, error) { - if session.cfg.ProtoVersion == protoVersion1 { + if session.cfg.ProtoVersion == internal.ProtoVersion1 { return nil, nil } var tableName string @@ -1069,7 +1070,7 @@ func getMaterializedViewsMetadata(session *Session, keyspaceName string) ([]Mate } func getFunctionsMetadata(session *Session, keyspaceName string) ([]FunctionMetadata, error) { - if session.cfg.ProtoVersion == protoVersion1 || !session.hasAggregatesAndFunctions { + if session.cfg.ProtoVersion == internal.ProtoVersion1 || !session.hasAggregatesAndFunctions { return nil, nil } var tableName string @@ -1124,7 +1125,7 @@ func getFunctionsMetadata(session *Session, keyspaceName string) ([]FunctionMeta } func getAggregatesMetadata(session *Session, keyspaceName string) ([]AggregateMetadata, error) { - if session.cfg.ProtoVersion == protoVersion1 || !session.hasAggregatesAndFunctions { + if session.cfg.ProtoVersion == internal.ProtoVersion1 || !session.hasAggregatesAndFunctions { return nil, nil } var tableName string diff --git a/policies_test.go b/policies_test.go index 231c2a7e2..d097e4d9c 100644 --- a/policies_test.go +++ b/policies_test.go @@ -37,6 +37,7 @@ import ( "testing" "time" + "github.com/gocql/gocql/internal" "github.com/hailocab/go-hostpool" ) @@ -122,10 +123,10 @@ func TestHostPolicy_TokenAware_SimpleStrategy(t *testing.T) { // It's handy to have as reference here. assertDeepEqual(t, "replicas", map[string]tokenRingReplicas{ "myKeyspace": { - {orderedToken("00"), []*HostInfo{hosts[0], hosts[1]}}, - {orderedToken("25"), []*HostInfo{hosts[1], hosts[2]}}, - {orderedToken("50"), []*HostInfo{hosts[2], hosts[3]}}, - {orderedToken("75"), []*HostInfo{hosts[3], hosts[0]}}, + {internal.OrderedToken("00"), []*HostInfo{hosts[0], hosts[1]}}, + {internal.OrderedToken("25"), []*HostInfo{hosts[1], hosts[2]}}, + {internal.OrderedToken("50"), []*HostInfo{hosts[2], hosts[3]}}, + {internal.OrderedToken("75"), []*HostInfo{hosts[3], hosts[0]}}, }, }, policyInternal.getMetadataReadOnly().replicas) @@ -567,18 +568,18 @@ func TestHostPolicy_TokenAware(t *testing.T) { // It's handy to have as reference here. assertDeepEqual(t, "replicas", map[string]tokenRingReplicas{ "myKeyspace": { - {orderedToken("05"), []*HostInfo{hosts[0], hosts[1], hosts[2]}}, - {orderedToken("10"), []*HostInfo{hosts[1], hosts[2], hosts[3]}}, - {orderedToken("15"), []*HostInfo{hosts[2], hosts[3], hosts[4]}}, - {orderedToken("20"), []*HostInfo{hosts[3], hosts[4], hosts[5]}}, - {orderedToken("25"), []*HostInfo{hosts[4], hosts[5], hosts[6]}}, - {orderedToken("30"), []*HostInfo{hosts[5], hosts[6], hosts[7]}}, - {orderedToken("35"), []*HostInfo{hosts[6], hosts[7], hosts[8]}}, - {orderedToken("40"), []*HostInfo{hosts[7], hosts[8], hosts[9]}}, - {orderedToken("45"), []*HostInfo{hosts[8], hosts[9], hosts[10]}}, - {orderedToken("50"), []*HostInfo{hosts[9], hosts[10], hosts[11]}}, - {orderedToken("55"), []*HostInfo{hosts[10], hosts[11], hosts[0]}}, - {orderedToken("60"), []*HostInfo{hosts[11], hosts[0], hosts[1]}}, + {internal.OrderedToken("05"), []*HostInfo{hosts[0], hosts[1], hosts[2]}}, + {internal.OrderedToken("10"), []*HostInfo{hosts[1], hosts[2], hosts[3]}}, + {internal.OrderedToken("15"), []*HostInfo{hosts[2], hosts[3], hosts[4]}}, + {internal.OrderedToken("20"), []*HostInfo{hosts[3], hosts[4], hosts[5]}}, + {internal.OrderedToken("25"), []*HostInfo{hosts[4], hosts[5], hosts[6]}}, + {internal.OrderedToken("30"), []*HostInfo{hosts[5], hosts[6], hosts[7]}}, + {internal.OrderedToken("35"), []*HostInfo{hosts[6], hosts[7], hosts[8]}}, + {internal.OrderedToken("40"), []*HostInfo{hosts[7], hosts[8], hosts[9]}}, + {internal.OrderedToken("45"), []*HostInfo{hosts[8], hosts[9], hosts[10]}}, + {internal.OrderedToken("50"), []*HostInfo{hosts[9], hosts[10], hosts[11]}}, + {internal.OrderedToken("55"), []*HostInfo{hosts[10], hosts[11], hosts[0]}}, + {internal.OrderedToken("60"), []*HostInfo{hosts[11], hosts[0], hosts[1]}}, }, }, policyInternal.getMetadataReadOnly().replicas) @@ -658,18 +659,18 @@ func TestHostPolicy_TokenAware_NetworkStrategy(t *testing.T) { // It's handy to have as reference here. assertDeepEqual(t, "replicas", map[string]tokenRingReplicas{ keyspace: { - {orderedToken("05"), []*HostInfo{hosts[0], hosts[1], hosts[2], hosts[3], hosts[4], hosts[5]}}, - {orderedToken("10"), []*HostInfo{hosts[1], hosts[2], hosts[3], hosts[4], hosts[5], hosts[6]}}, - {orderedToken("15"), []*HostInfo{hosts[2], hosts[3], hosts[4], hosts[5], hosts[6], hosts[7]}}, - {orderedToken("20"), []*HostInfo{hosts[3], hosts[4], hosts[5], hosts[6], hosts[7], hosts[8]}}, - {orderedToken("25"), []*HostInfo{hosts[4], hosts[5], hosts[6], hosts[7], hosts[8], hosts[9]}}, - {orderedToken("30"), []*HostInfo{hosts[5], hosts[6], hosts[7], hosts[8], hosts[9], hosts[10]}}, - {orderedToken("35"), []*HostInfo{hosts[6], hosts[7], hosts[8], hosts[9], hosts[10], hosts[11]}}, - {orderedToken("40"), []*HostInfo{hosts[7], hosts[8], hosts[9], hosts[10], hosts[11], hosts[0]}}, - {orderedToken("45"), []*HostInfo{hosts[8], hosts[9], hosts[10], hosts[11], hosts[0], hosts[1]}}, - {orderedToken("50"), []*HostInfo{hosts[9], hosts[10], hosts[11], hosts[0], hosts[1], hosts[2]}}, - {orderedToken("55"), []*HostInfo{hosts[10], hosts[11], hosts[0], hosts[1], hosts[2], hosts[3]}}, - {orderedToken("60"), []*HostInfo{hosts[11], hosts[0], hosts[1], hosts[2], hosts[3], hosts[4]}}, + {internal.OrderedToken("05"), []*HostInfo{hosts[0], hosts[1], hosts[2], hosts[3], hosts[4], hosts[5]}}, + {internal.OrderedToken("10"), []*HostInfo{hosts[1], hosts[2], hosts[3], hosts[4], hosts[5], hosts[6]}}, + {internal.OrderedToken("15"), []*HostInfo{hosts[2], hosts[3], hosts[4], hosts[5], hosts[6], hosts[7]}}, + {internal.OrderedToken("20"), []*HostInfo{hosts[3], hosts[4], hosts[5], hosts[6], hosts[7], hosts[8]}}, + {internal.OrderedToken("25"), []*HostInfo{hosts[4], hosts[5], hosts[6], hosts[7], hosts[8], hosts[9]}}, + {internal.OrderedToken("30"), []*HostInfo{hosts[5], hosts[6], hosts[7], hosts[8], hosts[9], hosts[10]}}, + {internal.OrderedToken("35"), []*HostInfo{hosts[6], hosts[7], hosts[8], hosts[9], hosts[10], hosts[11]}}, + {internal.OrderedToken("40"), []*HostInfo{hosts[7], hosts[8], hosts[9], hosts[10], hosts[11], hosts[0]}}, + {internal.OrderedToken("45"), []*HostInfo{hosts[8], hosts[9], hosts[10], hosts[11], hosts[0], hosts[1]}}, + {internal.OrderedToken("50"), []*HostInfo{hosts[9], hosts[10], hosts[11], hosts[0], hosts[1], hosts[2]}}, + {internal.OrderedToken("55"), []*HostInfo{hosts[10], hosts[11], hosts[0], hosts[1], hosts[2], hosts[3]}}, + {internal.OrderedToken("60"), []*HostInfo{hosts[11], hosts[0], hosts[1], hosts[2], hosts[3], hosts[4]}}, }, }, policyInternal.getMetadataReadOnly().replicas) @@ -799,18 +800,18 @@ func TestHostPolicy_TokenAware_RackAware(t *testing.T) { // It's handy to have as reference here. assertDeepEqual(t, "replicas", map[string]tokenRingReplicas{ "myKeyspace": { - {orderedToken("05"), []*HostInfo{hosts[0], hosts[1], hosts[2], hosts[3]}}, - {orderedToken("10"), []*HostInfo{hosts[1], hosts[2], hosts[3], hosts[4]}}, - {orderedToken("15"), []*HostInfo{hosts[2], hosts[3], hosts[4], hosts[5]}}, - {orderedToken("20"), []*HostInfo{hosts[3], hosts[4], hosts[5], hosts[6]}}, - {orderedToken("25"), []*HostInfo{hosts[4], hosts[5], hosts[6], hosts[7]}}, - {orderedToken("30"), []*HostInfo{hosts[5], hosts[6], hosts[7], hosts[8]}}, - {orderedToken("35"), []*HostInfo{hosts[6], hosts[7], hosts[8], hosts[9]}}, - {orderedToken("40"), []*HostInfo{hosts[7], hosts[8], hosts[9], hosts[10]}}, - {orderedToken("45"), []*HostInfo{hosts[8], hosts[9], hosts[10], hosts[11]}}, - {orderedToken("50"), []*HostInfo{hosts[9], hosts[10], hosts[11], hosts[0]}}, - {orderedToken("55"), []*HostInfo{hosts[10], hosts[11], hosts[0], hosts[1]}}, - {orderedToken("60"), []*HostInfo{hosts[11], hosts[0], hosts[1], hosts[2]}}, + {internal.OrderedToken("05"), []*HostInfo{hosts[0], hosts[1], hosts[2], hosts[3]}}, + {internal.OrderedToken("10"), []*HostInfo{hosts[1], hosts[2], hosts[3], hosts[4]}}, + {internal.OrderedToken("15"), []*HostInfo{hosts[2], hosts[3], hosts[4], hosts[5]}}, + {internal.OrderedToken("20"), []*HostInfo{hosts[3], hosts[4], hosts[5], hosts[6]}}, + {internal.OrderedToken("25"), []*HostInfo{hosts[4], hosts[5], hosts[6], hosts[7]}}, + {internal.OrderedToken("30"), []*HostInfo{hosts[5], hosts[6], hosts[7], hosts[8]}}, + {internal.OrderedToken("35"), []*HostInfo{hosts[6], hosts[7], hosts[8], hosts[9]}}, + {internal.OrderedToken("40"), []*HostInfo{hosts[7], hosts[8], hosts[9], hosts[10]}}, + {internal.OrderedToken("45"), []*HostInfo{hosts[8], hosts[9], hosts[10], hosts[11]}}, + {internal.OrderedToken("50"), []*HostInfo{hosts[9], hosts[10], hosts[11], hosts[0]}}, + {internal.OrderedToken("55"), []*HostInfo{hosts[10], hosts[11], hosts[0], hosts[1]}}, + {internal.OrderedToken("60"), []*HostInfo{hosts[11], hosts[0], hosts[1], hosts[2]}}, }, }, policyInternal.getMetadataReadOnly().replicas) diff --git a/session.go b/session.go index d04a13672..c133d75a6 100644 --- a/session.go +++ b/session.go @@ -38,6 +38,7 @@ import ( "time" "unicode" + "github.com/gocql/gocql/internal" "github.com/gocql/gocql/internal/lru" ) @@ -824,7 +825,6 @@ type queryMetrics struct { totalAttempts int } -// preFilledQueryMetrics initializes new queryMetrics based on per-host supplied data. func preFilledQueryMetrics(m map[string]*hostMetrics) *queryMetrics { qm := &queryMetrics{m: m} for _, hm := range qm.m { @@ -1303,18 +1303,10 @@ func (q *Query) Exec() error { return q.Iter().Close() } -func isUseStatement(stmt string) bool { - if len(stmt) < 3 { - return false - } - - return strings.EqualFold(stmt[0:3], "use") -} - // Iter executes the query and returns an iterator capable of iterating // over all results. func (q *Query) Iter() *Iter { - if isUseStatement(q.stmt) { + if internal.IsUseStatement(q.stmt) { return &Iter{err: ErrUseStmt} } // if the query was specifically run on a connection then re-use that @@ -1656,7 +1648,7 @@ func (iter *Iter) GetCustomPayload() map[string][]byte { // This is only available starting with CQL Protocol v4. func (iter *Iter) Warnings() []string { if iter.framer != nil { - return iter.framer.header.warnings + return iter.framer.header.Warnings } return nil } diff --git a/session_test.go b/session_test.go index 8633f9957..4943d78dd 100644 --- a/session_test.go +++ b/session_test.go @@ -32,6 +32,8 @@ import ( "fmt" "net" "testing" + + "github.com/gocql/gocql/internal" ) func TestSessionAPI(t *testing.T) { @@ -341,7 +343,7 @@ func TestIsUseStatement(t *testing.T) { } for _, tc := range testCases { - v := isUseStatement(tc.input) + v := internal.IsUseStatement(tc.input) if v != tc.exp { t.Fatalf("expected %v but got %v for statement %q", tc.exp, v, tc.input) } diff --git a/token.go b/token.go index 7502ea713..664a6d21a 100644 --- a/token.go +++ b/token.go @@ -30,120 +30,16 @@ package gocql import ( "bytes" - "crypto/md5" "fmt" - "math/big" "sort" "strconv" "strings" - "github.com/gocql/gocql/internal/murmur" + "github.com/gocql/gocql/internal" ) -// a token partitioner -type partitioner interface { - Name() string - Hash([]byte) token - ParseString(string) token -} - -// a token -type token interface { - fmt.Stringer - Less(token) bool -} - -// murmur3 partitioner and token -type murmur3Partitioner struct{} -type murmur3Token int64 - -func (p murmur3Partitioner) Name() string { - return "Murmur3Partitioner" -} - -func (p murmur3Partitioner) Hash(partitionKey []byte) token { - h1 := murmur.Murmur3H1(partitionKey) - return murmur3Token(h1) -} - -// murmur3 little-endian, 128-bit hash, but returns only h1 -func (p murmur3Partitioner) ParseString(str string) token { - val, _ := strconv.ParseInt(str, 10, 64) - return murmur3Token(val) -} - -func (m murmur3Token) String() string { - return strconv.FormatInt(int64(m), 10) -} - -func (m murmur3Token) Less(token token) bool { - return m < token.(murmur3Token) -} - -// order preserving partitioner and token -type orderedPartitioner struct{} -type orderedToken string - -func (p orderedPartitioner) Name() string { - return "OrderedPartitioner" -} - -func (p orderedPartitioner) Hash(partitionKey []byte) token { - // the partition key is the token - return orderedToken(partitionKey) -} - -func (p orderedPartitioner) ParseString(str string) token { - return orderedToken(str) -} - -func (o orderedToken) String() string { - return string(o) -} - -func (o orderedToken) Less(token token) bool { - return o < token.(orderedToken) -} - -// random partitioner and token -type randomPartitioner struct{} -type randomToken big.Int - -func (r randomPartitioner) Name() string { - return "RandomPartitioner" -} - -// 2 ** 128 -var maxHashInt, _ = new(big.Int).SetString("340282366920938463463374607431768211456", 10) - -func (p randomPartitioner) Hash(partitionKey []byte) token { - sum := md5.Sum(partitionKey) - val := new(big.Int) - val.SetBytes(sum[:]) - if sum[0] > 127 { - val.Sub(val, maxHashInt) - val.Abs(val) - } - - return (*randomToken)(val) -} - -func (p randomPartitioner) ParseString(str string) token { - val := new(big.Int) - val.SetString(str, 10) - return (*randomToken)(val) -} - -func (r *randomToken) String() string { - return (*big.Int)(r).String() -} - -func (r *randomToken) Less(token token) bool { - return -1 == (*big.Int)(r).Cmp((*big.Int)(token.(*randomToken))) -} - type hostToken struct { - token token + token internal.Token host *HostInfo } @@ -153,7 +49,7 @@ func (ht hostToken) String() string { // a data structure for organizing the relationship between tokens and hosts type tokenRing struct { - partitioner partitioner + partitioner internal.Partitioner // tokens map token range to primary replica. // The elements in tokens are sorted by token ascending. @@ -171,11 +67,11 @@ func newTokenRing(partitioner string, hosts []*HostInfo) (*tokenRing, error) { } if strings.HasSuffix(partitioner, "Murmur3Partitioner") { - tokenRing.partitioner = murmur3Partitioner{} + tokenRing.partitioner = internal.Murmur3Partitioner{} } else if strings.HasSuffix(partitioner, "OrderedPartitioner") { - tokenRing.partitioner = orderedPartitioner{} + tokenRing.partitioner = internal.OrderedPartitioner{} } else if strings.HasSuffix(partitioner, "RandomPartitioner") { - tokenRing.partitioner = randomPartitioner{} + tokenRing.partitioner = internal.RandomPartitioner{} } else { return nil, fmt.Errorf("unsupported partitioner '%s'", partitioner) } @@ -226,7 +122,7 @@ func (t *tokenRing) String() string { return string(buf.Bytes()) } -func (t *tokenRing) GetHostForToken(token token) (host *HostInfo, endToken token) { +func (t *tokenRing) GetHostForToken(token internal.Token) (host *HostInfo, endToken internal.Token) { if t == nil || len(t.tokens) == 0 { return nil, nil } diff --git a/token_test.go b/token_test.go index 90e0d4fd8..46af7c2d4 100644 --- a/token_test.go +++ b/token_test.go @@ -36,11 +36,13 @@ import ( "sort" "strconv" "testing" + + "github.com/gocql/gocql/internal" ) // Tests of the murmur3Patitioner func TestMurmur3Partitioner(t *testing.T) { - token := murmur3Partitioner{}.ParseString("-1053604476080545076") + token := internal.Murmur3Partitioner{}.ParseString("-1053604476080545076") if "-1053604476080545076" != token.String() { t.Errorf("Expected '-1053604476080545076' but was '%s'", token) @@ -48,32 +50,32 @@ func TestMurmur3Partitioner(t *testing.T) { // at least verify that the partitioner // doesn't return nil - pk, _ := marshalInt(nil, 1) - token = murmur3Partitioner{}.Hash(pk) + pk, _ := internal.MarshalInt(nil, 1) + token = internal.Murmur3Partitioner{}.Hash(pk) if token == nil { t.Fatal("token was nil") } } -// Tests of the murmur3Token +// Tests of the internal.Murmur3Token func TestMurmur3Token(t *testing.T) { - if murmur3Token(42).Less(murmur3Token(42)) { + if internal.Murmur3Token(42).Less(internal.Murmur3Token(42)) { t.Errorf("Expected Less to return false, but was true") } - if !murmur3Token(-42).Less(murmur3Token(42)) { + if !internal.Murmur3Token(-42).Less(internal.Murmur3Token(42)) { t.Errorf("Expected Less to return true, but was false") } - if murmur3Token(42).Less(murmur3Token(-42)) { + if internal.Murmur3Token(42).Less(internal.Murmur3Token(-42)) { t.Errorf("Expected Less to return false, but was true") } } -// Tests of the orderedPartitioner +// Tests of the internal.OrderedPartitioner func TestOrderedPartitioner(t *testing.T) { // at least verify that the partitioner // doesn't return nil - p := orderedPartitioner{} - pk, _ := marshalInt(nil, 1) + p := internal.OrderedPartitioner{} + pk, _ := internal.MarshalInt(nil, 1) token := p.Hash(pk) if token == nil { t.Fatal("token was nil") @@ -82,34 +84,34 @@ func TestOrderedPartitioner(t *testing.T) { str := token.String() parsedToken := p.ParseString(str) - if !bytes.Equal([]byte(token.(orderedToken)), []byte(parsedToken.(orderedToken))) { + if !bytes.Equal([]byte(token.(internal.OrderedToken)), []byte(parsedToken.(internal.OrderedToken))) { t.Errorf("Failed to convert to and from a string %s expected %x but was %x", str, - []byte(token.(orderedToken)), - []byte(parsedToken.(orderedToken)), + []byte(token.(internal.OrderedToken)), + []byte(parsedToken.(internal.OrderedToken)), ) } } -// Tests of the orderedToken +// Tests of the internal.OrderedToken func TestOrderedToken(t *testing.T) { - if orderedToken([]byte{0, 0, 4, 2}).Less(orderedToken([]byte{0, 0, 4, 2})) { + if internal.OrderedToken([]byte{0, 0, 4, 2}).Less(internal.OrderedToken([]byte{0, 0, 4, 2})) { t.Errorf("Expected Less to return false, but was true") } - if !orderedToken([]byte{0, 0, 3}).Less(orderedToken([]byte{0, 0, 4, 2})) { + if !internal.OrderedToken([]byte{0, 0, 3}).Less(internal.OrderedToken([]byte{0, 0, 4, 2})) { t.Errorf("Expected Less to return true, but was false") } - if orderedToken([]byte{0, 0, 4, 2}).Less(orderedToken([]byte{0, 0, 3})) { + if internal.OrderedToken([]byte{0, 0, 4, 2}).Less(internal.OrderedToken([]byte{0, 0, 3})) { t.Errorf("Expected Less to return false, but was true") } } -// Tests of the randomPartitioner +// Tests of the internal.RandomPartitioner func TestRandomPartitioner(t *testing.T) { // at least verify that the partitioner // doesn't return nil - p := randomPartitioner{} - pk, _ := marshalInt(nil, 1) + p := internal.RandomPartitioner{} + pk, _ := internal.MarshalInt(nil, 1) token := p.Hash(pk) if token == nil { t.Fatal("token was nil") @@ -118,7 +120,7 @@ func TestRandomPartitioner(t *testing.T) { str := token.String() parsedToken := p.ParseString(str) - if (*big.Int)(token.(*randomToken)).Cmp((*big.Int)(parsedToken.(*randomToken))) != 0 { + if (*big.Int)(token.(*internal.RandomToken)).Cmp((*big.Int)(parsedToken.(*internal.RandomToken))) != 0 { t.Errorf("Failed to convert to and from a string %s expected %v but was %v", str, token, @@ -132,7 +134,7 @@ func TestRandomPartitionerMatchesReference(t *testing.T) { // >>> from cassandra.metadata import MD5Token // >>> MD5Token.hash_fn("test") // 12707736894140473154801792860916528374L - var p randomPartitioner + var p internal.RandomPartitioner expect := "12707736894140473154801792860916528374" actual := p.Hash([]byte("test")).String() if actual != expect { @@ -141,23 +143,23 @@ func TestRandomPartitionerMatchesReference(t *testing.T) { } } -// Tests of the randomToken +// Tests of the internal.RandomToken func TestRandomToken(t *testing.T) { - if ((*randomToken)(big.NewInt(42))).Less((*randomToken)(big.NewInt(42))) { + if ((*internal.RandomToken)(big.NewInt(42))).Less((*internal.RandomToken)(big.NewInt(42))) { t.Errorf("Expected Less to return false, but was true") } - if !((*randomToken)(big.NewInt(41))).Less((*randomToken)(big.NewInt(42))) { + if !((*internal.RandomToken)(big.NewInt(41))).Less((*internal.RandomToken)(big.NewInt(42))) { t.Errorf("Expected Less to return true, but was false") } - if ((*randomToken)(big.NewInt(42))).Less((*randomToken)(big.NewInt(41))) { + if ((*internal.RandomToken)(big.NewInt(42))).Less((*internal.RandomToken)(big.NewInt(41))) { t.Errorf("Expected Less to return false, but was true") } } type intToken int -func (i intToken) String() string { return strconv.Itoa(int(i)) } -func (i intToken) Less(token token) bool { return i < token.(intToken) } +func (i intToken) String() string { return strconv.Itoa(int(i)) } +func (i intToken) Less(token internal.Token) bool { return i < token.(intToken) } // Test of the token ring implementation based on example at the start of this // page of documentation: @@ -260,7 +262,7 @@ func TestTokenRing_Murmur3(t *testing.T) { t.Fatalf("Failed to create token ring due to error: %v", err) } - p := murmur3Partitioner{} + p := internal.Murmur3Partitioner{} for _, host := range hosts { actual, _ := ring.GetHostForToken(p.ParseString(host.tokens[0])) @@ -291,7 +293,7 @@ func TestTokenRing_Ordered(t *testing.T) { t.Fatalf("Failed to create token ring due to error: %v", err) } - p := orderedPartitioner{} + p := internal.OrderedPartitioner{} var actual *HostInfo for _, host := range hosts { @@ -322,7 +324,7 @@ func TestTokenRing_Random(t *testing.T) { t.Fatalf("Failed to create token ring due to error: %v", err) } - p := randomPartitioner{} + p := internal.RandomPartitioner{} var actual *HostInfo for _, host := range hosts { diff --git a/topology.go b/topology.go index 2fc38a887..5fd36dee1 100644 --- a/topology.go +++ b/topology.go @@ -29,11 +29,13 @@ import ( "sort" "strconv" "strings" + + "github.com/gocql/gocql/internal" ) type hostTokens struct { // token is end (inclusive) of token range these hosts belong to - token token + token internal.Token hosts []*HostInfo } @@ -48,7 +50,7 @@ func (h tokenRingReplicas) Less(i, j int) bool { return h[i].token.Less(h[j].tok func (h tokenRingReplicas) Len() int { return len(h) } func (h tokenRingReplicas) Swap(i, j int) { h[i], h[j] = h[j], h[i] } -func (h tokenRingReplicas) replicasFor(t token) *hostTokens { +func (h tokenRingReplicas) replicasFor(t internal.Token) *hostTokens { if len(h) == 0 { return nil } diff --git a/topology_test.go b/topology_test.go index fe8473e98..0c714deec 100644 --- a/topology_test.go +++ b/topology_test.go @@ -26,6 +26,7 @@ package gocql import ( "fmt" + "github.com/gocql/gocql/internal" "sort" "testing" ) @@ -136,7 +137,7 @@ func TestPlacementStrategy_NetworkStrategy(t *testing.T) { h := &HostInfo{hostId: fmt.Sprintf("%s:%s:%d", dc, rack, j), dataCenter: dc, rack: rack} token := hostToken{ - token: orderedToken([]byte(h.hostId)), + token: internal.OrderedToken([]byte(h.hostId)), host: h, } diff --git a/tuple_test.go b/tuple_test.go index 296c56feb..4dddd4141 100644 --- a/tuple_test.go +++ b/tuple_test.go @@ -35,7 +35,7 @@ import ( func TestTupleSimple(t *testing.T) { session := createSession(t) defer session.Close() - if session.cfg.ProtoVersion < protoVersion3 { + if session.cfg.ProtoVersion < internal.ProtoVersion3 { t.Skip("tuple types are only available of proto>=3") } @@ -79,7 +79,7 @@ func TestTupleSimple(t *testing.T) { func TestTuple_NullTuple(t *testing.T) { session := createSession(t) defer session.Close() - if session.cfg.ProtoVersion < protoVersion3 { + if session.cfg.ProtoVersion < internal.ProtoVersion3 { t.Skip("tuple types are only available of proto>=3") } @@ -117,7 +117,7 @@ func TestTuple_NullTuple(t *testing.T) { func TestTuple_TupleNotSet(t *testing.T) { session := createSession(t) defer session.Close() - if session.cfg.ProtoVersion < protoVersion3 { + if session.cfg.ProtoVersion < internal.ProtoVersion3 { t.Skip("tuple types are only available of proto>=3") } @@ -170,7 +170,7 @@ func TestTuple_TupleNotSet(t *testing.T) { func TestTupleMapScan(t *testing.T) { session := createSession(t) defer session.Close() - if session.cfg.ProtoVersion < protoVersion3 { + if session.cfg.ProtoVersion < internal.ProtoVersion3 { t.Skip("tuple types are only available of proto>=3") } @@ -203,7 +203,7 @@ func TestTupleMapScan(t *testing.T) { func TestTupleMapScanNil(t *testing.T) { session := createSession(t) defer session.Close() - if session.cfg.ProtoVersion < protoVersion3 { + if session.cfg.ProtoVersion < internal.ProtoVersion3 { t.Skip("tuple types are only available of proto>=3") } err := createTable(session, `CREATE TABLE gocql_test.tuple_map_scan_nil( @@ -234,7 +234,7 @@ func TestTupleMapScanNil(t *testing.T) { func TestTupleMapScanNotSet(t *testing.T) { session := createSession(t) defer session.Close() - if session.cfg.ProtoVersion < protoVersion3 { + if session.cfg.ProtoVersion < internal.ProtoVersion3 { t.Skip("tuple types are only available of proto>=3") } err := createTable(session, `CREATE TABLE gocql_test.tuple_map_scan_not_set( @@ -266,7 +266,7 @@ func TestTupleLastFieldEmpty(t *testing.T) { // Regression test - empty value used to be treated as NULL value in the last tuple field session := createSession(t) defer session.Close() - if session.cfg.ProtoVersion < protoVersion3 { + if session.cfg.ProtoVersion < internal.ProtoVersion3 { t.Skip("tuple types are only available of proto>=3") } err := createTable(session, `CREATE TABLE gocql_test.tuple_last_field_empty( @@ -304,7 +304,7 @@ func TestTupleLastFieldEmpty(t *testing.T) { func TestTuple_NestedCollection(t *testing.T) { session := createSession(t) defer session.Close() - if session.cfg.ProtoVersion < protoVersion3 { + if session.cfg.ProtoVersion < internal.ProtoVersion3 { t.Skip("tuple types are only available of proto>=3") } @@ -356,7 +356,7 @@ func TestTuple_NestedCollection(t *testing.T) { func TestTuple_NullableNestedCollection(t *testing.T) { session := createSession(t) defer session.Close() - if session.cfg.ProtoVersion < protoVersion3 { + if session.cfg.ProtoVersion < internal.ProtoVersion3 { t.Skip("tuple types are only available of proto>=3") } diff --git a/udt_test.go b/udt_test.go index f1980f243..70f9cc7b1 100644 --- a/udt_test.go +++ b/udt_test.go @@ -72,7 +72,7 @@ func TestUDT_Marshaler(t *testing.T) { session := createSession(t) defer session.Close() - if session.cfg.ProtoVersion < protoVersion3 { + if session.cfg.ProtoVersion < internal.ProtoVersion3 { t.Skip("UDT are only available on protocol >= 3") } @@ -129,7 +129,7 @@ func TestUDT_Reflect(t *testing.T) { session := createSession(t) defer session.Close() - if session.cfg.ProtoVersion < protoVersion3 { + if session.cfg.ProtoVersion < internal.ProtoVersion3 { t.Skip("UDT are only available on protocol >= 3") } @@ -188,7 +188,7 @@ func TestUDT_NullObject(t *testing.T) { session := createSession(t) defer session.Close() - if session.cfg.ProtoVersion < protoVersion3 { + if session.cfg.ProtoVersion < internal.ProtoVersion3 { t.Skip("UDT are only available on protocol >= 3") } @@ -242,7 +242,7 @@ func TestMapScanUDT(t *testing.T) { session := createSession(t) defer session.Close() - if session.cfg.ProtoVersion < protoVersion3 { + if session.cfg.ProtoVersion < internal.ProtoVersion3 { t.Skip("UDT are only available on protocol >= 3") } @@ -329,7 +329,7 @@ func TestUDT_MissingField(t *testing.T) { session := createSession(t) defer session.Close() - if session.cfg.ProtoVersion < protoVersion3 { + if session.cfg.ProtoVersion < internal.ProtoVersion3 { t.Skip("UDT are only available on protocol >= 3") } @@ -379,7 +379,7 @@ func TestUDT_EmptyCollections(t *testing.T) { session := createSession(t) defer session.Close() - if session.cfg.ProtoVersion < protoVersion3 { + if session.cfg.ProtoVersion < internal.ProtoVersion3 { t.Skip("UDT are only available on protocol >= 3") } @@ -435,7 +435,7 @@ func TestUDT_UpdateField(t *testing.T) { session := createSession(t) defer session.Close() - if session.cfg.ProtoVersion < protoVersion3 { + if session.cfg.ProtoVersion < internal.ProtoVersion3 { t.Skip("UDT are only available on protocol >= 3") } @@ -492,7 +492,7 @@ func TestUDT_ScanNullUDT(t *testing.T) { session := createSession(t) defer session.Close() - if session.cfg.ProtoVersion < protoVersion3 { + if session.cfg.ProtoVersion < internal.ProtoVersion3 { t.Skip("UDT are only available on protocol >= 3") }