Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions batch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func TestBatch_Errors(t *testing.T) {
session := createSession(t)
defer session.Close()

if session.cfg.ProtoVersion < protoVersion2 {
if session.cfg.ProtoVersion < internal.ProtoVersion2 {
t.Skip("atomic batches not supported. Please use Cassandra >= 2.0")
}

Expand All @@ -58,7 +58,7 @@ func TestBatch_WithTimestamp(t *testing.T) {
session := createSession(t)
defer session.Close()

if session.cfg.ProtoVersion < protoVersion3 {
if session.cfg.ProtoVersion < internal.ProtoVersion3 {
t.Skip("Batch timestamps are only available on protocol >= 3")
}

Expand Down
8 changes: 4 additions & 4 deletions cassandra_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1120,7 +1120,7 @@ func TestSmallInt(t *testing.T) {
session := createSession(t)
defer session.Close()

if session.cfg.ProtoVersion < protoVersion4 {
if session.cfg.ProtoVersion < internal.ProtoVersion4 {
t.Skip("smallint is only supported in cassandra 2.2+")
}

Expand Down Expand Up @@ -2146,7 +2146,7 @@ func TestGetTableMetadata(t *testing.T) {
if testTable == nil {
t.Fatal("Expected table metadata for name 'test_table_metadata'")
}
if session.cfg.ProtoVersion == protoVersion1 {
if session.cfg.ProtoVersion == internal.ProtoVersion1 {
if testTable.KeyValidator != "org.apache.cassandra.db.marshal.Int32Type" {
t.Errorf("Expected test_table_metadata key validator to be 'org.apache.cassandra.db.marshal.Int32Type' but was '%s'", testTable.KeyValidator)
}
Expand Down Expand Up @@ -2813,7 +2813,7 @@ func TestNegativeStream(t *testing.T) {

const stream = -50
writer := frameWriterFunc(func(f *framer, streamID int) error {
f.writeHeader(0, opOptions, stream)
f.writeHeader(0, internal.OpOptions, stream)
return f.finish()
})

Expand Down Expand Up @@ -3116,7 +3116,7 @@ func TestUnmarshallNestedTypes(t *testing.T) {
session := createSession(t)
defer session.Close()

if session.cfg.ProtoVersion < protoVersion3 {
if session.cfg.ProtoVersion < internal.ProtoVersion3 {
t.Skip("can not have frozen types in cassandra < 2.1.3")
}

Expand Down
2 changes: 1 addition & 1 deletion cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ type ClusterConfig struct {
ReconnectInterval time.Duration

// The maximum amount of time to wait for schema agreement in a cluster after
// receiving a schema change frame. (default: 60s)
// receiving a schema change internal.Frame. (default: 60s)
MaxWaitSchemaAgreement time.Duration

// HostFilter will filter all incoming events for host, any which don't pass
Expand Down
77 changes: 32 additions & 45 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,24 +39,11 @@ import (
"sync/atomic"
"time"

"github.com/gocql/gocql/internal"
"github.com/gocql/gocql/internal/lru"
"github.com/gocql/gocql/internal/streams"
)

// approve the authenticator with the list of allowed authenticators. If the provided list is empty,
// the given authenticator is allowed.
func approve(authenticator string, approvedAuthenticators []string) bool {
if len(approvedAuthenticators) == 0 {
return true
}
for _, s := range approvedAuthenticators {
if authenticator == s {
return true
}
}
return false
}

// JoinHostPort is a utility to return an address string that can be used
// by `gocql.Conn` to form a connection with a host.
func JoinHostPort(addr string, port int) string {
Expand Down Expand Up @@ -85,7 +72,7 @@ type PasswordAuthenticator struct {
}

func (p PasswordAuthenticator) Challenge(req []byte) ([]byte, Authenticator, error) {
if !approve(string(req), p.AllowedAuthenticators) {
if !internal.Approve(string(req), p.AllowedAuthenticators) {
return nil, nil, fmt.Errorf("unexpected authenticator %q", req)
}
resp := make([]byte, 2+len(p.Username)+len(p.Password))
Expand Down Expand Up @@ -188,7 +175,7 @@ type Conn struct {
frameObserver FrameHeaderObserver
streamObserver StreamObserver

headerBuf [maxFrameHeaderSize]byte
headerBuf [internal.MaxFrameHeaderSize]byte

streams *streams.IDGenerator
mu sync.Mutex
Expand Down Expand Up @@ -406,7 +393,7 @@ func (s *startupCoordinator) setupConn(ctx context.Context) error {
return nil
}

func (s *startupCoordinator) write(ctx context.Context, frame frameBuilder) (frame, error) {
func (s *startupCoordinator) write(ctx context.Context, frame frameBuilder) (internal.Frame, error) {
select {
case s.frameTicker <- struct{}{}:
case <-ctx.Done():
Expand Down Expand Up @@ -585,23 +572,23 @@ func (c *Conn) serve(ctx context.Context) {
c.closeWithError(err)
}

func (c *Conn) discardFrame(head frameHeader) error {
_, err := io.CopyN(ioutil.Discard, c, int64(head.length))
func (c *Conn) discardFrame(head internal.FrameHeader) error {
_, err := io.CopyN(ioutil.Discard, c, int64(head.Length))
if err != nil {
return err
}
return nil
}

type protocolError struct {
frame frame
frame internal.Frame
}

func (p *protocolError) Error() string {
if err, ok := p.frame.(error); ok {
return err.Error()
}
return fmt.Sprintf("gocql: received unexpected frame on stream %d: %v", p.frame.Header().stream, p.frame)
return fmt.Sprintf("gocql: received unexpected frame on stream %d: %v", p.frame.Header().Stream, p.frame)
}

func (c *Conn) heartBeat(ctx context.Context) {
Expand Down Expand Up @@ -670,28 +657,28 @@ func (c *Conn) recv(ctx context.Context) error {

if c.frameObserver != nil {
c.frameObserver.ObserveFrameHeader(context.Background(), ObservedFrameHeader{
Version: protoVersion(head.version),
Flags: head.flags,
Stream: int16(head.stream),
Opcode: frameOp(head.op),
Length: int32(head.length),
Version: internal.ProtoVersion(head.Version),
Flags: head.Flags,
Stream: int16(head.Stream),
Opcode: internal.FrameOp(head.Op),
Length: int32(head.Length),
Start: headStartTime,
End: headEndTime,
Host: c.host,
})
}

if head.stream > c.streams.NumStreams {
return fmt.Errorf("gocql: frame header stream is beyond call expected bounds: %d", head.stream)
} else if head.stream == -1 {
if head.Stream > c.streams.NumStreams {
return fmt.Errorf("gocql: frame header stream is beyond call expected bounds: %d", head.Stream)
} else if head.Stream == -1 {
// TODO: handle cassandra event frames, we shouldnt get any currently
framer := newFramer(c.compressor, c.version)
if err := framer.readFrame(c, &head); err != nil {
return err
}
go c.session.handleEvent(framer)
return nil
} else if head.stream <= 0 {
} else if head.Stream <= 0 {
// reserved stream that we dont use, probably due to a protocol error
// or a bug in Cassandra, this should be an error, parse it and return.
framer := newFramer(c.compressor, c.version)
Expand All @@ -714,14 +701,14 @@ func (c *Conn) recv(ctx context.Context) error {
c.mu.Unlock()
return ErrConnectionClosed
}
call, ok := c.calls[head.stream]
delete(c.calls, head.stream)
call, ok := c.calls[head.Stream]
delete(c.calls, head.Stream)
c.mu.Unlock()
if call == nil || !ok {
c.logger.Printf("gocql: received response for stream which has no handler: header=%v\n", head)
return c.discardFrame(head)
} else if head.stream != call.streamID {
panic(fmt.Sprintf("call has incorrect streamID: got %d expected %d", call.streamID, head.stream))
} else if head.Stream != call.streamID {
panic(fmt.Sprintf("call has incorrect streamID: got %d expected %d", call.streamID, head.Stream))
}

framer := newFramer(c.compressor, c.version)
Expand Down Expand Up @@ -1150,7 +1137,7 @@ func (c *Conn) exec(ctx context.Context, req frameBuilder, tracer Tracer) (*fram
// requests on the stream to prevent nil pointer dereferences in recv().
defer c.releaseStream(call)

if v := resp.framer.header.version.version(); v != c.version {
if v := resp.framer.header.Version.Version(); v != c.version {
return nil, NewErrProtocol("unexpected protocol version in response: got %d expected %d", v, c.version)
}

Expand Down Expand Up @@ -1244,7 +1231,7 @@ func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer)
prep := &writePrepareFrame{
statement: stmt,
}
if c.version > protoVersion4 {
if c.version > internal.ProtoVersion4 {
prep.keyspace = c.currentKeyspace
}

Expand Down Expand Up @@ -1276,7 +1263,7 @@ func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer)
flight.preparedStatment = &preparedStatment{
// defensively copy as we will recycle the underlying buffer after we
// return.
id: copyBytes(x.preparedID),
id: internal.CopyBytes(x.preparedID),
// the type info's should _not_ have a reference to the framers read buffer,
// therefore we can just copy them directly.
request: x.reqMeta,
Expand All @@ -1303,12 +1290,12 @@ func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer)
}

func marshalQueryValue(typ TypeInfo, value interface{}, dst *queryValues) error {
if named, ok := value.(*namedValue); ok {
dst.name = named.name
value = named.value
if named, ok := value.(*internal.NamedValue); ok {
dst.name = named.Name
value = named.Value
}

if _, ok := value.(unsetColumn); !ok {
if _, ok := value.(internal.UnsetColumn); !ok {
val, err := Marshal(typ, value)
if err != nil {
return err
Expand Down Expand Up @@ -1338,7 +1325,7 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter {
if qry.pageSize > 0 {
params.pageSize = qry.pageSize
}
if c.version > protoVersion4 {
if c.version > internal.ProtoVersion4 {
params.keyspace = c.currentKeyspace
}

Expand Down Expand Up @@ -1431,7 +1418,7 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter {
if params.skipMeta {
if info != nil {
iter.meta = info.response
iter.meta.pagingState = copyBytes(x.meta.pagingState)
iter.meta.pagingState = internal.CopyBytes(x.meta.pagingState)
} else {
return &Iter{framer: framer, err: errors.New("gocql: did not receive metadata but prepared info is nil")}
}
Expand All @@ -1442,7 +1429,7 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter {
if x.meta.morePages() && !qry.disableAutoPage {
newQry := new(Query)
*newQry = *qry
newQry.pageState = copyBytes(x.meta.pagingState)
newQry.pageState = internal.CopyBytes(x.meta.pagingState)
newQry.metrics = &queryMetrics{m: make(map[string]*hostMetrics)}

iter.next = &nextIter{
Expand Down Expand Up @@ -1531,7 +1518,7 @@ func (c *Conn) UseKeyspace(keyspace string) error {
}

func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter {
if c.version == protoVersion1 {
if c.version == internal.ProtoVersion1 {
return &Iter{err: ErrUnsupported}
}

Expand Down
Loading