diff --git a/cassandra_test.go b/cassandra_test.go index 3b0c61053..e17ee6922 100644 --- a/cassandra_test.go +++ b/cassandra_test.go @@ -32,6 +32,7 @@ import ( "context" "errors" "fmt" + "github.com/gocql/gocql/protocol" "github.com/stretchr/testify/require" "io" "math" @@ -519,23 +520,23 @@ func TestDurationType(t *testing.T) { t.Fatal("create:", err) } - durations := []Duration{ - Duration{ + durations := []protocol.Duration{ + protocol.Duration{ Months: 250, Days: 500, Nanoseconds: 300010001, }, - Duration{ + protocol.Duration{ Months: -250, Days: -500, Nanoseconds: -300010001, }, - Duration{ + protocol.Duration{ Months: 0, Days: 128, Nanoseconds: 127, }, - Duration{ + protocol.Duration{ Months: 0x7FFFFFFF, Days: 0x7FFFFFFF, Nanoseconds: 0x7FFFFFFFFFFFFFFF, @@ -547,7 +548,7 @@ func TestDurationType(t *testing.T) { } var id int - var duration Duration + var duration protocol.Duration if err := session.Query(`SELECT k, v FROM gocql_test.duration_table`).Scan(&id, &duration); err != nil { t.Fatal(err) } diff --git a/cluster.go b/cluster.go index 413695ca4..b1ed58c08 100644 --- a/cluster.go +++ b/cluster.go @@ -27,6 +27,7 @@ package gocql import ( "context" "errors" + "github.com/gocql/gocql/consistency" "net" "time" ) @@ -114,7 +115,7 @@ type ClusterConfig struct { // Default consistency level. // Default: Quorum - Consistency Consistency + Consistency consistency.Consistency // Compression algorithm. // Default: nil @@ -156,7 +157,7 @@ type ClusterConfig struct { // Consistency for the serial part of queries, values can be either SERIAL or LOCAL_SERIAL. // Default: unset - SerialConsistency SerialConsistency + SerialConsistency consistency.SerialConsistency // SslOpts configures TLS use when HostDialer is not set. // SslOpts is ignored if HostDialer is set. diff --git a/compressor.go b/compressor.go index f3d451a9f..6669ad82a 100644 --- a/compressor.go +++ b/compressor.go @@ -28,6 +28,7 @@ import ( "github.com/golang/snappy" ) +// Deprecated: use compressor.Compressor instead type Compressor interface { Name() string Encode(data []byte) ([]byte, error) diff --git a/compressor/compressor.go b/compressor/compressor.go new file mode 100644 index 000000000..f54906899 --- /dev/null +++ b/compressor/compressor.go @@ -0,0 +1,26 @@ +package compressor + +import "github.com/golang/snappy" + +type Compressor interface { + Name() string + Encode(data []byte) ([]byte, error) + Decode(data []byte) ([]byte, error) +} + +// SnappyCompressor implements the Compressor interface and can be used to +// compress incoming and outgoing frames. The snappy compression algorithm +// aims for very high speeds and reasonable compression. +type SnappyCompressor struct{} + +func (s SnappyCompressor) Name() string { + return "snappy" +} + +func (s SnappyCompressor) Encode(data []byte) ([]byte, error) { + return snappy.Encode(nil, data), nil +} + +func (s SnappyCompressor) Decode(data []byte) ([]byte, error) { + return snappy.Decode(nil, data) +} diff --git a/conn.go b/conn.go index ae02bd71c..930e4e9ea 100644 --- a/conn.go +++ b/conn.go @@ -30,6 +30,7 @@ import ( "crypto/tls" "errors" "fmt" + "github.com/gocql/gocql/internal" "io" "io/ioutil" "net" @@ -1276,7 +1277,7 @@ func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer) flight.preparedStatment = &preparedStatment{ // defensively copy as we will recycle the underlying buffer after we // return. - id: copyBytes(x.preparedID), + id: internal.CopyBytes(x.preparedID), // the type info's should _not_ have a reference to the framers read buffer, // therefore we can just copy them directly. request: x.reqMeta, @@ -1431,7 +1432,7 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { if params.skipMeta { if info != nil { iter.meta = info.response - iter.meta.pagingState = copyBytes(x.meta.pagingState) + iter.meta.pagingState = internal.CopyBytes(x.meta.pagingState) } else { return &Iter{framer: framer, err: errors.New("gocql: did not receive metadata but prepared info is nil")} } @@ -1442,7 +1443,7 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { if x.meta.morePages() && !qry.disableAutoPage { newQry := new(Query) *newQry = *qry - newQry.pageState = copyBytes(x.meta.pagingState) + newQry.pageState = internal.CopyBytes(x.meta.pagingState) newQry.metrics = &queryMetrics{m: make(map[string]*hostMetrics)} iter.next = &nextIter{ @@ -1659,7 +1660,7 @@ func (c *Conn) querySystemPeers(ctx context.Context, version cassVersion) *Iter err := iter.checkErrAndNotFound() if err != nil { - if errFrame, ok := err.(errorFrame); ok && errFrame.code == ErrCodeInvalid { // system.peers_v2 not found, try system.peers + if errFrame, ok := err.(errorFrame); ok && errFrame.code == gocql_errors.ErrCodeInvalid { // system.peers_v2 not found, try system.peers c.mu.Lock() c.isSchemaV2 = false c.mu.Unlock() diff --git a/consistency/consistency.go b/consistency/consistency.go new file mode 100644 index 000000000..330fa6822 --- /dev/null +++ b/consistency/consistency.go @@ -0,0 +1,137 @@ +package consistency + +import ( + "fmt" + "strings" +) + +type Consistency uint16 + +const ( + Any Consistency = 0x00 + One Consistency = 0x01 + Two Consistency = 0x02 + Three Consistency = 0x03 + Quorum Consistency = 0x04 + All Consistency = 0x05 + LocalQuorum Consistency = 0x06 + EachQuorum Consistency = 0x07 + LocalOne Consistency = 0x0A +) + +func (c Consistency) String() string { + switch c { + case Any: + return "ANY" + case One: + return "ONE" + case Two: + return "TWO" + case Three: + return "THREE" + case Quorum: + return "QUORUM" + case All: + return "ALL" + case LocalQuorum: + return "LOCAL_QUORUM" + case EachQuorum: + return "EACH_QUORUM" + case LocalOne: + return "LOCAL_ONE" + default: + return fmt.Sprintf("UNKNOWN_CONS_0x%x", uint16(c)) + } +} + +func (c Consistency) MarshalText() (text []byte, err error) { + return []byte(c.String()), nil +} + +func (c *Consistency) UnmarshalText(text []byte) error { + switch string(text) { + case "ANY": + *c = Any + case "ONE": + *c = One + case "TWO": + *c = Two + case "THREE": + *c = Three + case "QUORUM": + *c = Quorum + case "ALL": + *c = All + case "LOCAL_QUORUM": + *c = LocalQuorum + case "EACH_QUORUM": + *c = EachQuorum + case "LOCAL_ONE": + *c = LocalOne + default: + return fmt.Errorf("invalid consistency %q", string(text)) + } + + return nil +} + +func ParseConsistency(s string) Consistency { + var c Consistency + if err := c.UnmarshalText([]byte(strings.ToUpper(s))); err != nil { + panic(err) + } + return c +} + +// ParseConsistencyWrapper wraps gocql.ParseConsistency to provide an err +// return instead of a panic +func ParseConsistencyWrapper(s string) (consistency Consistency, err error) { + err = consistency.UnmarshalText([]byte(strings.ToUpper(s))) + return +} + +// MustParseConsistency is the same as ParseConsistency except it returns +// an error (never). It is kept here since breaking changes are not good. +// DEPRECATED: use ParseConsistency if you want a panic on parse error. +func MustParseConsistency(s string) (Consistency, error) { + c, err := ParseConsistencyWrapper(s) + if err != nil { + panic(err) + } + return c, nil +} + +type SerialConsistency uint16 + +const ( + Serial SerialConsistency = 0x08 + LocalSerial SerialConsistency = 0x09 +) + +func (s SerialConsistency) String() string { + switch s { + case Serial: + return "SERIAL" + case LocalSerial: + return "LOCAL_SERIAL" + default: + return fmt.Sprintf("UNKNOWN_SERIAL_CONS_0x%x", uint16(s)) + } +} + +func (s SerialConsistency) MarshalText() (text []byte, err error) { + return []byte(s.String()), nil +} + +func (s *SerialConsistency) UnmarshalText(text []byte) error { + switch string(text) { + case "SERIAL": + *s = Serial + case "LOCAL_SERIAL": + *s = LocalSerial + default: + return fmt.Errorf("invalid consistency %q", string(text)) + } + + return nil +} diff --git a/errors.go b/errors.go index d64c46208..944db8f9e 100644 --- a/errors.go +++ b/errors.go @@ -24,7 +24,9 @@ package gocql -import "fmt" +import ( + "github.com/gocql/gocql/internal" +) // See CQL Binary Protocol v5, section 8 for more details. // https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec @@ -32,191 +34,127 @@ const ( // ErrCodeServer indicates unexpected error on server-side. // // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1246-L1247 - ErrCodeServer = 0x0000 + // Deprecated: use gocql_errors.ErrCodeServer instead. + ErrCodeServer = internal.ErrCodeServer // ErrCodeProtocol indicates a protocol violation by some client message. // // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1248-L1250 - ErrCodeProtocol = 0x000A + // Deprecated: use gocql_errors.ErrCodeProtocol instead. + ErrCodeProtocol = internal.ErrCodeProtocol // ErrCodeCredentials indicates missing required authentication. // // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1251-L1254 - ErrCodeCredentials = 0x0100 + // Deprecated: use gocql_errors.ErrCodeCredentials instead. + ErrCodeCredentials = internal.ErrCodeCredentials // ErrCodeUnavailable indicates unavailable error. // // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1255-L1265 - ErrCodeUnavailable = 0x1000 + // Deprecated: use gocql_errors.ErrCodeUnavailable instead. + ErrCodeUnavailable = internal.ErrCodeUnavailable // ErrCodeOverloaded returned in case of request on overloaded node coordinator. // // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1266-L1267 - ErrCodeOverloaded = 0x1001 + // Deprecated: use gocql_errors.ErrCodeOverloaded instead. + ErrCodeOverloaded = internal.ErrCodeOverloaded // ErrCodeBootstrapping returned from the coordinator node in bootstrapping phase. // // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1268-L1269 - ErrCodeBootstrapping = 0x1002 + // Deprecated: use gocql_errors.ErrCodeBootstrapping instead. + ErrCodeBootstrapping = internal.ErrCodeBootstrapping // ErrCodeTruncate indicates truncation exception. // // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1270 - ErrCodeTruncate = 0x1003 + // Deprecated: use gocql_errors.ErrCodeTruncate instead. + ErrCodeTruncate = internal.ErrCodeTruncate // ErrCodeWriteTimeout returned in case of timeout during the request write. // // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1271-L1304 - ErrCodeWriteTimeout = 0x1100 + // Deprecated: use gocql_errors.ErrCodeWriteTimeout instead. + ErrCodeWriteTimeout = internal.ErrCodeWriteTimeout // ErrCodeReadTimeout returned in case of timeout during the request read. // // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1305-L1321 - ErrCodeReadTimeout = 0x1200 + // Deprecated: use gocql_errors.ErrCodeReadTimeout instead. + ErrCodeReadTimeout = internal.ErrCodeReadTimeout // ErrCodeReadFailure indicates request read error which is not covered by ErrCodeReadTimeout. // // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1322-L1340 - ErrCodeReadFailure = 0x1300 + // Deprecated: use gocql_errors.ErrCodeReadFailure instead. + ErrCodeReadFailure = internal.ErrCodeReadFailure // ErrCodeFunctionFailure indicates an error in user-defined function. // // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1341-L1347 - ErrCodeFunctionFailure = 0x1400 + // Deprecated: use gocql_errors.ErrCodeFunctionFailure instead. + ErrCodeFunctionFailure = internal.ErrCodeFunctionFailure // ErrCodeWriteFailure indicates request write error which is not covered by ErrCodeWriteTimeout. // // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1348-L1385 - ErrCodeWriteFailure = 0x1500 + // Deprecated: use gocql_errors.ErrCodeWriteFailure instead. + ErrCodeWriteFailure = internal.ErrCodeWriteFailure // ErrCodeCDCWriteFailure is defined, but not yet documented in CQLv5 protocol. // // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1386 - ErrCodeCDCWriteFailure = 0x1600 + // Deprecated: use gocql_errors.ErrCodeCDCWriteFailure instead. + ErrCodeCDCWriteFailure = internal.ErrCodeCDCWriteFailure // ErrCodeCASWriteUnknown indicates only partially completed CAS operation. // // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1387-L1397 - ErrCodeCASWriteUnknown = 0x1700 + // Deprecated: use gocql_errors.ErrCodeCASWriteUnknown instead. + ErrCodeCASWriteUnknown = internal.ErrCodeCASWriteUnknown // ErrCodeSyntax indicates the syntax error in the query. // // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1399 - ErrCodeSyntax = 0x2000 + // Deprecated: use gocql_errors.ErrCodeSyntax instead. + ErrCodeSyntax = internal.ErrCodeSyntax // ErrCodeUnauthorized indicates access rights violation by user on performed operation. // // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1400-L1401 - ErrCodeUnauthorized = 0x2100 + // Deprecated: use gocql_errors.ErrCodeUnauthorized instead. + ErrCodeUnauthorized = internal.ErrCodeUnauthorized // ErrCodeInvalid indicates invalid query error which is not covered by ErrCodeSyntax. // // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1402 - ErrCodeInvalid = 0x2200 + // Deprecated: use gocql_errors.ErrCodeInvalid instead. + ErrCodeInvalid = internal.ErrCodeInvalid // ErrCodeConfig indicates the configuration error. // // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1403 - ErrCodeConfig = 0x2300 + // Deprecated: use gocql_errors.ErrCodeConfig instead. + ErrCodeConfig = internal.ErrCodeConfig // ErrCodeAlreadyExists is returned for the requests creating the existing keyspace/table. // // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1404-L1413 - ErrCodeAlreadyExists = 0x2400 + // Deprecated: use gocql_errors.ErrCodeAlreadyExists instead. + ErrCodeAlreadyExists = internal.ErrCodeAlreadyExists // ErrCodeUnprepared returned from the host for prepared statement which is unknown. // // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1414-L1417 - ErrCodeUnprepared = 0x2500 + // Deprecated: use gocql_errors.ErrCodeUnprepared instead. + ErrCodeUnprepared = internal.ErrCodeUnprepared ) -type RequestError interface { - Code() int - Message() string - Error() string -} - -type errorFrame struct { - frameHeader - - code int - message string -} - -func (e errorFrame) Code() int { - return e.code -} - -func (e errorFrame) Message() string { - return e.message -} - -func (e errorFrame) Error() string { - return e.Message() -} - -func (e errorFrame) String() string { - return fmt.Sprintf("[error code=%x message=%q]", e.code, e.message) -} - -type RequestErrUnavailable struct { - errorFrame - Consistency Consistency - Required int - Alive int -} - -func (e *RequestErrUnavailable) String() string { - return fmt.Sprintf("[request_error_unavailable consistency=%s required=%d alive=%d]", e.Consistency, e.Required, e.Alive) -} - -type ErrorMap map[string]uint16 - -type RequestErrWriteTimeout struct { - errorFrame - Consistency Consistency - Received int - BlockFor int - WriteType string -} - -type RequestErrWriteFailure struct { - errorFrame - Consistency Consistency - Received int - BlockFor int - NumFailures int - WriteType string - ErrorMap ErrorMap -} - -type RequestErrCDCWriteFailure struct { - errorFrame -} - -type RequestErrReadTimeout struct { - errorFrame - Consistency Consistency - Received int - BlockFor int - DataPresent byte -} - -type RequestErrAlreadyExists struct { - errorFrame - Keyspace string - Table string -} - -type RequestErrUnprepared struct { - errorFrame - StatementId []byte -} - -type RequestErrReadFailure struct { - errorFrame - Consistency Consistency - Received int - BlockFor int - NumFailures int - DataPresent bool - ErrorMap ErrorMap -} - -type RequestErrFunctionFailure struct { - errorFrame - Keyspace string - Function string - ArgTypes []string -} - -// RequestErrCASWriteUnknown is distinct error for ErrCodeCasWriteUnknown. -// -// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1387-L1397 -type RequestErrCASWriteUnknown struct { - errorFrame - Consistency Consistency - Received int - BlockFor int -} +type RequestError = internal.RequestError + +type errorFrame = internal.ErrorFrame + +type RequestErrUnavailable = internal.RequestErrUnavailable + +type ErrorMap = internal.ErrorMap + +type RequestErrWriteTimeout = internal.RequestErrWriteTimeout + +type RequestErrWriteFailure = internal.RequestErrWriteFailure + +type RequestErrCDCWriteFailure = internal.RequestErrCDCWriteFailure + +type RequestErrReadTimeout = internal.RequestErrReadTimeout + +type RequestErrAlreadyExists = internal.RequestErrAlreadyExists + +type RequestErrUnprepared = internal.RequestErrUnprepared + +type RequestErrReadFailure = internal.RequestErrReadFailure + +type RequestErrFunctionFailure = internal.RequestErrFunctionFailure + +type RequestErrCASWriteUnknown = internal.RequestErrCASWriteUnknown diff --git a/frame.go b/frame.go index d374ae574..144f62e3f 100644 --- a/frame.go +++ b/frame.go @@ -26,18 +26,12 @@ package gocql import ( "context" - "errors" "fmt" - "io" - "io/ioutil" - "net" - "runtime" + "github.com/gocql/gocql/internal" "strings" "time" ) -type unsetColumn struct{} - // UnsetValue represents a value used in a query binding that will be ignored by Cassandra. // // By setting a field to the unset value Cassandra will ignore the write completely. @@ -45,151 +39,17 @@ type unsetColumn struct{} // want to update some fields, where before you needed to make another prepared statement. // // UnsetValue is only available when using the version 4 of the protocol. -var UnsetValue = unsetColumn{} - -type namedValue struct { - name string - value interface{} -} +var UnsetValue = internal.UnsetColumn{} // NamedValue produce a value which will bind to the named parameter in a query func NamedValue(name string, value interface{}) interface{} { - return &namedValue{ - name: name, - value: value, - } -} - -const ( - protoDirectionMask = 0x80 - protoVersionMask = 0x7F - protoVersion1 = 0x01 - protoVersion2 = 0x02 - protoVersion3 = 0x03 - protoVersion4 = 0x04 - protoVersion5 = 0x05 - - maxFrameSize = 256 * 1024 * 1024 -) - -type protoVersion byte - -func (p protoVersion) request() bool { - return p&protoDirectionMask == 0x00 -} - -func (p protoVersion) response() bool { - return p&protoDirectionMask == 0x80 -} - -func (p protoVersion) version() byte { - return byte(p) & protoVersionMask -} - -func (p protoVersion) String() string { - dir := "REQ" - if p.response() { - dir = "RESP" - } - - return fmt.Sprintf("[version=%d direction=%s]", p.version(), dir) -} - -type frameOp byte - -const ( - // header ops - opError frameOp = 0x00 - opStartup frameOp = 0x01 - opReady frameOp = 0x02 - opAuthenticate frameOp = 0x03 - opOptions frameOp = 0x05 - opSupported frameOp = 0x06 - opQuery frameOp = 0x07 - opResult frameOp = 0x08 - opPrepare frameOp = 0x09 - opExecute frameOp = 0x0A - opRegister frameOp = 0x0B - opEvent frameOp = 0x0C - opBatch frameOp = 0x0D - opAuthChallenge frameOp = 0x0E - opAuthResponse frameOp = 0x0F - opAuthSuccess frameOp = 0x10 -) - -func (f frameOp) String() string { - switch f { - case opError: - return "ERROR" - case opStartup: - return "STARTUP" - case opReady: - return "READY" - case opAuthenticate: - return "AUTHENTICATE" - case opOptions: - return "OPTIONS" - case opSupported: - return "SUPPORTED" - case opQuery: - return "QUERY" - case opResult: - return "RESULT" - case opPrepare: - return "PREPARE" - case opExecute: - return "EXECUTE" - case opRegister: - return "REGISTER" - case opEvent: - return "EVENT" - case opBatch: - return "BATCH" - case opAuthChallenge: - return "AUTH_CHALLENGE" - case opAuthResponse: - return "AUTH_RESPONSE" - case opAuthSuccess: - return "AUTH_SUCCESS" - default: - return fmt.Sprintf("UNKNOWN_OP_%d", f) + return &internal.NamedValue{ + Name: name, + Value: value, } } -const ( - // result kind - resultKindVoid = 1 - resultKindRows = 2 - resultKindKeyspace = 3 - resultKindPrepared = 4 - resultKindSchemaChanged = 5 - - // rows flags - flagGlobalTableSpec int = 0x01 - flagHasMorePages int = 0x02 - flagNoMetaData int = 0x04 - - // query flags - flagValues byte = 0x01 - flagSkipMetaData byte = 0x02 - flagPageSize byte = 0x04 - flagWithPagingState byte = 0x08 - flagWithSerialConsistency byte = 0x10 - flagDefaultTimestamp byte = 0x20 - flagWithNameValues byte = 0x40 - flagWithKeyspace byte = 0x80 - - // prepare flags - flagWithPreparedKeyspace uint32 = 0x01 - - // header flags - flagCompress byte = 0x01 - flagTracing byte = 0x02 - flagCustomPayload byte = 0x04 - flagWarning byte = 0x08 - flagBetaProtocol byte = 0x10 -) - +// TODO: Deprecate consystency, and use it from package consistency type Consistency uint16 const ( @@ -321,44 +181,13 @@ func (s *SerialConsistency) UnmarshalText(text []byte) error { return nil } -const ( - apacheCassandraTypePrefix = "org.apache.cassandra.db.marshal." -) - -var ( - ErrFrameTooBig = errors.New("frame length is bigger than the maximum allowed") -) - -const maxFrameHeaderSize = 9 - -func readInt(p []byte) int32 { - return int32(p[0])<<24 | int32(p[1])<<16 | int32(p[2])<<8 | int32(p[3]) -} - -type frameHeader struct { - version protoVersion - flags byte - stream int - op frameOp - length int - warnings []string -} - -func (f frameHeader) String() string { - return fmt.Sprintf("[header version=%s flags=0x%x stream=%d op=%s length=%d]", f.version, f.flags, f.stream, f.op, f.length) -} - -func (f frameHeader) Header() frameHeader { - return f -} - -const defaultBufSize = 128 +// TODO type ObservedFrameHeader struct { - Version protoVersion + Version internal.ProtoVersion Flags byte Stream int16 - Opcode frameOp + Opcode internal.FrameOp Length int32 // StartHeader is the time we started reading the frame header off the network connection. @@ -381,1692 +210,3 @@ type FrameHeaderObserver interface { // ObserveFrameHeader gets called on every received frame header. ObserveFrameHeader(context.Context, ObservedFrameHeader) } - -// a framer is responsible for reading, writing and parsing frames on a single stream -type framer struct { - proto byte - // flags are for outgoing flags, enabling compression and tracing etc - flags byte - compres Compressor - headSize int - // if this frame was read then the header will be here - header *frameHeader - - // if tracing flag is set this is not nil - traceID []byte - - // holds a ref to the whole byte slice for buf so that it can be reset to - // 0 after a read. - readBuffer []byte - - buf []byte - - customPayload map[string][]byte -} - -func newFramer(compressor Compressor, version byte) *framer { - buf := make([]byte, defaultBufSize) - f := &framer{ - buf: buf[:0], - readBuffer: buf, - } - var flags byte - if compressor != nil { - flags |= flagCompress - } - if version == protoVersion5 { - flags |= flagBetaProtocol - } - - version &= protoVersionMask - - headSize := 8 - if version > protoVersion2 { - headSize = 9 - } - - f.compres = compressor - f.proto = version - f.flags = flags - f.headSize = headSize - - f.header = nil - f.traceID = nil - - return f -} - -type frame interface { - Header() frameHeader -} - -func readHeader(r io.Reader, p []byte) (head frameHeader, err error) { - _, err = io.ReadFull(r, p[:1]) - if err != nil { - return frameHeader{}, err - } - - version := p[0] & protoVersionMask - - if version < protoVersion1 || version > protoVersion5 { - return frameHeader{}, fmt.Errorf("gocql: unsupported protocol response version: %d", version) - } - - headSize := 9 - if version < protoVersion3 { - headSize = 8 - } - - _, err = io.ReadFull(r, p[1:headSize]) - if err != nil { - return frameHeader{}, err - } - - p = p[:headSize] - - head.version = protoVersion(p[0]) - head.flags = p[1] - - if version > protoVersion2 { - if len(p) != 9 { - return frameHeader{}, fmt.Errorf("not enough bytes to read header require 9 got: %d", len(p)) - } - - head.stream = int(int16(p[2])<<8 | int16(p[3])) - head.op = frameOp(p[4]) - head.length = int(readInt(p[5:])) - } else { - if len(p) != 8 { - return frameHeader{}, fmt.Errorf("not enough bytes to read header require 8 got: %d", len(p)) - } - - head.stream = int(int8(p[2])) - head.op = frameOp(p[3]) - head.length = int(readInt(p[4:])) - } - - return head, nil -} - -// explicitly enables tracing for the framers outgoing requests -func (f *framer) trace() { - f.flags |= flagTracing -} - -// explicitly enables the custom payload flag -func (f *framer) payload() { - f.flags |= flagCustomPayload -} - -// reads a frame form the wire into the framers buffer -func (f *framer) readFrame(r io.Reader, head *frameHeader) error { - if head.length < 0 { - return fmt.Errorf("frame body length can not be less than 0: %d", head.length) - } else if head.length > maxFrameSize { - // need to free up the connection to be used again - _, err := io.CopyN(ioutil.Discard, r, int64(head.length)) - if err != nil { - return fmt.Errorf("error whilst trying to discard frame with invalid length: %v", err) - } - return ErrFrameTooBig - } - - if cap(f.readBuffer) >= head.length { - f.buf = f.readBuffer[:head.length] - } else { - f.readBuffer = make([]byte, head.length) - f.buf = f.readBuffer - } - - // assume the underlying reader takes care of timeouts and retries - n, err := io.ReadFull(r, f.buf) - if err != nil { - return fmt.Errorf("unable to read frame body: read %d/%d bytes: %v", n, head.length, err) - } - - if head.flags&flagCompress == flagCompress { - if f.compres == nil { - return NewErrProtocol("no compressor available with compressed frame body") - } - - f.buf, err = f.compres.Decode(f.buf) - if err != nil { - return err - } - } - - f.header = head - return nil -} - -func (f *framer) parseFrame() (frame frame, err error) { - defer func() { - if r := recover(); r != nil { - if _, ok := r.(runtime.Error); ok { - panic(r) - } - err = r.(error) - } - }() - - if f.header.version.request() { - return nil, NewErrProtocol("got a request frame from server: %v", f.header.version) - } - - if f.header.flags&flagTracing == flagTracing { - f.readTrace() - } - - if f.header.flags&flagWarning == flagWarning { - f.header.warnings = f.readStringList() - } - - if f.header.flags&flagCustomPayload == flagCustomPayload { - f.customPayload = f.readBytesMap() - } - - // assumes that the frame body has been read into rbuf - switch f.header.op { - case opError: - frame = f.parseErrorFrame() - case opReady: - frame = f.parseReadyFrame() - case opResult: - frame, err = f.parseResultFrame() - case opSupported: - frame = f.parseSupportedFrame() - case opAuthenticate: - frame = f.parseAuthenticateFrame() - case opAuthChallenge: - frame = f.parseAuthChallengeFrame() - case opAuthSuccess: - frame = f.parseAuthSuccessFrame() - case opEvent: - frame = f.parseEventFrame() - default: - return nil, NewErrProtocol("unknown op in frame header: %s", f.header.op) - } - - return -} - -func (f *framer) parseErrorFrame() frame { - code := f.readInt() - msg := f.readString() - - errD := errorFrame{ - frameHeader: *f.header, - code: code, - message: msg, - } - - switch code { - case ErrCodeUnavailable: - cl := f.readConsistency() - required := f.readInt() - alive := f.readInt() - return &RequestErrUnavailable{ - errorFrame: errD, - Consistency: cl, - Required: required, - Alive: alive, - } - case ErrCodeWriteTimeout: - cl := f.readConsistency() - received := f.readInt() - blockfor := f.readInt() - writeType := f.readString() - return &RequestErrWriteTimeout{ - errorFrame: errD, - Consistency: cl, - Received: received, - BlockFor: blockfor, - WriteType: writeType, - } - case ErrCodeReadTimeout: - cl := f.readConsistency() - received := f.readInt() - blockfor := f.readInt() - dataPresent := f.readByte() - return &RequestErrReadTimeout{ - errorFrame: errD, - Consistency: cl, - Received: received, - BlockFor: blockfor, - DataPresent: dataPresent, - } - case ErrCodeAlreadyExists: - ks := f.readString() - table := f.readString() - return &RequestErrAlreadyExists{ - errorFrame: errD, - Keyspace: ks, - Table: table, - } - case ErrCodeUnprepared: - stmtId := f.readShortBytes() - return &RequestErrUnprepared{ - errorFrame: errD, - StatementId: copyBytes(stmtId), // defensively copy - } - case ErrCodeReadFailure: - res := &RequestErrReadFailure{ - errorFrame: errD, - } - res.Consistency = f.readConsistency() - res.Received = f.readInt() - res.BlockFor = f.readInt() - if f.proto > protoVersion4 { - res.ErrorMap = f.readErrorMap() - res.NumFailures = len(res.ErrorMap) - } else { - res.NumFailures = f.readInt() - } - res.DataPresent = f.readByte() != 0 - - return res - case ErrCodeWriteFailure: - res := &RequestErrWriteFailure{ - errorFrame: errD, - } - res.Consistency = f.readConsistency() - res.Received = f.readInt() - res.BlockFor = f.readInt() - if f.proto > protoVersion4 { - res.ErrorMap = f.readErrorMap() - res.NumFailures = len(res.ErrorMap) - } else { - res.NumFailures = f.readInt() - } - res.WriteType = f.readString() - return res - case ErrCodeFunctionFailure: - res := &RequestErrFunctionFailure{ - errorFrame: errD, - } - res.Keyspace = f.readString() - res.Function = f.readString() - res.ArgTypes = f.readStringList() - return res - - case ErrCodeCDCWriteFailure: - res := &RequestErrCDCWriteFailure{ - errorFrame: errD, - } - return res - case ErrCodeCASWriteUnknown: - res := &RequestErrCASWriteUnknown{ - errorFrame: errD, - } - res.Consistency = f.readConsistency() - res.Received = f.readInt() - res.BlockFor = f.readInt() - return res - case ErrCodeInvalid, ErrCodeBootstrapping, ErrCodeConfig, ErrCodeCredentials, ErrCodeOverloaded, - ErrCodeProtocol, ErrCodeServer, ErrCodeSyntax, ErrCodeTruncate, ErrCodeUnauthorized: - // TODO(zariel): we should have some distinct types for these errors - return errD - default: - panic(fmt.Errorf("unknown error code: 0x%x", errD.code)) - } -} - -func (f *framer) readErrorMap() (errMap ErrorMap) { - errMap = make(ErrorMap) - numErrs := f.readInt() - for i := 0; i < numErrs; i++ { - ip := f.readInetAdressOnly().String() - errMap[ip] = f.readShort() - } - return -} - -func (f *framer) writeHeader(flags byte, op frameOp, stream int) { - f.buf = f.buf[:0] - f.buf = append(f.buf, - f.proto, - flags, - ) - - if f.proto > protoVersion2 { - f.buf = append(f.buf, - byte(stream>>8), - byte(stream), - ) - } else { - f.buf = append(f.buf, - byte(stream), - ) - } - - // pad out length - f.buf = append(f.buf, - byte(op), - 0, - 0, - 0, - 0, - ) -} - -func (f *framer) setLength(length int) { - p := 4 - if f.proto > protoVersion2 { - p = 5 - } - - f.buf[p+0] = byte(length >> 24) - f.buf[p+1] = byte(length >> 16) - f.buf[p+2] = byte(length >> 8) - f.buf[p+3] = byte(length) -} - -func (f *framer) finish() error { - if len(f.buf) > maxFrameSize { - // huge app frame, lets remove it so it doesn't bloat the heap - f.buf = make([]byte, defaultBufSize) - return ErrFrameTooBig - } - - if f.buf[1]&flagCompress == flagCompress { - if f.compres == nil { - panic("compress flag set with no compressor") - } - - // TODO: only compress frames which are big enough - compressed, err := f.compres.Encode(f.buf[f.headSize:]) - if err != nil { - return err - } - - f.buf = append(f.buf[:f.headSize], compressed...) - } - length := len(f.buf) - f.headSize - f.setLength(length) - - return nil -} - -func (f *framer) writeTo(w io.Writer) error { - _, err := w.Write(f.buf) - return err -} - -func (f *framer) readTrace() { - f.traceID = f.readUUID().Bytes() -} - -type readyFrame struct { - frameHeader -} - -func (f *framer) parseReadyFrame() frame { - return &readyFrame{ - frameHeader: *f.header, - } -} - -type supportedFrame struct { - frameHeader - - supported map[string][]string -} - -// TODO: if we move the body buffer onto the frameHeader then we only need a single -// framer, and can move the methods onto the header. -func (f *framer) parseSupportedFrame() frame { - return &supportedFrame{ - frameHeader: *f.header, - - supported: f.readStringMultiMap(), - } -} - -type writeStartupFrame struct { - opts map[string]string -} - -func (w writeStartupFrame) String() string { - return fmt.Sprintf("[startup opts=%+v]", w.opts) -} - -func (w *writeStartupFrame) buildFrame(f *framer, streamID int) error { - f.writeHeader(f.flags&^flagCompress, opStartup, streamID) - f.writeStringMap(w.opts) - - return f.finish() -} - -type writePrepareFrame struct { - statement string - keyspace string - customPayload map[string][]byte -} - -func (w *writePrepareFrame) buildFrame(f *framer, streamID int) error { - if len(w.customPayload) > 0 { - f.payload() - } - f.writeHeader(f.flags, opPrepare, streamID) - f.writeCustomPayload(&w.customPayload) - f.writeLongString(w.statement) - - var flags uint32 = 0 - if w.keyspace != "" { - if f.proto > protoVersion4 { - flags |= flagWithPreparedKeyspace - } else { - panic(fmt.Errorf("the keyspace can only be set with protocol 5 or higher")) - } - } - if f.proto > protoVersion4 { - f.writeUint(flags) - } - if w.keyspace != "" { - f.writeString(w.keyspace) - } - - return f.finish() -} - -func (f *framer) readTypeInfo() TypeInfo { - // TODO: factor this out so the same code paths can be used to parse custom - // types and other types, as much of the logic will be duplicated. - id := f.readShort() - - simple := NativeType{ - proto: f.proto, - typ: Type(id), - } - - if simple.typ == TypeCustom { - simple.custom = f.readString() - if cassType := getApacheCassandraType(simple.custom); cassType != TypeCustom { - simple.typ = cassType - } - } - - switch simple.typ { - case TypeTuple: - n := f.readShort() - tuple := TupleTypeInfo{ - NativeType: simple, - Elems: make([]TypeInfo, n), - } - - for i := 0; i < int(n); i++ { - tuple.Elems[i] = f.readTypeInfo() - } - - return tuple - - case TypeUDT: - udt := UDTTypeInfo{ - NativeType: simple, - } - udt.KeySpace = f.readString() - udt.Name = f.readString() - - n := f.readShort() - udt.Elements = make([]UDTField, n) - for i := 0; i < int(n); i++ { - field := &udt.Elements[i] - field.Name = f.readString() - field.Type = f.readTypeInfo() - } - - return udt - case TypeMap, TypeList, TypeSet: - collection := CollectionType{ - NativeType: simple, - } - - if simple.typ == TypeMap { - collection.Key = f.readTypeInfo() - } - - collection.Elem = f.readTypeInfo() - - return collection - } - - return simple -} - -type preparedMetadata struct { - resultMetadata - - // proto v4+ - pkeyColumns []int - - keyspace string - - table string -} - -func (r preparedMetadata) String() string { - return fmt.Sprintf("[prepared flags=0x%x pkey=%v paging_state=% X columns=%v col_count=%d actual_col_count=%d]", r.flags, r.pkeyColumns, r.pagingState, r.columns, r.colCount, r.actualColCount) -} - -func (f *framer) parsePreparedMetadata() preparedMetadata { - // TODO: deduplicate this from parseMetadata - meta := preparedMetadata{} - - meta.flags = f.readInt() - meta.colCount = f.readInt() - if meta.colCount < 0 { - panic(fmt.Errorf("received negative column count: %d", meta.colCount)) - } - meta.actualColCount = meta.colCount - - if f.proto >= protoVersion4 { - pkeyCount := f.readInt() - pkeys := make([]int, pkeyCount) - for i := 0; i < pkeyCount; i++ { - pkeys[i] = int(f.readShort()) - } - meta.pkeyColumns = pkeys - } - - if meta.flags&flagHasMorePages == flagHasMorePages { - meta.pagingState = copyBytes(f.readBytes()) - } - - if meta.flags&flagNoMetaData == flagNoMetaData { - return meta - } - - globalSpec := meta.flags&flagGlobalTableSpec == flagGlobalTableSpec - if globalSpec { - meta.keyspace = f.readString() - meta.table = f.readString() - } - - var cols []ColumnInfo - if meta.colCount < 1000 { - // preallocate columninfo to avoid excess copying - cols = make([]ColumnInfo, meta.colCount) - for i := 0; i < meta.colCount; i++ { - f.readCol(&cols[i], &meta.resultMetadata, globalSpec, meta.keyspace, meta.table) - } - } else { - // use append, huge number of columns usually indicates a corrupt frame or - // just a huge row. - for i := 0; i < meta.colCount; i++ { - var col ColumnInfo - f.readCol(&col, &meta.resultMetadata, globalSpec, meta.keyspace, meta.table) - cols = append(cols, col) - } - } - - meta.columns = cols - - return meta -} - -type resultMetadata struct { - flags int - - // only if flagPageState - pagingState []byte - - columns []ColumnInfo - colCount int - - // this is a count of the total number of columns which can be scanned, - // it is at minimum len(columns) but may be larger, for instance when a column - // is a UDT or tuple. - actualColCount int -} - -func (r *resultMetadata) morePages() bool { - return r.flags&flagHasMorePages == flagHasMorePages -} - -func (r resultMetadata) String() string { - return fmt.Sprintf("[metadata flags=0x%x paging_state=% X columns=%v]", r.flags, r.pagingState, r.columns) -} - -func (f *framer) readCol(col *ColumnInfo, meta *resultMetadata, globalSpec bool, keyspace, table string) { - if !globalSpec { - col.Keyspace = f.readString() - col.Table = f.readString() - } else { - col.Keyspace = keyspace - col.Table = table - } - - col.Name = f.readString() - col.TypeInfo = f.readTypeInfo() - switch v := col.TypeInfo.(type) { - // maybe also UDT - case TupleTypeInfo: - // -1 because we already included the tuple column - meta.actualColCount += len(v.Elems) - 1 - } -} - -func (f *framer) parseResultMetadata() resultMetadata { - var meta resultMetadata - - meta.flags = f.readInt() - meta.colCount = f.readInt() - if meta.colCount < 0 { - panic(fmt.Errorf("received negative column count: %d", meta.colCount)) - } - meta.actualColCount = meta.colCount - - if meta.flags&flagHasMorePages == flagHasMorePages { - meta.pagingState = copyBytes(f.readBytes()) - } - - if meta.flags&flagNoMetaData == flagNoMetaData { - return meta - } - - var keyspace, table string - globalSpec := meta.flags&flagGlobalTableSpec == flagGlobalTableSpec - if globalSpec { - keyspace = f.readString() - table = f.readString() - } - - var cols []ColumnInfo - if meta.colCount < 1000 { - // preallocate columninfo to avoid excess copying - cols = make([]ColumnInfo, meta.colCount) - for i := 0; i < meta.colCount; i++ { - f.readCol(&cols[i], &meta, globalSpec, keyspace, table) - } - - } else { - // use append, huge number of columns usually indicates a corrupt frame or - // just a huge row. - for i := 0; i < meta.colCount; i++ { - var col ColumnInfo - f.readCol(&col, &meta, globalSpec, keyspace, table) - cols = append(cols, col) - } - } - - meta.columns = cols - - return meta -} - -type resultVoidFrame struct { - frameHeader -} - -func (f *resultVoidFrame) String() string { - return "[result_void]" -} - -func (f *framer) parseResultFrame() (frame, error) { - kind := f.readInt() - - switch kind { - case resultKindVoid: - return &resultVoidFrame{frameHeader: *f.header}, nil - case resultKindRows: - return f.parseResultRows(), nil - case resultKindKeyspace: - return f.parseResultSetKeyspace(), nil - case resultKindPrepared: - return f.parseResultPrepared(), nil - case resultKindSchemaChanged: - return f.parseResultSchemaChange(), nil - } - - return nil, NewErrProtocol("unknown result kind: %x", kind) -} - -type resultRowsFrame struct { - frameHeader - - meta resultMetadata - // dont parse the rows here as we only need to do it once - numRows int -} - -func (f *resultRowsFrame) String() string { - return fmt.Sprintf("[result_rows meta=%v]", f.meta) -} - -func (f *framer) parseResultRows() frame { - result := &resultRowsFrame{} - result.meta = f.parseResultMetadata() - - result.numRows = f.readInt() - if result.numRows < 0 { - panic(fmt.Errorf("invalid row_count in result frame: %d", result.numRows)) - } - - return result -} - -type resultKeyspaceFrame struct { - frameHeader - keyspace string -} - -func (r *resultKeyspaceFrame) String() string { - return fmt.Sprintf("[result_keyspace keyspace=%s]", r.keyspace) -} - -func (f *framer) parseResultSetKeyspace() frame { - return &resultKeyspaceFrame{ - frameHeader: *f.header, - keyspace: f.readString(), - } -} - -type resultPreparedFrame struct { - frameHeader - - preparedID []byte - reqMeta preparedMetadata - respMeta resultMetadata -} - -func (f *framer) parseResultPrepared() frame { - frame := &resultPreparedFrame{ - frameHeader: *f.header, - preparedID: f.readShortBytes(), - reqMeta: f.parsePreparedMetadata(), - } - - if f.proto < protoVersion2 { - return frame - } - - frame.respMeta = f.parseResultMetadata() - - return frame -} - -type schemaChangeKeyspace struct { - frameHeader - - change string - keyspace string -} - -func (f schemaChangeKeyspace) String() string { - return fmt.Sprintf("[event schema_change_keyspace change=%q keyspace=%q]", f.change, f.keyspace) -} - -type schemaChangeTable struct { - frameHeader - - change string - keyspace string - object string -} - -func (f schemaChangeTable) String() string { - return fmt.Sprintf("[event schema_change change=%q keyspace=%q object=%q]", f.change, f.keyspace, f.object) -} - -type schemaChangeType struct { - frameHeader - - change string - keyspace string - object string -} - -type schemaChangeFunction struct { - frameHeader - - change string - keyspace string - name string - args []string -} - -type schemaChangeAggregate struct { - frameHeader - - change string - keyspace string - name string - args []string -} - -func (f *framer) parseResultSchemaChange() frame { - if f.proto <= protoVersion2 { - change := f.readString() - keyspace := f.readString() - table := f.readString() - - if table != "" { - return &schemaChangeTable{ - frameHeader: *f.header, - change: change, - keyspace: keyspace, - object: table, - } - } else { - return &schemaChangeKeyspace{ - frameHeader: *f.header, - change: change, - keyspace: keyspace, - } - } - } else { - change := f.readString() - target := f.readString() - - // TODO: could just use a separate type for each target - switch target { - case "KEYSPACE": - frame := &schemaChangeKeyspace{ - frameHeader: *f.header, - change: change, - } - - frame.keyspace = f.readString() - - return frame - case "TABLE": - frame := &schemaChangeTable{ - frameHeader: *f.header, - change: change, - } - - frame.keyspace = f.readString() - frame.object = f.readString() - - return frame - case "TYPE": - frame := &schemaChangeType{ - frameHeader: *f.header, - change: change, - } - - frame.keyspace = f.readString() - frame.object = f.readString() - - return frame - case "FUNCTION": - frame := &schemaChangeFunction{ - frameHeader: *f.header, - change: change, - } - - frame.keyspace = f.readString() - frame.name = f.readString() - frame.args = f.readStringList() - - return frame - case "AGGREGATE": - frame := &schemaChangeAggregate{ - frameHeader: *f.header, - change: change, - } - - frame.keyspace = f.readString() - frame.name = f.readString() - frame.args = f.readStringList() - - return frame - default: - panic(fmt.Errorf("gocql: unknown SCHEMA_CHANGE target: %q change: %q", target, change)) - } - } - -} - -type authenticateFrame struct { - frameHeader - - class string -} - -func (a *authenticateFrame) String() string { - return fmt.Sprintf("[authenticate class=%q]", a.class) -} - -func (f *framer) parseAuthenticateFrame() frame { - return &authenticateFrame{ - frameHeader: *f.header, - class: f.readString(), - } -} - -type authSuccessFrame struct { - frameHeader - - data []byte -} - -func (a *authSuccessFrame) String() string { - return fmt.Sprintf("[auth_success data=%q]", a.data) -} - -func (f *framer) parseAuthSuccessFrame() frame { - return &authSuccessFrame{ - frameHeader: *f.header, - data: f.readBytes(), - } -} - -type authChallengeFrame struct { - frameHeader - - data []byte -} - -func (a *authChallengeFrame) String() string { - return fmt.Sprintf("[auth_challenge data=%q]", a.data) -} - -func (f *framer) parseAuthChallengeFrame() frame { - return &authChallengeFrame{ - frameHeader: *f.header, - data: f.readBytes(), - } -} - -type statusChangeEventFrame struct { - frameHeader - - change string - host net.IP - port int -} - -func (t statusChangeEventFrame) String() string { - return fmt.Sprintf("[status_change change=%s host=%v port=%v]", t.change, t.host, t.port) -} - -// essentially the same as statusChange -type topologyChangeEventFrame struct { - frameHeader - - change string - host net.IP - port int -} - -func (t topologyChangeEventFrame) String() string { - return fmt.Sprintf("[topology_change change=%s host=%v port=%v]", t.change, t.host, t.port) -} - -func (f *framer) parseEventFrame() frame { - eventType := f.readString() - - switch eventType { - case "TOPOLOGY_CHANGE": - frame := &topologyChangeEventFrame{frameHeader: *f.header} - frame.change = f.readString() - frame.host, frame.port = f.readInet() - - return frame - case "STATUS_CHANGE": - frame := &statusChangeEventFrame{frameHeader: *f.header} - frame.change = f.readString() - frame.host, frame.port = f.readInet() - - return frame - case "SCHEMA_CHANGE": - // this should work for all versions - return f.parseResultSchemaChange() - default: - panic(fmt.Errorf("gocql: unknown event type: %q", eventType)) - } - -} - -type writeAuthResponseFrame struct { - data []byte -} - -func (a *writeAuthResponseFrame) String() string { - return fmt.Sprintf("[auth_response data=%q]", a.data) -} - -func (a *writeAuthResponseFrame) buildFrame(framer *framer, streamID int) error { - return framer.writeAuthResponseFrame(streamID, a.data) -} - -func (f *framer) writeAuthResponseFrame(streamID int, data []byte) error { - f.writeHeader(f.flags, opAuthResponse, streamID) - f.writeBytes(data) - return f.finish() -} - -type queryValues struct { - value []byte - - // optional name, will set With names for values flag - name string - isUnset bool -} - -type queryParams struct { - consistency Consistency - // v2+ - skipMeta bool - values []queryValues - pageSize int - pagingState []byte - serialConsistency SerialConsistency - // v3+ - defaultTimestamp bool - defaultTimestampValue int64 - // v5+ - keyspace string -} - -func (q queryParams) String() string { - return fmt.Sprintf("[query_params consistency=%v skip_meta=%v page_size=%d paging_state=%q serial_consistency=%v default_timestamp=%v values=%v keyspace=%s]", - q.consistency, q.skipMeta, q.pageSize, q.pagingState, q.serialConsistency, q.defaultTimestamp, q.values, q.keyspace) -} - -func (f *framer) writeQueryParams(opts *queryParams) { - f.writeConsistency(opts.consistency) - - if f.proto == protoVersion1 { - return - } - - var flags byte - if len(opts.values) > 0 { - flags |= flagValues - } - if opts.skipMeta { - flags |= flagSkipMetaData - } - if opts.pageSize > 0 { - flags |= flagPageSize - } - if len(opts.pagingState) > 0 { - flags |= flagWithPagingState - } - if opts.serialConsistency > 0 { - flags |= flagWithSerialConsistency - } - - names := false - - // protoV3 specific things - if f.proto > protoVersion2 { - if opts.defaultTimestamp { - flags |= flagDefaultTimestamp - } - - if len(opts.values) > 0 && opts.values[0].name != "" { - flags |= flagWithNameValues - names = true - } - } - - if opts.keyspace != "" { - if f.proto > protoVersion4 { - flags |= flagWithKeyspace - } else { - panic(fmt.Errorf("the keyspace can only be set with protocol 5 or higher")) - } - } - - if f.proto > protoVersion4 { - f.writeUint(uint32(flags)) - } else { - f.writeByte(flags) - } - - if n := len(opts.values); n > 0 { - f.writeShort(uint16(n)) - - for i := 0; i < n; i++ { - if names { - f.writeString(opts.values[i].name) - } - if opts.values[i].isUnset { - f.writeUnset() - } else { - f.writeBytes(opts.values[i].value) - } - } - } - - if opts.pageSize > 0 { - f.writeInt(int32(opts.pageSize)) - } - - if len(opts.pagingState) > 0 { - f.writeBytes(opts.pagingState) - } - - if opts.serialConsistency > 0 { - f.writeConsistency(Consistency(opts.serialConsistency)) - } - - if f.proto > protoVersion2 && opts.defaultTimestamp { - // timestamp in microseconds - var ts int64 - if opts.defaultTimestampValue != 0 { - ts = opts.defaultTimestampValue - } else { - ts = time.Now().UnixNano() / 1000 - } - f.writeLong(ts) - } - - if opts.keyspace != "" { - f.writeString(opts.keyspace) - } -} - -type writeQueryFrame struct { - statement string - params queryParams - - // v4+ - customPayload map[string][]byte -} - -func (w *writeQueryFrame) String() string { - return fmt.Sprintf("[query statement=%q params=%v]", w.statement, w.params) -} - -func (w *writeQueryFrame) buildFrame(framer *framer, streamID int) error { - return framer.writeQueryFrame(streamID, w.statement, &w.params, w.customPayload) -} - -func (f *framer) writeQueryFrame(streamID int, statement string, params *queryParams, customPayload map[string][]byte) error { - if len(customPayload) > 0 { - f.payload() - } - f.writeHeader(f.flags, opQuery, streamID) - f.writeCustomPayload(&customPayload) - f.writeLongString(statement) - f.writeQueryParams(params) - - return f.finish() -} - -type frameBuilder interface { - buildFrame(framer *framer, streamID int) error -} - -type frameWriterFunc func(framer *framer, streamID int) error - -func (f frameWriterFunc) buildFrame(framer *framer, streamID int) error { - return f(framer, streamID) -} - -type writeExecuteFrame struct { - preparedID []byte - params queryParams - - // v4+ - customPayload map[string][]byte -} - -func (e *writeExecuteFrame) String() string { - return fmt.Sprintf("[execute id=% X params=%v]", e.preparedID, &e.params) -} - -func (e *writeExecuteFrame) buildFrame(fr *framer, streamID int) error { - return fr.writeExecuteFrame(streamID, e.preparedID, &e.params, &e.customPayload) -} - -func (f *framer) writeExecuteFrame(streamID int, preparedID []byte, params *queryParams, customPayload *map[string][]byte) error { - if len(*customPayload) > 0 { - f.payload() - } - f.writeHeader(f.flags, opExecute, streamID) - f.writeCustomPayload(customPayload) - f.writeShortBytes(preparedID) - if f.proto > protoVersion1 { - f.writeQueryParams(params) - } else { - n := len(params.values) - f.writeShort(uint16(n)) - for i := 0; i < n; i++ { - if params.values[i].isUnset { - f.writeUnset() - } else { - f.writeBytes(params.values[i].value) - } - } - f.writeConsistency(params.consistency) - } - - return f.finish() -} - -// TODO: can we replace BatchStatemt with batchStatement? As they prety much -// duplicate each other -type batchStatment struct { - preparedID []byte - statement string - values []queryValues -} - -type writeBatchFrame struct { - typ BatchType - statements []batchStatment - consistency Consistency - - // v3+ - serialConsistency SerialConsistency - defaultTimestamp bool - defaultTimestampValue int64 - - //v4+ - customPayload map[string][]byte -} - -func (w *writeBatchFrame) buildFrame(framer *framer, streamID int) error { - return framer.writeBatchFrame(streamID, w, w.customPayload) -} - -func (f *framer) writeBatchFrame(streamID int, w *writeBatchFrame, customPayload map[string][]byte) error { - if len(customPayload) > 0 { - f.payload() - } - f.writeHeader(f.flags, opBatch, streamID) - f.writeCustomPayload(&customPayload) - f.writeByte(byte(w.typ)) - - n := len(w.statements) - f.writeShort(uint16(n)) - - var flags byte - - for i := 0; i < n; i++ { - b := &w.statements[i] - if len(b.preparedID) == 0 { - f.writeByte(0) - f.writeLongString(b.statement) - } else { - f.writeByte(1) - f.writeShortBytes(b.preparedID) - } - - f.writeShort(uint16(len(b.values))) - for j := range b.values { - col := b.values[j] - if f.proto > protoVersion2 && col.name != "" { - // TODO: move this check into the caller and set a flag on writeBatchFrame - // to indicate using named values - if f.proto <= protoVersion5 { - return fmt.Errorf("gocql: named query values are not supported in batches, please see https://issues.apache.org/jira/browse/CASSANDRA-10246") - } - flags |= flagWithNameValues - f.writeString(col.name) - } - if col.isUnset { - f.writeUnset() - } else { - f.writeBytes(col.value) - } - } - } - - f.writeConsistency(w.consistency) - - if f.proto > protoVersion2 { - if w.serialConsistency > 0 { - flags |= flagWithSerialConsistency - } - if w.defaultTimestamp { - flags |= flagDefaultTimestamp - } - - if f.proto > protoVersion4 { - f.writeUint(uint32(flags)) - } else { - f.writeByte(flags) - } - - if w.serialConsistency > 0 { - f.writeConsistency(Consistency(w.serialConsistency)) - } - - if w.defaultTimestamp { - var ts int64 - if w.defaultTimestampValue != 0 { - ts = w.defaultTimestampValue - } else { - ts = time.Now().UnixNano() / 1000 - } - f.writeLong(ts) - } - } - - return f.finish() -} - -type writeOptionsFrame struct{} - -func (w *writeOptionsFrame) buildFrame(framer *framer, streamID int) error { - return framer.writeOptionsFrame(streamID, w) -} - -func (f *framer) writeOptionsFrame(stream int, _ *writeOptionsFrame) error { - f.writeHeader(f.flags&^flagCompress, opOptions, stream) - return f.finish() -} - -type writeRegisterFrame struct { - events []string -} - -func (w *writeRegisterFrame) buildFrame(framer *framer, streamID int) error { - return framer.writeRegisterFrame(streamID, w) -} - -func (f *framer) writeRegisterFrame(streamID int, w *writeRegisterFrame) error { - f.writeHeader(f.flags, opRegister, streamID) - f.writeStringList(w.events) - - return f.finish() -} - -func (f *framer) readByte() byte { - if len(f.buf) < 1 { - panic(fmt.Errorf("not enough bytes in buffer to read byte require 1 got: %d", len(f.buf))) - } - - b := f.buf[0] - f.buf = f.buf[1:] - return b -} - -func (f *framer) readInt() (n int) { - if len(f.buf) < 4 { - panic(fmt.Errorf("not enough bytes in buffer to read int require 4 got: %d", len(f.buf))) - } - - n = int(int32(f.buf[0])<<24 | int32(f.buf[1])<<16 | int32(f.buf[2])<<8 | int32(f.buf[3])) - f.buf = f.buf[4:] - return -} - -func (f *framer) readShort() (n uint16) { - if len(f.buf) < 2 { - panic(fmt.Errorf("not enough bytes in buffer to read short require 2 got: %d", len(f.buf))) - } - n = uint16(f.buf[0])<<8 | uint16(f.buf[1]) - f.buf = f.buf[2:] - return -} - -func (f *framer) readString() (s string) { - size := f.readShort() - - if len(f.buf) < int(size) { - panic(fmt.Errorf("not enough bytes in buffer to read string require %d got: %d", size, len(f.buf))) - } - - s = string(f.buf[:size]) - f.buf = f.buf[size:] - return -} - -func (f *framer) readLongString() (s string) { - size := f.readInt() - - if len(f.buf) < size { - panic(fmt.Errorf("not enough bytes in buffer to read long string require %d got: %d", size, len(f.buf))) - } - - s = string(f.buf[:size]) - f.buf = f.buf[size:] - return -} - -func (f *framer) readUUID() *UUID { - if len(f.buf) < 16 { - panic(fmt.Errorf("not enough bytes in buffer to read uuid require %d got: %d", 16, len(f.buf))) - } - - // TODO: how to handle this error, if it is a uuid, then sureley, problems? - u, _ := UUIDFromBytes(f.buf[:16]) - f.buf = f.buf[16:] - return &u -} - -func (f *framer) readStringList() []string { - size := f.readShort() - - l := make([]string, size) - for i := 0; i < int(size); i++ { - l[i] = f.readString() - } - - return l -} - -func (f *framer) readBytesInternal() ([]byte, error) { - size := f.readInt() - if size < 0 { - return nil, nil - } - - if len(f.buf) < size { - return nil, fmt.Errorf("not enough bytes in buffer to read bytes require %d got: %d", size, len(f.buf)) - } - - l := f.buf[:size] - f.buf = f.buf[size:] - - return l, nil -} - -func (f *framer) readBytes() []byte { - l, err := f.readBytesInternal() - if err != nil { - panic(err) - } - - return l -} - -func (f *framer) readShortBytes() []byte { - size := f.readShort() - if len(f.buf) < int(size) { - panic(fmt.Errorf("not enough bytes in buffer to read short bytes: require %d got %d", size, len(f.buf))) - } - - l := f.buf[:size] - f.buf = f.buf[size:] - - return l -} - -func (f *framer) readInetAdressOnly() net.IP { - if len(f.buf) < 1 { - panic(fmt.Errorf("not enough bytes in buffer to read inet size require %d got: %d", 1, len(f.buf))) - } - - size := f.buf[0] - f.buf = f.buf[1:] - - if !(size == 4 || size == 16) { - panic(fmt.Errorf("invalid IP size: %d", size)) - } - - if len(f.buf) < 1 { - panic(fmt.Errorf("not enough bytes in buffer to read inet require %d got: %d", size, len(f.buf))) - } - - ip := make([]byte, size) - copy(ip, f.buf[:size]) - f.buf = f.buf[size:] - return net.IP(ip) -} - -func (f *framer) readInet() (net.IP, int) { - return f.readInetAdressOnly(), f.readInt() -} - -func (f *framer) readConsistency() Consistency { - return Consistency(f.readShort()) -} - -func (f *framer) readBytesMap() map[string][]byte { - size := f.readShort() - m := make(map[string][]byte, size) - - for i := 0; i < int(size); i++ { - k := f.readString() - v := f.readBytes() - m[k] = v - } - - return m -} - -func (f *framer) readStringMultiMap() map[string][]string { - size := f.readShort() - m := make(map[string][]string, size) - - for i := 0; i < int(size); i++ { - k := f.readString() - v := f.readStringList() - m[k] = v - } - - return m -} - -func (f *framer) writeByte(b byte) { - f.buf = append(f.buf, b) -} - -func appendBytes(p []byte, d []byte) []byte { - if d == nil { - return appendInt(p, -1) - } - p = appendInt(p, int32(len(d))) - p = append(p, d...) - return p -} - -func appendShort(p []byte, n uint16) []byte { - return append(p, - byte(n>>8), - byte(n), - ) -} - -func appendInt(p []byte, n int32) []byte { - return append(p, byte(n>>24), - byte(n>>16), - byte(n>>8), - byte(n)) -} - -func appendUint(p []byte, n uint32) []byte { - return append(p, byte(n>>24), - byte(n>>16), - byte(n>>8), - byte(n)) -} - -func appendLong(p []byte, n int64) []byte { - return append(p, - byte(n>>56), - byte(n>>48), - byte(n>>40), - byte(n>>32), - byte(n>>24), - byte(n>>16), - byte(n>>8), - byte(n), - ) -} - -func (f *framer) writeCustomPayload(customPayload *map[string][]byte) { - if len(*customPayload) > 0 { - if f.proto < protoVersion4 { - panic("Custom payload is not supported with version V3 or less") - } - f.writeBytesMap(*customPayload) - } -} - -// these are protocol level binary types -func (f *framer) writeInt(n int32) { - f.buf = appendInt(f.buf, n) -} - -func (f *framer) writeUint(n uint32) { - f.buf = appendUint(f.buf, n) -} - -func (f *framer) writeShort(n uint16) { - f.buf = appendShort(f.buf, n) -} - -func (f *framer) writeLong(n int64) { - f.buf = appendLong(f.buf, n) -} - -func (f *framer) writeString(s string) { - f.writeShort(uint16(len(s))) - f.buf = append(f.buf, s...) -} - -func (f *framer) writeLongString(s string) { - f.writeInt(int32(len(s))) - f.buf = append(f.buf, s...) -} - -func (f *framer) writeStringList(l []string) { - f.writeShort(uint16(len(l))) - for _, s := range l { - f.writeString(s) - } -} - -func (f *framer) writeUnset() { - // Protocol version 4 specifies that bind variables do not require having a - // value when executing a statement. Bind variables without a value are - // called 'unset'. The 'unset' bind variable is serialized as the int - // value '-2' without following bytes. - f.writeInt(-2) -} - -func (f *framer) writeBytes(p []byte) { - // TODO: handle null case correctly, - // [bytes] A [int] n, followed by n bytes if n >= 0. If n < 0, - // no byte should follow and the value represented is `null`. - if p == nil { - f.writeInt(-1) - } else { - f.writeInt(int32(len(p))) - f.buf = append(f.buf, p...) - } -} - -func (f *framer) writeShortBytes(p []byte) { - f.writeShort(uint16(len(p))) - f.buf = append(f.buf, p...) -} - -func (f *framer) writeConsistency(cons Consistency) { - f.writeShort(uint16(cons)) -} - -func (f *framer) writeStringMap(m map[string]string) { - f.writeShort(uint16(len(m))) - for k, v := range m { - f.writeString(k) - f.writeString(v) - } -} - -func (f *framer) writeBytesMap(m map[string][]byte) { - f.writeShort(uint16(len(m))) - for k, v := range m { - f.writeString(k) - f.writeBytes(v) - } -} diff --git a/helpers.go b/helpers.go index f2faee9e0..6a50c434e 100644 --- a/helpers.go +++ b/helpers.go @@ -26,13 +26,9 @@ package gocql import ( "fmt" - "math/big" "net" "reflect" "strings" - "time" - - "gopkg.in/inf.v0" ) type RowData struct { @@ -40,173 +36,71 @@ type RowData struct { Values []interface{} } -func goType(t TypeInfo) (reflect.Type, error) { - switch t.Type() { - case TypeVarchar, TypeAscii, TypeInet, TypeText: - return reflect.TypeOf(*new(string)), nil - case TypeBigInt, TypeCounter: - return reflect.TypeOf(*new(int64)), nil - case TypeTime: - return reflect.TypeOf(*new(time.Duration)), nil - case TypeTimestamp: - return reflect.TypeOf(*new(time.Time)), nil - case TypeBlob: - return reflect.TypeOf(*new([]byte)), nil - case TypeBoolean: - return reflect.TypeOf(*new(bool)), nil - case TypeFloat: - return reflect.TypeOf(*new(float32)), nil - case TypeDouble: - return reflect.TypeOf(*new(float64)), nil - case TypeInt: - return reflect.TypeOf(*new(int)), nil - case TypeSmallInt: - return reflect.TypeOf(*new(int16)), nil - case TypeTinyInt: - return reflect.TypeOf(*new(int8)), nil - case TypeDecimal: - return reflect.TypeOf(*new(*inf.Dec)), nil - case TypeUUID, TypeTimeUUID: - return reflect.TypeOf(*new(UUID)), nil - case TypeList, TypeSet: - elemType, err := goType(t.(CollectionType).Elem) - if err != nil { - return nil, err - } - return reflect.SliceOf(elemType), nil - case TypeMap: - keyType, err := goType(t.(CollectionType).Key) - if err != nil { - return nil, err - } - valueType, err := goType(t.(CollectionType).Elem) - if err != nil { - return nil, err - } - return reflect.MapOf(keyType, valueType), nil - case TypeVarint: - return reflect.TypeOf(*new(*big.Int)), nil - case TypeTuple: - // what can we do here? all there is to do is to make a list of interface{} - tuple := t.(TupleTypeInfo) - return reflect.TypeOf(make([]interface{}, len(tuple.Elems))), nil - case TypeUDT: - return reflect.TypeOf(make(map[string]interface{})), nil - case TypeDate: - return reflect.TypeOf(*new(time.Time)), nil - case TypeDuration: - return reflect.TypeOf(*new(Duration)), nil - default: - return nil, fmt.Errorf("cannot create Go type for unknown CQL type %s", t) - } -} +//func goType(t TypeInfo) (reflect.Type, error) { +// switch t.Type() { +// case TypeVarchar, TypeAscii, TypeInet, TypeText: +// return reflect.TypeOf(*new(string)), nil +// case TypeBigInt, TypeCounter: +// return reflect.TypeOf(*new(int64)), nil +// case TypeTime: +// return reflect.TypeOf(*new(time.Duration)), nil +// case TypeTimestamp: +// return reflect.TypeOf(*new(time.Time)), nil +// case TypeBlob: +// return reflect.TypeOf(*new([]byte)), nil +// case TypeBoolean: +// return reflect.TypeOf(*new(bool)), nil +// case TypeFloat: +// return reflect.TypeOf(*new(float32)), nil +// case TypeDouble: +// return reflect.TypeOf(*new(float64)), nil +// case TypeInt: +// return reflect.TypeOf(*new(int)), nil +// case TypeSmallInt: +// return reflect.TypeOf(*new(int16)), nil +// case TypeTinyInt: +// return reflect.TypeOf(*new(int8)), nil +// case TypeDecimal: +// return reflect.TypeOf(*new(*inf.Dec)), nil +// case TypeUUID, TypeTimeUUID: +// return reflect.TypeOf(*new(UUID)), nil +// case TypeList, TypeSet: +// elemType, err := goType(t.(CollectionType).Elem) +// if err != nil { +// return nil, err +// } +// return reflect.SliceOf(elemType), nil +// case TypeMap: +// keyType, err := goType(t.(CollectionType).Key) +// if err != nil { +// return nil, err +// } +// valueType, err := goType(t.(CollectionType).Elem) +// if err != nil { +// return nil, err +// } +// return reflect.MapOf(keyType, valueType), nil +// case TypeVarint: +// return reflect.TypeOf(*new(*big.Int)), nil +// case TypeTuple: +// // what can we do here? all there is to do is to make a list of interface{} +// tuple := t.(TupleTypeInfo) +// return reflect.TypeOf(make([]interface{}, len(tuple.Elems))), nil +// case TypeUDT: +// return reflect.TypeOf(make(map[string]interface{})), nil +// case TypeDate: +// return reflect.TypeOf(*new(time.Time)), nil +// case TypeDuration: +// return reflect.TypeOf(*new(Duration)), nil +// default: +// return nil, fmt.Errorf("cannot create Go type for unknown CQL type %s", t) +// } +//} func dereference(i interface{}) interface{} { return reflect.Indirect(reflect.ValueOf(i)).Interface() } -func getCassandraBaseType(name string) Type { - switch name { - case "ascii": - return TypeAscii - case "bigint": - return TypeBigInt - case "blob": - return TypeBlob - case "boolean": - return TypeBoolean - case "counter": - return TypeCounter - case "date": - return TypeDate - case "decimal": - return TypeDecimal - case "double": - return TypeDouble - case "duration": - return TypeDuration - case "float": - return TypeFloat - case "int": - return TypeInt - case "smallint": - return TypeSmallInt - case "tinyint": - return TypeTinyInt - case "time": - return TypeTime - case "timestamp": - return TypeTimestamp - case "uuid": - return TypeUUID - case "varchar": - return TypeVarchar - case "text": - return TypeText - case "varint": - return TypeVarint - case "timeuuid": - return TypeTimeUUID - case "inet": - return TypeInet - case "MapType": - return TypeMap - case "ListType": - return TypeList - case "SetType": - return TypeSet - case "TupleType": - return TypeTuple - default: - return TypeCustom - } -} - -func getCassandraType(name string, logger StdLogger) TypeInfo { - if strings.HasPrefix(name, "frozen<") { - return getCassandraType(strings.TrimPrefix(name[:len(name)-1], "frozen<"), logger) - } else if strings.HasPrefix(name, "set<") { - return CollectionType{ - NativeType: NativeType{typ: TypeSet}, - Elem: getCassandraType(strings.TrimPrefix(name[:len(name)-1], "set<"), logger), - } - } else if strings.HasPrefix(name, "list<") { - return CollectionType{ - NativeType: NativeType{typ: TypeList}, - Elem: getCassandraType(strings.TrimPrefix(name[:len(name)-1], "list<"), logger), - } - } else if strings.HasPrefix(name, "map<") { - names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "map<")) - if len(names) != 2 { - logger.Printf("Error parsing map type, it has %d subelements, expecting 2\n", len(names)) - return NativeType{ - typ: TypeCustom, - } - } - return CollectionType{ - NativeType: NativeType{typ: TypeMap}, - Key: getCassandraType(names[0], logger), - Elem: getCassandraType(names[1], logger), - } - } else if strings.HasPrefix(name, "tuple<") { - names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "tuple<")) - types := make([]TypeInfo, len(names)) - - for i, name := range names { - types[i] = getCassandraType(name, logger) - } - - return TupleTypeInfo{ - NativeType: NativeType{typ: TypeTuple}, - Elems: types, - } - } else { - return NativeType{ - typ: getCassandraBaseType(name), - } - } -} - func splitCompositeTypes(name string) []string { if !strings.Contains(name, "<") { return strings.Split(name, ", ") @@ -451,12 +345,6 @@ func (iter *Iter) MapScan(m map[string]interface{}) bool { return false } -func copyBytes(p []byte) []byte { - b := make([]byte, len(p)) - copy(b, p) - return b -} - var failDNS = false func LookupIP(host string) ([]net.IP, error) { diff --git a/host_source.go b/host_source.go index a0bab9ad0..eb5f7d8de 100644 --- a/host_source.go +++ b/host_source.go @@ -458,7 +458,7 @@ func checkSystemSchema(control *controlConn) (bool, error) { iter := control.query("SELECT * FROM system_schema.keyspaces") if err := iter.err; err != nil { if errf, ok := err.(*errorFrame); ok { - if errf.code == ErrCodeSyntax { + if errf.code == gocql_errors.ErrCodeSyntax { return false, nil } } diff --git a/internal/error_frame.go b/internal/error_frame.go new file mode 100644 index 000000000..b3d088d5c --- /dev/null +++ b/internal/error_frame.go @@ -0,0 +1,26 @@ +package internal + +import "fmt" + +type ErrorFrame struct { + FrameHeader + + code int + message string +} + +func (e ErrorFrame) Code() int { + return e.code +} + +func (e ErrorFrame) Message() string { + return e.message +} + +func (e ErrorFrame) Error() string { + return e.Message() +} + +func (e ErrorFrame) String() string { + return fmt.Sprintf("[error code=%x message=%q]", e.code, e.message) +} diff --git a/internal/errors.go b/internal/errors.go new file mode 100644 index 000000000..025112394 --- /dev/null +++ b/internal/errors.go @@ -0,0 +1,201 @@ +package internal + +import ( + "fmt" + "github.com/gocql/gocql/consistency" +) + +// See CQL Binary Protocol v5, section 8 for more details. +// https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec +const ( + // ErrCodeServer indicates unexpected error on server-side. + // + // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1246-L1247 + ErrCodeServer = 0x0000 + // ErrCodeProtocol indicates a protocol violation by some client message. + // + // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1248-L1250 + ErrCodeProtocol = 0x000A + // ErrCodeCredentials indicates missing required authentication. + // + // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1251-L1254 + ErrCodeCredentials = 0x0100 + // ErrCodeUnavailable indicates unavailable error. + // + // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1255-L1265 + ErrCodeUnavailable = 0x1000 + // ErrCodeOverloaded returned in case of request on overloaded node coordinator. + // + // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1266-L1267 + ErrCodeOverloaded = 0x1001 + // ErrCodeBootstrapping returned from the coordinator node in bootstrapping phase. + // + // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1268-L1269 + ErrCodeBootstrapping = 0x1002 + // ErrCodeTruncate indicates truncation exception. + // + // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1270 + ErrCodeTruncate = 0x1003 + // ErrCodeWriteTimeout returned in case of timeout during the request write. + // + // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1271-L1304 + ErrCodeWriteTimeout = 0x1100 + // ErrCodeReadTimeout returned in case of timeout during the request read. + // + // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1305-L1321 + ErrCodeReadTimeout = 0x1200 + // ErrCodeReadFailure indicates request read error which is not covered by ErrCodeReadTimeout. + // + // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1322-L1340 + ErrCodeReadFailure = 0x1300 + // ErrCodeFunctionFailure indicates an error in user-defined function. + // + // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1341-L1347 + ErrCodeFunctionFailure = 0x1400 + // ErrCodeWriteFailure indicates request write error which is not covered by ErrCodeWriteTimeout. + // + // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1348-L1385 + ErrCodeWriteFailure = 0x1500 + // ErrCodeCDCWriteFailure is defined, but not yet documented in CQLv5 protocol. + // + // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1386 + ErrCodeCDCWriteFailure = 0x1600 + // ErrCodeCASWriteUnknown indicates only partially completed CAS operation. + // + // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1387-L1397 + ErrCodeCASWriteUnknown = 0x1700 + // ErrCodeSyntax indicates the syntax error in the query. + // + // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1399 + ErrCodeSyntax = 0x2000 + // ErrCodeUnauthorized indicates access rights violation by user on performed operation. + // + // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1400-L1401 + ErrCodeUnauthorized = 0x2100 + // ErrCodeInvalid indicates invalid query error which is not covered by ErrCodeSyntax. + // + // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1402 + ErrCodeInvalid = 0x2200 + // ErrCodeConfig indicates the configuration error. + // + // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1403 + ErrCodeConfig = 0x2300 + // ErrCodeAlreadyExists is returned for the requests creating the existing keyspace/table. + // + // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1404-L1413 + ErrCodeAlreadyExists = 0x2400 + // ErrCodeUnprepared returned from the host for prepared statement which is unknown. + // + // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1414-L1417 + ErrCodeUnprepared = 0x2500 +) + +type RequestError interface { + Code() int + Message() string + Error() string +} + +//type errorFrame struct { +// FrameHeader +// +// code int +// message string +//} +// +//func (e errorFrame) Code() int { +// return e.code +//} +// +//func (e errorFrame) Message() string { +// return e.message +//} +// +//func (e errorFrame) Error() string { +// return e.Message() +//} +// +//func (e errorFrame) String() string { +// return fmt.Sprintf("[error code=%x message=%q]", e.code, e.message) +//} + +type RequestErrUnavailable struct { + ErrorFrame + Consistency consistency.Consistency + Required int + Alive int +} + +func (e *RequestErrUnavailable) String() string { + return fmt.Sprintf("[request_error_unavailable consistency=%s required=%d alive=%d]", e.Consistency, e.Required, e.Alive) +} + +type ErrorMap map[string]uint16 + +type RequestErrWriteTimeout struct { + ErrorFrame + Consistency consistency.Consistency + Received int + BlockFor int + WriteType string +} + +type RequestErrWriteFailure struct { + ErrorFrame + Consistency consistency.Consistency + Received int + BlockFor int + NumFailures int + WriteType string + ErrorMap ErrorMap +} + +type RequestErrCDCWriteFailure struct { + ErrorFrame +} + +type RequestErrReadTimeout struct { + ErrorFrame + Consistency consistency.Consistency + Received int + BlockFor int + DataPresent byte +} + +type RequestErrAlreadyExists struct { + ErrorFrame + Keyspace string + Table string +} + +type RequestErrUnprepared struct { + ErrorFrame + StatementId []byte +} + +type RequestErrReadFailure struct { + ErrorFrame + Consistency consistency.Consistency + Received int + BlockFor int + NumFailures int + DataPresent bool + ErrorMap ErrorMap +} + +type RequestErrFunctionFailure struct { + ErrorFrame + Keyspace string + Function string + ArgTypes []string +} + +// RequestErrCASWriteUnknown is distinct error for ErrCodeCasWriteUnknown. +// +// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1387-L1397 +type RequestErrCASWriteUnknown struct { + ErrorFrame + Consistency consistency.Consistency + Received int + BlockFor int +} diff --git a/internal/framer.go b/internal/framer.go new file mode 100644 index 000000000..cecc09189 --- /dev/null +++ b/internal/framer.go @@ -0,0 +1,1874 @@ +package internal + +import ( + "errors" + "fmt" + "github.com/gocql/gocql/compressor" + "github.com/gocql/gocql/consistency" + "github.com/gocql/gocql/protocol" + "github.com/gocql/gocql/session" + "io" + "io/ioutil" + "net" + "runtime" + "time" +) + +type UnsetColumn struct{} + +type NamedValue struct { + Name string + Value interface{} +} + +const ( + protoDirectionMask = 0x80 + protoVersionMask = 0x7F + protoVersion1 = 0x01 + protoVersion2 = 0x02 + protoVersion3 = 0x03 + protoVersion4 = 0x04 + protoVersion5 = 0x05 + + maxFrameSize = 256 * 1024 * 1024 +) + +type ProtoVersion byte + +func (p ProtoVersion) request() bool { + return p&protoDirectionMask == 0x00 +} + +func (p ProtoVersion) response() bool { + return p&protoDirectionMask == 0x80 +} + +func (p ProtoVersion) version() byte { + return byte(p) & protoVersionMask +} + +func (p ProtoVersion) String() string { + dir := "REQ" + if p.response() { + dir = "RESP" + } + + return fmt.Sprintf("[version=%d direction=%s]", p.version(), dir) +} + +type FrameOp byte + +const ( + // header ops + opError FrameOp = 0x00 + opStartup FrameOp = 0x01 + opReady FrameOp = 0x02 + opAuthenticate FrameOp = 0x03 + opOptions FrameOp = 0x05 + opSupported FrameOp = 0x06 + opQuery FrameOp = 0x07 + opResult FrameOp = 0x08 + opPrepare FrameOp = 0x09 + opExecute FrameOp = 0x0A + opRegister FrameOp = 0x0B + opEvent FrameOp = 0x0C + opBatch FrameOp = 0x0D + opAuthChallenge FrameOp = 0x0E + opAuthResponse FrameOp = 0x0F + opAuthSuccess FrameOp = 0x10 +) + +func (f FrameOp) String() string { + switch f { + case opError: + return "ERROR" + case opStartup: + return "STARTUP" + case opReady: + return "READY" + case opAuthenticate: + return "AUTHENTICATE" + case opOptions: + return "OPTIONS" + case opSupported: + return "SUPPORTED" + case opQuery: + return "QUERY" + case opResult: + return "RESULT" + case opPrepare: + return "PREPARE" + case opExecute: + return "EXECUTE" + case opRegister: + return "REGISTER" + case opEvent: + return "EVENT" + case opBatch: + return "BATCH" + case opAuthChallenge: + return "AUTH_CHALLENGE" + case opAuthResponse: + return "AUTH_RESPONSE" + case opAuthSuccess: + return "AUTH_SUCCESS" + default: + return fmt.Sprintf("UNKNOWN_OP_%d", f) + } +} + +const ( + // result kind + resultKindVoid = 1 + resultKindRows = 2 + resultKindKeyspace = 3 + resultKindPrepared = 4 + resultKindSchemaChanged = 5 + + // rows flags + flagGlobalTableSpec int = 0x01 + flagHasMorePages int = 0x02 + flagNoMetaData int = 0x04 + + // query flags + flagValues byte = 0x01 + flagSkipMetaData byte = 0x02 + flagPageSize byte = 0x04 + flagWithPagingState byte = 0x08 + flagWithSerialConsistency byte = 0x10 + flagDefaultTimestamp byte = 0x20 + flagWithNameValues byte = 0x40 + flagWithKeyspace byte = 0x80 + + // prepare flags + flagWithPreparedKeyspace uint32 = 0x01 + + // header flags + flagCompress byte = 0x01 + flagTracing byte = 0x02 + flagCustomPayload byte = 0x04 + flagWarning byte = 0x08 + flagBetaProtocol byte = 0x10 +) + +const ( + apacheCassandraTypePrefix = "org.apache.cassandra.db.marshal." +) + +var ( + ErrFrameTooBig = errors.New("frame length is bigger than the maximum allowed") +) + +const maxFrameHeaderSize = 9 + +func readInt(p []byte) int32 { + return int32(p[0])<<24 | int32(p[1])<<16 | int32(p[2])<<8 | int32(p[3]) +} + +type FrameHeader struct { + version ProtoVersion + flags byte + stream int + op FrameOp + length int + warnings []string +} + +func (f FrameHeader) String() string { + return fmt.Sprintf("[header version=%s flags=0x%x stream=%d op=%s length=%d]", f.version, f.flags, f.stream, f.op, f.length) +} + +func (f FrameHeader) Header() FrameHeader { + return f +} + +const defaultBufSize = 128 + +// a framer is responsible for reading, writing and parsing frames on a single stream +type framer struct { + proto byte + // flags are for outgoing flags, enabling compression and tracing etc + flags byte + compres compressor.Compressor + headSize int + // if this frame was read then the header will be here + header *FrameHeader + + // if tracing flag is set this is not nil + traceID []byte + + // holds a ref to the whole byte slice for buf so that it can be reset to + // 0 after a read. + readBuffer []byte + + buf []byte + + customPayload map[string][]byte +} + +func newFramer(compressor compressor.Compressor, version byte) *framer { + buf := make([]byte, defaultBufSize) + f := &framer{ + buf: buf[:0], + readBuffer: buf, + } + var flags byte + if compressor != nil { + flags |= flagCompress + } + if version == protoVersion5 { + flags |= flagBetaProtocol + } + + version &= protoVersionMask + + headSize := 8 + if version > protoVersion2 { + headSize = 9 + } + + f.compres = compressor + f.proto = version + f.flags = flags + f.headSize = headSize + + f.header = nil + f.traceID = nil + + return f +} + +type frame interface { + Header() FrameHeader +} + +func readHeader(r io.Reader, p []byte) (head FrameHeader, err error) { + _, err = io.ReadFull(r, p[:1]) + if err != nil { + return FrameHeader{}, err + } + + version := p[0] & protoVersionMask + + if version < protoVersion1 || version > protoVersion5 { + return FrameHeader{}, fmt.Errorf("gocql: unsupported protocol response version: %d", version) + } + + headSize := 9 + if version < protoVersion3 { + headSize = 8 + } + + _, err = io.ReadFull(r, p[1:headSize]) + if err != nil { + return FrameHeader{}, err + } + + p = p[:headSize] + + head.version = ProtoVersion(p[0]) + head.flags = p[1] + + if version > protoVersion2 { + if len(p) != 9 { + return FrameHeader{}, fmt.Errorf("not enough bytes to read header require 9 got: %d", len(p)) + } + + head.stream = int(int16(p[2])<<8 | int16(p[3])) + head.op = FrameOp(p[4]) + head.length = int(readInt(p[5:])) + } else { + if len(p) != 8 { + return FrameHeader{}, fmt.Errorf("not enough bytes to read header require 8 got: %d", len(p)) + } + + head.stream = int(int8(p[2])) + head.op = FrameOp(p[3]) + head.length = int(readInt(p[4:])) + } + + return head, nil +} + +// explicitly enables tracing for the framers outgoing requests +func (f *framer) trace() { + f.flags |= flagTracing +} + +// explicitly enables the custom payload flag +func (f *framer) payload() { + f.flags |= flagCustomPayload +} + +// reads a frame form the wire into the framers buffer +func (f *framer) readFrame(r io.Reader, head *FrameHeader) error { + if head.length < 0 { + return fmt.Errorf("frame body length can not be less than 0: %d", head.length) + } else if head.length > maxFrameSize { + // need to free up the connection to be used again + _, err := io.CopyN(ioutil.Discard, r, int64(head.length)) + if err != nil { + return fmt.Errorf("error whilst trying to discard frame with invalid length: %v", err) + } + return ErrFrameTooBig + } + + if cap(f.readBuffer) >= head.length { + f.buf = f.readBuffer[:head.length] + } else { + f.readBuffer = make([]byte, head.length) + f.buf = f.readBuffer + } + + // assume the underlying reader takes care of timeouts and retries + n, err := io.ReadFull(r, f.buf) + if err != nil { + return fmt.Errorf("unable to read frame body: read %d/%d bytes: %v", n, head.length, err) + } + + if head.flags&flagCompress == flagCompress { + if f.compres == nil { + return session.NewErrProtocol("no compressor available with compressed frame body") + } + + f.buf, err = f.compres.Decode(f.buf) + if err != nil { + return err + } + } + + f.header = head + return nil +} + +func (f *framer) parseFrame() (frame frame, err error) { + defer func() { + if r := recover(); r != nil { + if _, ok := r.(runtime.Error); ok { + panic(r) + } + err = r.(error) + } + }() + + if f.header.version.request() { + return nil, session.NewErrProtocol("got a request frame from server: %v", f.header.version) + } + + if f.header.flags&flagTracing == flagTracing { + f.readTrace() + } + + if f.header.flags&flagWarning == flagWarning { + f.header.warnings = f.readStringList() + } + + if f.header.flags&flagCustomPayload == flagCustomPayload { + f.customPayload = f.readBytesMap() + } + + // assumes that the frame body has been read into rbuf + switch f.header.op { + case opError: + frame = f.parseErrorFrame() + case opReady: + frame = f.parseReadyFrame() + case opResult: + frame, err = f.parseResultFrame() + case opSupported: + frame = f.parseSupportedFrame() + case opAuthenticate: + frame = f.parseAuthenticateFrame() + case opAuthChallenge: + frame = f.parseAuthChallengeFrame() + case opAuthSuccess: + frame = f.parseAuthSuccessFrame() + case opEvent: + frame = f.parseEventFrame() + default: + return nil, session.NewErrProtocol("unknown op in frame header: %s", f.header.op) + } + + return +} + +func (f *framer) parseErrorFrame() frame { + code := f.readInt() + msg := f.readString() + + errD := ErrorFrame{ + FrameHeader: *f.header, + code: code, + message: msg, + } + + switch code { + case ErrCodeUnavailable: + cl := f.readConsistency() + required := f.readInt() + alive := f.readInt() + return &RequestErrUnavailable{ + ErrorFrame: errD, + Consistency: cl, + Required: required, + Alive: alive, + } + case ErrCodeWriteTimeout: + cl := f.readConsistency() + received := f.readInt() + blockfor := f.readInt() + writeType := f.readString() + return &RequestErrWriteTimeout{ + ErrorFrame: errD, + Consistency: cl, + Received: received, + BlockFor: blockfor, + WriteType: writeType, + } + case ErrCodeReadTimeout: + cl := f.readConsistency() + received := f.readInt() + blockfor := f.readInt() + dataPresent := f.readByte() + return &RequestErrReadTimeout{ + ErrorFrame: errD, + Consistency: cl, + Received: received, + BlockFor: blockfor, + DataPresent: dataPresent, + } + case ErrCodeAlreadyExists: + ks := f.readString() + table := f.readString() + return &RequestErrAlreadyExists{ + ErrorFrame: errD, + Keyspace: ks, + Table: table, + } + case ErrCodeUnprepared: + stmtId := f.readShortBytes() + return &RequestErrUnprepared{ + ErrorFrame: errD, + StatementId: CopyBytes(stmtId), // defensively copy + } + case ErrCodeReadFailure: + res := &RequestErrReadFailure{ + ErrorFrame: errD, + } + res.Consistency = f.readConsistency() + res.Received = f.readInt() + res.BlockFor = f.readInt() + if f.proto > protoVersion4 { + res.ErrorMap = f.readErrorMap() + res.NumFailures = len(res.ErrorMap) + } else { + res.NumFailures = f.readInt() + } + res.DataPresent = f.readByte() != 0 + + return res + case ErrCodeWriteFailure: + res := &RequestErrWriteFailure{ + ErrorFrame: errD, + } + res.Consistency = f.readConsistency() + res.Received = f.readInt() + res.BlockFor = f.readInt() + if f.proto > protoVersion4 { + res.ErrorMap = f.readErrorMap() + res.NumFailures = len(res.ErrorMap) + } else { + res.NumFailures = f.readInt() + } + res.WriteType = f.readString() + return res + case ErrCodeFunctionFailure: + res := &RequestErrFunctionFailure{ + ErrorFrame: errD, + } + res.Keyspace = f.readString() + res.Function = f.readString() + res.ArgTypes = f.readStringList() + return res + + case ErrCodeCDCWriteFailure: + res := &RequestErrCDCWriteFailure{ + ErrorFrame: errD, + } + return res + case ErrCodeCASWriteUnknown: + res := &RequestErrCASWriteUnknown{ + ErrorFrame: errD, + } + res.Consistency = f.readConsistency() + res.Received = f.readInt() + res.BlockFor = f.readInt() + return res + case ErrCodeInvalid, ErrCodeBootstrapping, ErrCodeConfig, ErrCodeCredentials, ErrCodeOverloaded, + ErrCodeProtocol, ErrCodeServer, ErrCodeSyntax, ErrCodeTruncate, ErrCodeUnauthorized: + // TODO(zariel): we should have some distinct types for these errors + return errD + default: + panic(fmt.Errorf("unknown error code: 0x%x", errD.code)) + } +} + +func (f *framer) readErrorMap() (errMap ErrorMap) { + errMap = make(ErrorMap) + numErrs := f.readInt() + for i := 0; i < numErrs; i++ { + ip := f.readInetAdressOnly().String() + errMap[ip] = f.readShort() + } + return +} + +func (f *framer) writeHeader(flags byte, op FrameOp, stream int) { + f.buf = f.buf[:0] + f.buf = append(f.buf, + f.proto, + flags, + ) + + if f.proto > protoVersion2 { + f.buf = append(f.buf, + byte(stream>>8), + byte(stream), + ) + } else { + f.buf = append(f.buf, + byte(stream), + ) + } + + // pad out length + f.buf = append(f.buf, + byte(op), + 0, + 0, + 0, + 0, + ) +} + +func (f *framer) setLength(length int) { + p := 4 + if f.proto > protoVersion2 { + p = 5 + } + + f.buf[p+0] = byte(length >> 24) + f.buf[p+1] = byte(length >> 16) + f.buf[p+2] = byte(length >> 8) + f.buf[p+3] = byte(length) +} + +func (f *framer) finish() error { + if len(f.buf) > maxFrameSize { + // huge app frame, lets remove it so it doesn't bloat the heap + f.buf = make([]byte, defaultBufSize) + return ErrFrameTooBig + } + + if f.buf[1]&flagCompress == flagCompress { + if f.compres == nil { + panic("compress flag set with no compressor") + } + + // TODO: only compress frames which are big enough + compressed, err := f.compres.Encode(f.buf[f.headSize:]) + if err != nil { + return err + } + + f.buf = append(f.buf[:f.headSize], compressed...) + } + length := len(f.buf) - f.headSize + f.setLength(length) + + return nil +} + +func (f *framer) writeTo(w io.Writer) error { + _, err := w.Write(f.buf) + return err +} + +func (f *framer) readTrace() { + f.traceID = f.readUUID().Bytes() +} + +type readyFrame struct { + FrameHeader +} + +func (f *framer) parseReadyFrame() frame { + return &readyFrame{ + FrameHeader: *f.header, + } +} + +type supportedFrame struct { + FrameHeader + + supported map[string][]string +} + +// TODO: if we move the body buffer onto the FrameHeader then we only need a single +// framer, and can move the methods onto the header. +func (f *framer) parseSupportedFrame() frame { + return &supportedFrame{ + FrameHeader: *f.header, + + supported: f.readStringMultiMap(), + } +} + +type writeStartupFrame struct { + opts map[string]string +} + +func (w writeStartupFrame) String() string { + return fmt.Sprintf("[startup opts=%+v]", w.opts) +} + +func (w *writeStartupFrame) buildFrame(f *framer, streamID int) error { + f.writeHeader(f.flags&^flagCompress, opStartup, streamID) + f.writeStringMap(w.opts) + + return f.finish() +} + +type writePrepareFrame struct { + statement string + keyspace string + customPayload map[string][]byte +} + +func (w *writePrepareFrame) buildFrame(f *framer, streamID int) error { + if len(w.customPayload) > 0 { + f.payload() + } + f.writeHeader(f.flags, opPrepare, streamID) + f.writeCustomPayload(&w.customPayload) + f.writeLongString(w.statement) + + var flags uint32 = 0 + if w.keyspace != "" { + if f.proto > protoVersion4 { + flags |= flagWithPreparedKeyspace + } else { + panic(fmt.Errorf("the keyspace can only be set with protocol 5 or higher")) + } + } + if f.proto > protoVersion4 { + f.writeUint(flags) + } + if w.keyspace != "" { + f.writeString(w.keyspace) + } + + return f.finish() +} + +func (f *framer) readTypeInfo() protocol.TypeInfo { + // TODO: factor this out so the same code paths can be used to parse custom + // types and other types, as much of the logic will be duplicated. + id := f.readShort() + + simple := protocol.NativeType{ + Proto: f.proto, + Typ: protocol.Type(id), + } + + if simple.Typ == protocol.TypeCustom { + simple.Cust = f.readString() + if cassType := getApacheCassandraType(simple.Cust); cassType != protocol.TypeCustom { + simple.Typ = cassType + } + } + + switch simple.Typ { + case protocol.TypeTuple: + n := f.readShort() + tuple := protocol.TupleTypeInfo{ + NativeType: simple, + Elems: make([]protocol.TypeInfo, n), + } + + for i := 0; i < int(n); i++ { + tuple.Elems[i] = f.readTypeInfo() + } + + return tuple + + case protocol.TypeUDT: + udt := protocol.UDTTypeInfo{ + NativeType: simple, + } + udt.KeySpace = f.readString() + udt.Name = f.readString() + + n := f.readShort() + udt.Elements = make([]protocol.UDTField, n) + for i := 0; i < int(n); i++ { + field := &udt.Elements[i] + field.Name = f.readString() + field.Type = f.readTypeInfo() + } + + return udt + case protocol.TypeMap, protocol.TypeList, protocol.TypeSet: + collection := protocol.CollectionType{ + NativeType: simple, + } + + if simple.Typ == protocol.TypeMap { + collection.Key = f.readTypeInfo() + } + + collection.Elem = f.readTypeInfo() + + return collection + } + + return simple +} + +type preparedMetadata struct { + resultMetadata + + // proto v4+ + pkeyColumns []int + + keyspace string + + table string +} + +func (r preparedMetadata) String() string { + return fmt.Sprintf("[prepared flags=0x%x pkey=%v paging_state=% X columns=%v col_count=%d actual_col_count=%d]", r.flags, r.pkeyColumns, r.pagingState, r.columns, r.colCount, r.actualColCount) +} + +func (f *framer) parsePreparedMetadata() preparedMetadata { + // TODO: deduplicate this from parseMetadata + meta := preparedMetadata{} + + meta.flags = f.readInt() + meta.colCount = f.readInt() + if meta.colCount < 0 { + panic(fmt.Errorf("received negative column count: %d", meta.colCount)) + } + meta.actualColCount = meta.colCount + + if f.proto >= protoVersion4 { + pkeyCount := f.readInt() + pkeys := make([]int, pkeyCount) + for i := 0; i < pkeyCount; i++ { + pkeys[i] = int(f.readShort()) + } + meta.pkeyColumns = pkeys + } + + if meta.flags&flagHasMorePages == flagHasMorePages { + meta.pagingState = CopyBytes(f.readBytes()) + } + + if meta.flags&flagNoMetaData == flagNoMetaData { + return meta + } + + globalSpec := meta.flags&flagGlobalTableSpec == flagGlobalTableSpec + if globalSpec { + meta.keyspace = f.readString() + meta.table = f.readString() + } + + var cols []session.ColumnInfo + if meta.colCount < 1000 { + // preallocate columninfo to avoid excess copying + cols = make([]session.ColumnInfo, meta.colCount) + for i := 0; i < meta.colCount; i++ { + f.readCol(&cols[i], &meta.resultMetadata, globalSpec, meta.keyspace, meta.table) + } + } else { + // use append, huge number of columns usually indicates a corrupt frame or + // just a huge row. + for i := 0; i < meta.colCount; i++ { + var col session.ColumnInfo + f.readCol(&col, &meta.resultMetadata, globalSpec, meta.keyspace, meta.table) + cols = append(cols, col) + } + } + + meta.columns = cols + + return meta +} + +type resultMetadata struct { + flags int + + // only if flagPageState + pagingState []byte + + columns []session.ColumnInfo + colCount int + + // this is a count of the total number of columns which can be scanned, + // it is at minimum len(columns) but may be larger, for instance when a column + // is a UDT or tuple. + actualColCount int +} + +func (r *resultMetadata) morePages() bool { + return r.flags&flagHasMorePages == flagHasMorePages +} + +func (r resultMetadata) String() string { + return fmt.Sprintf("[metadata flags=0x%x paging_state=% X columns=%v]", r.flags, r.pagingState, r.columns) +} + +func (f *framer) readCol(col *session.ColumnInfo, meta *resultMetadata, globalSpec bool, keyspace, table string) { + if !globalSpec { + col.Keyspace = f.readString() + col.Table = f.readString() + } else { + col.Keyspace = keyspace + col.Table = table + } + + col.Name = f.readString() + col.TypeInfo = f.readTypeInfo() + switch v := col.TypeInfo.(type) { + // maybe also UDT + case protocol.TupleTypeInfo: + // -1 because we already included the tuple column + meta.actualColCount += len(v.Elems) - 1 + } +} + +func (f *framer) parseResultMetadata() resultMetadata { + var meta resultMetadata + + meta.flags = f.readInt() + meta.colCount = f.readInt() + if meta.colCount < 0 { + panic(fmt.Errorf("received negative column count: %d", meta.colCount)) + } + meta.actualColCount = meta.colCount + + if meta.flags&flagHasMorePages == flagHasMorePages { + meta.pagingState = CopyBytes(f.readBytes()) + } + + if meta.flags&flagNoMetaData == flagNoMetaData { + return meta + } + + var keyspace, table string + globalSpec := meta.flags&flagGlobalTableSpec == flagGlobalTableSpec + if globalSpec { + keyspace = f.readString() + table = f.readString() + } + + var cols []session.ColumnInfo + if meta.colCount < 1000 { + // preallocate columninfo to avoid excess copying + cols = make([]session.ColumnInfo, meta.colCount) + for i := 0; i < meta.colCount; i++ { + f.readCol(&cols[i], &meta, globalSpec, keyspace, table) + } + + } else { + // use append, huge number of columns usually indicates a corrupt frame or + // just a huge row. + for i := 0; i < meta.colCount; i++ { + var col session.ColumnInfo + f.readCol(&col, &meta, globalSpec, keyspace, table) + cols = append(cols, col) + } + } + + meta.columns = cols + + return meta +} + +type resultVoidFrame struct { + FrameHeader +} + +func (f *resultVoidFrame) String() string { + return "[result_void]" +} + +func (f *framer) parseResultFrame() (frame, error) { + kind := f.readInt() + + switch kind { + case resultKindVoid: + return &resultVoidFrame{FrameHeader: *f.header}, nil + case resultKindRows: + return f.parseResultRows(), nil + case resultKindKeyspace: + return f.parseResultSetKeyspace(), nil + case resultKindPrepared: + return f.parseResultPrepared(), nil + case resultKindSchemaChanged: + return f.parseResultSchemaChange(), nil + } + + return nil, session.NewErrProtocol("unknown result kind: %x", kind) +} + +type resultRowsFrame struct { + FrameHeader + + meta resultMetadata + // dont parse the rows here as we only need to do it once + numRows int +} + +func (f *resultRowsFrame) String() string { + return fmt.Sprintf("[result_rows meta=%v]", f.meta) +} + +func (f *framer) parseResultRows() frame { + result := &resultRowsFrame{} + result.meta = f.parseResultMetadata() + + result.numRows = f.readInt() + if result.numRows < 0 { + panic(fmt.Errorf("invalid row_count in result frame: %d", result.numRows)) + } + + return result +} + +type resultKeyspaceFrame struct { + FrameHeader + keyspace string +} + +func (r *resultKeyspaceFrame) String() string { + return fmt.Sprintf("[result_keyspace keyspace=%s]", r.keyspace) +} + +func (f *framer) parseResultSetKeyspace() frame { + return &resultKeyspaceFrame{ + FrameHeader: *f.header, + keyspace: f.readString(), + } +} + +type resultPreparedFrame struct { + FrameHeader + + preparedID []byte + reqMeta preparedMetadata + respMeta resultMetadata +} + +func (f *framer) parseResultPrepared() frame { + frame := &resultPreparedFrame{ + FrameHeader: *f.header, + preparedID: f.readShortBytes(), + reqMeta: f.parsePreparedMetadata(), + } + + if f.proto < protoVersion2 { + return frame + } + + frame.respMeta = f.parseResultMetadata() + + return frame +} + +type schemaChangeKeyspace struct { + FrameHeader + + change string + keyspace string +} + +func (f schemaChangeKeyspace) String() string { + return fmt.Sprintf("[event schema_change_keyspace change=%q keyspace=%q]", f.change, f.keyspace) +} + +type schemaChangeTable struct { + FrameHeader + + change string + keyspace string + object string +} + +func (f schemaChangeTable) String() string { + return fmt.Sprintf("[event schema_change change=%q keyspace=%q object=%q]", f.change, f.keyspace, f.object) +} + +type schemaChangeType struct { + FrameHeader + + change string + keyspace string + object string +} + +type schemaChangeFunction struct { + FrameHeader + + change string + keyspace string + name string + args []string +} + +type schemaChangeAggregate struct { + FrameHeader + + change string + keyspace string + name string + args []string +} + +func (f *framer) parseResultSchemaChange() frame { + if f.proto <= protoVersion2 { + change := f.readString() + keyspace := f.readString() + table := f.readString() + + if table != "" { + return &schemaChangeTable{ + FrameHeader: *f.header, + change: change, + keyspace: keyspace, + object: table, + } + } else { + return &schemaChangeKeyspace{ + FrameHeader: *f.header, + change: change, + keyspace: keyspace, + } + } + } else { + change := f.readString() + target := f.readString() + + // TODO: could just use a separate type for each target + switch target { + case "KEYSPACE": + frame := &schemaChangeKeyspace{ + FrameHeader: *f.header, + change: change, + } + + frame.keyspace = f.readString() + + return frame + case "TABLE": + frame := &schemaChangeTable{ + FrameHeader: *f.header, + change: change, + } + + frame.keyspace = f.readString() + frame.object = f.readString() + + return frame + case "TYPE": + frame := &schemaChangeType{ + FrameHeader: *f.header, + change: change, + } + + frame.keyspace = f.readString() + frame.object = f.readString() + + return frame + case "FUNCTION": + frame := &schemaChangeFunction{ + FrameHeader: *f.header, + change: change, + } + + frame.keyspace = f.readString() + frame.name = f.readString() + frame.args = f.readStringList() + + return frame + case "AGGREGATE": + frame := &schemaChangeAggregate{ + FrameHeader: *f.header, + change: change, + } + + frame.keyspace = f.readString() + frame.name = f.readString() + frame.args = f.readStringList() + + return frame + default: + panic(fmt.Errorf("gocql: unknown SCHEMA_CHANGE target: %q change: %q", target, change)) + } + } + +} + +type authenticateFrame struct { + FrameHeader + + class string +} + +func (a *authenticateFrame) String() string { + return fmt.Sprintf("[authenticate class=%q]", a.class) +} + +func (f *framer) parseAuthenticateFrame() frame { + return &authenticateFrame{ + FrameHeader: *f.header, + class: f.readString(), + } +} + +type authSuccessFrame struct { + FrameHeader + + data []byte +} + +func (a *authSuccessFrame) String() string { + return fmt.Sprintf("[auth_success data=%q]", a.data) +} + +func (f *framer) parseAuthSuccessFrame() frame { + return &authSuccessFrame{ + FrameHeader: *f.header, + data: f.readBytes(), + } +} + +type authChallengeFrame struct { + FrameHeader + + data []byte +} + +func (a *authChallengeFrame) String() string { + return fmt.Sprintf("[auth_challenge data=%q]", a.data) +} + +func (f *framer) parseAuthChallengeFrame() frame { + return &authChallengeFrame{ + FrameHeader: *f.header, + data: f.readBytes(), + } +} + +type statusChangeEventFrame struct { + FrameHeader + + change string + host net.IP + port int +} + +func (t statusChangeEventFrame) String() string { + return fmt.Sprintf("[status_change change=%s host=%v port=%v]", t.change, t.host, t.port) +} + +// essentially the same as statusChange +type topologyChangeEventFrame struct { + FrameHeader + + change string + host net.IP + port int +} + +func (t topologyChangeEventFrame) String() string { + return fmt.Sprintf("[topology_change change=%s host=%v port=%v]", t.change, t.host, t.port) +} + +func (f *framer) parseEventFrame() frame { + eventType := f.readString() + + switch eventType { + case "TOPOLOGY_CHANGE": + frame := &topologyChangeEventFrame{FrameHeader: *f.header} + frame.change = f.readString() + frame.host, frame.port = f.readInet() + + return frame + case "STATUS_CHANGE": + frame := &statusChangeEventFrame{FrameHeader: *f.header} + frame.change = f.readString() + frame.host, frame.port = f.readInet() + + return frame + case "SCHEMA_CHANGE": + // this should work for all versions + return f.parseResultSchemaChange() + default: + panic(fmt.Errorf("gocql: unknown event type: %q", eventType)) + } + +} + +type writeAuthResponseFrame struct { + data []byte +} + +func (a *writeAuthResponseFrame) String() string { + return fmt.Sprintf("[auth_response data=%q]", a.data) +} + +func (a *writeAuthResponseFrame) buildFrame(framer *framer, streamID int) error { + return framer.writeAuthResponseFrame(streamID, a.data) +} + +func (f *framer) writeAuthResponseFrame(streamID int, data []byte) error { + f.writeHeader(f.flags, opAuthResponse, streamID) + f.writeBytes(data) + return f.finish() +} + +type queryValues struct { + value []byte + + // optional name, will set With names for values flag + name string + isUnset bool +} + +type queryParams struct { + consistency consistency.Consistency + // v2+ + skipMeta bool + values []queryValues + pageSize int + pagingState []byte + serialConsistency consistency.SerialConsistency + // v3+ + defaultTimestamp bool + defaultTimestampValue int64 + // v5+ + keyspace string +} + +func (q queryParams) String() string { + return fmt.Sprintf("[query_params consistency=%v skip_meta=%v page_size=%d paging_state=%q serial_consistency=%v default_timestamp=%v values=%v keyspace=%s]", + q.consistency, q.skipMeta, q.pageSize, q.pagingState, q.serialConsistency, q.defaultTimestamp, q.values, q.keyspace) +} + +func (f *framer) writeQueryParams(opts *queryParams) { + f.writeConsistency(opts.consistency) + + if f.proto == protoVersion1 { + return + } + + var flags byte + if len(opts.values) > 0 { + flags |= flagValues + } + if opts.skipMeta { + flags |= flagSkipMetaData + } + if opts.pageSize > 0 { + flags |= flagPageSize + } + if len(opts.pagingState) > 0 { + flags |= flagWithPagingState + } + if opts.serialConsistency > 0 { + flags |= flagWithSerialConsistency + } + + names := false + + // protoV3 specific things + if f.proto > protoVersion2 { + if opts.defaultTimestamp { + flags |= flagDefaultTimestamp + } + + if len(opts.values) > 0 && opts.values[0].name != "" { + flags |= flagWithNameValues + names = true + } + } + + if opts.keyspace != "" { + if f.proto > protoVersion4 { + flags |= flagWithKeyspace + } else { + panic(fmt.Errorf("the keyspace can only be set with protocol 5 or higher")) + } + } + + if f.proto > protoVersion4 { + f.writeUint(uint32(flags)) + } else { + f.writeByte(flags) + } + + if n := len(opts.values); n > 0 { + f.writeShort(uint16(n)) + + for i := 0; i < n; i++ { + if names { + f.writeString(opts.values[i].name) + } + if opts.values[i].isUnset { + f.writeUnset() + } else { + f.writeBytes(opts.values[i].value) + } + } + } + + if opts.pageSize > 0 { + f.writeInt(int32(opts.pageSize)) + } + + if len(opts.pagingState) > 0 { + f.writeBytes(opts.pagingState) + } + + if opts.serialConsistency > 0 { + f.writeConsistency(consistency.Consistency(opts.serialConsistency)) + } + + if f.proto > protoVersion2 && opts.defaultTimestamp { + // timestamp in microseconds + var ts int64 + if opts.defaultTimestampValue != 0 { + ts = opts.defaultTimestampValue + } else { + ts = time.Now().UnixNano() / 1000 + } + f.writeLong(ts) + } + + if opts.keyspace != "" { + f.writeString(opts.keyspace) + } +} + +type writeQueryFrame struct { + statement string + params queryParams + + // v4+ + customPayload map[string][]byte +} + +func (w *writeQueryFrame) String() string { + return fmt.Sprintf("[query statement=%q params=%v]", w.statement, w.params) +} + +func (w *writeQueryFrame) buildFrame(framer *framer, streamID int) error { + return framer.writeQueryFrame(streamID, w.statement, &w.params, w.customPayload) +} + +func (f *framer) writeQueryFrame(streamID int, statement string, params *queryParams, customPayload map[string][]byte) error { + if len(customPayload) > 0 { + f.payload() + } + f.writeHeader(f.flags, opQuery, streamID) + f.writeCustomPayload(&customPayload) + f.writeLongString(statement) + f.writeQueryParams(params) + + return f.finish() +} + +type frameBuilder interface { + buildFrame(framer *framer, streamID int) error +} + +type frameWriterFunc func(framer *framer, streamID int) error + +func (f frameWriterFunc) buildFrame(framer *framer, streamID int) error { + return f(framer, streamID) +} + +type writeExecuteFrame struct { + preparedID []byte + params queryParams + + // v4+ + customPayload map[string][]byte +} + +func (e *writeExecuteFrame) String() string { + return fmt.Sprintf("[execute id=% X params=%v]", e.preparedID, &e.params) +} + +func (e *writeExecuteFrame) buildFrame(fr *framer, streamID int) error { + return fr.writeExecuteFrame(streamID, e.preparedID, &e.params, &e.customPayload) +} + +func (f *framer) writeExecuteFrame(streamID int, preparedID []byte, params *queryParams, customPayload *map[string][]byte) error { + if len(*customPayload) > 0 { + f.payload() + } + f.writeHeader(f.flags, opExecute, streamID) + f.writeCustomPayload(customPayload) + f.writeShortBytes(preparedID) + if f.proto > protoVersion1 { + f.writeQueryParams(params) + } else { + n := len(params.values) + f.writeShort(uint16(n)) + for i := 0; i < n; i++ { + if params.values[i].isUnset { + f.writeUnset() + } else { + f.writeBytes(params.values[i].value) + } + } + f.writeConsistency(params.consistency) + } + + return f.finish() +} + +// TODO: can we replace BatchStatemt with batchStatement? As they prety much +// duplicate each other +type batchStatment struct { + preparedID []byte + statement string + values []queryValues +} + +type writeBatchFrame struct { + typ session.BatchType + statements []batchStatment + consistency consistency.Consistency + + // v3+ + serialConsistency consistency.SerialConsistency + defaultTimestamp bool + defaultTimestampValue int64 + + //v4+ + customPayload map[string][]byte +} + +func (w *writeBatchFrame) buildFrame(framer *framer, streamID int) error { + return framer.writeBatchFrame(streamID, w, w.customPayload) +} + +func (f *framer) writeBatchFrame(streamID int, w *writeBatchFrame, customPayload map[string][]byte) error { + if len(customPayload) > 0 { + f.payload() + } + f.writeHeader(f.flags, opBatch, streamID) + f.writeCustomPayload(&customPayload) + f.writeByte(byte(w.typ)) + + n := len(w.statements) + f.writeShort(uint16(n)) + + var flags byte + + for i := 0; i < n; i++ { + b := &w.statements[i] + if len(b.preparedID) == 0 { + f.writeByte(0) + f.writeLongString(b.statement) + } else { + f.writeByte(1) + f.writeShortBytes(b.preparedID) + } + + f.writeShort(uint16(len(b.values))) + for j := range b.values { + col := b.values[j] + if f.proto > protoVersion2 && col.name != "" { + // TODO: move this check into the caller and set a flag on writeBatchFrame + // to indicate using named values + if f.proto <= protoVersion5 { + return fmt.Errorf("gocql: named query values are not supported in batches, please see https://issues.apache.org/jira/browse/CASSANDRA-10246") + } + flags |= flagWithNameValues + f.writeString(col.name) + } + if col.isUnset { + f.writeUnset() + } else { + f.writeBytes(col.value) + } + } + } + + f.writeConsistency(w.consistency) + + if f.proto > protoVersion2 { + if w.serialConsistency > 0 { + flags |= flagWithSerialConsistency + } + if w.defaultTimestamp { + flags |= flagDefaultTimestamp + } + + if f.proto > protoVersion4 { + f.writeUint(uint32(flags)) + } else { + f.writeByte(flags) + } + + if w.serialConsistency > 0 { + f.writeConsistency(consistency.Consistency(w.serialConsistency)) + } + + if w.defaultTimestamp { + var ts int64 + if w.defaultTimestampValue != 0 { + ts = w.defaultTimestampValue + } else { + ts = time.Now().UnixNano() / 1000 + } + f.writeLong(ts) + } + } + + return f.finish() +} + +type writeOptionsFrame struct{} + +func (w *writeOptionsFrame) buildFrame(framer *framer, streamID int) error { + return framer.writeOptionsFrame(streamID, w) +} + +func (f *framer) writeOptionsFrame(stream int, _ *writeOptionsFrame) error { + f.writeHeader(f.flags&^flagCompress, opOptions, stream) + return f.finish() +} + +type writeRegisterFrame struct { + events []string +} + +func (w *writeRegisterFrame) buildFrame(framer *framer, streamID int) error { + return framer.writeRegisterFrame(streamID, w) +} + +func (f *framer) writeRegisterFrame(streamID int, w *writeRegisterFrame) error { + f.writeHeader(f.flags, opRegister, streamID) + f.writeStringList(w.events) + + return f.finish() +} + +func (f *framer) readByte() byte { + if len(f.buf) < 1 { + panic(fmt.Errorf("not enough bytes in buffer to read byte require 1 got: %d", len(f.buf))) + } + + b := f.buf[0] + f.buf = f.buf[1:] + return b +} + +func (f *framer) readInt() (n int) { + if len(f.buf) < 4 { + panic(fmt.Errorf("not enough bytes in buffer to read int require 4 got: %d", len(f.buf))) + } + + n = int(int32(f.buf[0])<<24 | int32(f.buf[1])<<16 | int32(f.buf[2])<<8 | int32(f.buf[3])) + f.buf = f.buf[4:] + return +} + +func (f *framer) readShort() (n uint16) { + if len(f.buf) < 2 { + panic(fmt.Errorf("not enough bytes in buffer to read short require 2 got: %d", len(f.buf))) + } + n = uint16(f.buf[0])<<8 | uint16(f.buf[1]) + f.buf = f.buf[2:] + return +} + +func (f *framer) readString() (s string) { + size := f.readShort() + + if len(f.buf) < int(size) { + panic(fmt.Errorf("not enough bytes in buffer to read string require %d got: %d", size, len(f.buf))) + } + + s = string(f.buf[:size]) + f.buf = f.buf[size:] + return +} + +func (f *framer) readLongString() (s string) { + size := f.readInt() + + if len(f.buf) < size { + panic(fmt.Errorf("not enough bytes in buffer to read long string require %d got: %d", size, len(f.buf))) + } + + s = string(f.buf[:size]) + f.buf = f.buf[size:] + return +} + +func (f *framer) readUUID() *protocol.UUID { + if len(f.buf) < 16 { + panic(fmt.Errorf("not enough bytes in buffer to read uuid require %d got: %d", 16, len(f.buf))) + } + + // TODO: how to handle this error, if it is a uuid, then sureley, problems? + u, _ := protocol.UUIDFromBytes(f.buf[:16]) + f.buf = f.buf[16:] + return &u +} + +func (f *framer) readStringList() []string { + size := f.readShort() + + l := make([]string, size) + for i := 0; i < int(size); i++ { + l[i] = f.readString() + } + + return l +} + +func (f *framer) readBytesInternal() ([]byte, error) { + size := f.readInt() + if size < 0 { + return nil, nil + } + + if len(f.buf) < size { + return nil, fmt.Errorf("not enough bytes in buffer to read bytes require %d got: %d", size, len(f.buf)) + } + + l := f.buf[:size] + f.buf = f.buf[size:] + + return l, nil +} + +func (f *framer) readBytes() []byte { + l, err := f.readBytesInternal() + if err != nil { + panic(err) + } + + return l +} + +func (f *framer) readShortBytes() []byte { + size := f.readShort() + if len(f.buf) < int(size) { + panic(fmt.Errorf("not enough bytes in buffer to read short bytes: require %d got %d", size, len(f.buf))) + } + + l := f.buf[:size] + f.buf = f.buf[size:] + + return l +} + +func (f *framer) readInetAdressOnly() net.IP { + if len(f.buf) < 1 { + panic(fmt.Errorf("not enough bytes in buffer to read inet size require %d got: %d", 1, len(f.buf))) + } + + size := f.buf[0] + f.buf = f.buf[1:] + + if !(size == 4 || size == 16) { + panic(fmt.Errorf("invalid IP size: %d", size)) + } + + if len(f.buf) < 1 { + panic(fmt.Errorf("not enough bytes in buffer to read inet require %d got: %d", size, len(f.buf))) + } + + ip := make([]byte, size) + copy(ip, f.buf[:size]) + f.buf = f.buf[size:] + return net.IP(ip) +} + +func (f *framer) readInet() (net.IP, int) { + return f.readInetAdressOnly(), f.readInt() +} + +func (f *framer) readConsistency() consistency.Consistency { + return consistency.Consistency(f.readShort()) +} + +func (f *framer) readBytesMap() map[string][]byte { + size := f.readShort() + m := make(map[string][]byte, size) + + for i := 0; i < int(size); i++ { + k := f.readString() + v := f.readBytes() + m[k] = v + } + + return m +} + +func (f *framer) readStringMultiMap() map[string][]string { + size := f.readShort() + m := make(map[string][]string, size) + + for i := 0; i < int(size); i++ { + k := f.readString() + v := f.readStringList() + m[k] = v + } + + return m +} + +func (f *framer) writeByte(b byte) { + f.buf = append(f.buf, b) +} + +func appendBytes(p []byte, d []byte) []byte { + if d == nil { + return appendInt(p, -1) + } + p = appendInt(p, int32(len(d))) + p = append(p, d...) + return p +} + +func appendShort(p []byte, n uint16) []byte { + return append(p, + byte(n>>8), + byte(n), + ) +} + +func appendInt(p []byte, n int32) []byte { + return append(p, byte(n>>24), + byte(n>>16), + byte(n>>8), + byte(n)) +} + +func appendUint(p []byte, n uint32) []byte { + return append(p, byte(n>>24), + byte(n>>16), + byte(n>>8), + byte(n)) +} + +func appendLong(p []byte, n int64) []byte { + return append(p, + byte(n>>56), + byte(n>>48), + byte(n>>40), + byte(n>>32), + byte(n>>24), + byte(n>>16), + byte(n>>8), + byte(n), + ) +} + +func (f *framer) writeCustomPayload(customPayload *map[string][]byte) { + if len(*customPayload) > 0 { + if f.proto < protoVersion4 { + panic("Custom payload is not supported with version V3 or less") + } + f.writeBytesMap(*customPayload) + } +} + +// these are protocol level binary types +func (f *framer) writeInt(n int32) { + f.buf = appendInt(f.buf, n) +} + +func (f *framer) writeUint(n uint32) { + f.buf = appendUint(f.buf, n) +} + +func (f *framer) writeShort(n uint16) { + f.buf = appendShort(f.buf, n) +} + +func (f *framer) writeLong(n int64) { + f.buf = appendLong(f.buf, n) +} + +func (f *framer) writeString(s string) { + f.writeShort(uint16(len(s))) + f.buf = append(f.buf, s...) +} + +func (f *framer) writeLongString(s string) { + f.writeInt(int32(len(s))) + f.buf = append(f.buf, s...) +} + +func (f *framer) writeStringList(l []string) { + f.writeShort(uint16(len(l))) + for _, s := range l { + f.writeString(s) + } +} + +func (f *framer) writeUnset() { + // Protocol version 4 specifies that bind variables do not require having a + // value when executing a statement. Bind variables without a value are + // called 'unset'. The 'unset' bind variable is serialized as the int + // value '-2' without following bytes. + f.writeInt(-2) +} + +func (f *framer) writeBytes(p []byte) { + // TODO: handle null case correctly, + // [bytes] A [int] n, followed by n bytes if n >= 0. If n < 0, + // no byte should follow and the value represented is `null`. + if p == nil { + f.writeInt(-1) + } else { + f.writeInt(int32(len(p))) + f.buf = append(f.buf, p...) + } +} + +func (f *framer) writeShortBytes(p []byte) { + f.writeShort(uint16(len(p))) + f.buf = append(f.buf, p...) +} + +func (f *framer) writeConsistency(cons consistency.Consistency) { + f.writeShort(uint16(cons)) +} + +func (f *framer) writeStringMap(m map[string]string) { + f.writeShort(uint16(len(m))) + for k, v := range m { + f.writeString(k) + f.writeString(v) + } +} + +func (f *framer) writeBytesMap(m map[string][]byte) { + f.writeShort(uint16(len(m))) + for k, v := range m { + f.writeString(k) + f.writeBytes(v) + } +} diff --git a/internal/helpers.go b/internal/helpers.go new file mode 100644 index 000000000..79dd110c3 --- /dev/null +++ b/internal/helpers.go @@ -0,0 +1,167 @@ +package internal + +import ( + "github.com/gocql/gocql/protocol" + "strings" +) + +func CopyBytes(p []byte) []byte { + b := make([]byte, len(p)) + copy(b, p) + return b +} + +func getApacheCassandraType(class string) protocol.Type { + switch strings.TrimPrefix(class, apacheCassandraTypePrefix) { + case "AsciiType": + return protocol.TypeAscii + case "LongType": + return protocol.TypeBigInt + case "BytesType": + return protocol.TypeBlob + case "BooleanType": + return protocol.TypeBoolean + case "CounterColumnType": + return protocol.TypeCounter + case "DecimalType": + return protocol.TypeDecimal + case "DoubleType": + return protocol.TypeDouble + case "FloatType": + return protocol.TypeFloat + case "Int32Type": + return protocol.TypeInt + case "ShortType": + return protocol.TypeSmallInt + case "ByteType": + return protocol.TypeTinyInt + case "TimeType": + return protocol.TypeTime + case "DateType", "TimestampType": + return protocol.TypeTimestamp + case "UUIDType", "LexicalUUIDType": + return protocol.TypeUUID + case "UTF8Type": + return protocol.TypeVarchar + case "IntegerType": + return protocol.TypeVarint + case "TimeUUIDType": + return protocol.TypeTimeUUID + case "InetAddressType": + return protocol.TypeInet + case "MapType": + return protocol.TypeMap + case "ListType": + return protocol.TypeList + case "SetType": + return protocol.TypeSet + case "TupleType": + return protocol.TypeTuple + case "DurationType": + return protocol.TypeDuration + default: + return protocol.TypeCustom + } +} + +func getCassandraBaseType(name string) protocol.Type { + switch name { + case "ascii": + return protocol.TypeAscii + case "bigint": + return protocol.TypeBigInt + case "blob": + return protocol.TypeBlob + case "boolean": + return protocol.TypeBoolean + case "counter": + return protocol.TypeCounter + case "date": + return protocol.TypeDate + case "decimal": + return protocol.TypeDecimal + case "double": + return protocol.TypeDouble + case "duration": + return protocol.TypeDuration + case "float": + return protocol.TypeFloat + case "int": + return protocol.TypeInt + case "smallint": + return protocol.TypeSmallInt + case "tinyint": + return protocol.TypeTinyInt + case "time": + return protocol.TypeTime + case "timestamp": + return protocol.TypeTimestamp + case "uuid": + return protocol.TypeUUID + case "varchar": + return protocol.TypeVarchar + case "text": + return protocol.TypeText + case "varint": + return protocol.TypeVarint + case "timeuuid": + return protocol.TypeTimeUUID + case "inet": + return protocol.TypeInet + case "MapType": + return protocol.TypeMap + case "ListType": + return protocol.TypeList + case "SetType": + return protocol.TypeSet + case "TupleType": + return protocol.TypeTuple + default: + return protocol.TypeCustom + } +} + +//func getCassandraType(name string, logger StdLogger) TypeInfo { +// if strings.HasPrefix(name, "frozen<") { +// return getCassandraType(strings.TrimPrefix(name[:len(name)-1], "frozen<"), logger) +// } else if strings.HasPrefix(name, "set<") { +// return CollectionType{ +// NativeType: NativeType{typ: TypeSet}, +// Elem: getCassandraType(strings.TrimPrefix(name[:len(name)-1], "set<"), logger), +// } +// } else if strings.HasPrefix(name, "list<") { +// return CollectionType{ +// NativeType: NativeType{typ: TypeList}, +// Elem: getCassandraType(strings.TrimPrefix(name[:len(name)-1], "list<"), logger), +// } +// } else if strings.HasPrefix(name, "map<") { +// names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "map<")) +// if len(names) != 2 { +// logger.Printf("Error parsing map type, it has %d subelements, expecting 2\n", len(names)) +// return NativeType{ +// typ: TypeCustom, +// } +// } +// return CollectionType{ +// NativeType: NativeType{typ: TypeMap}, +// Key: getCassandraType(names[0], logger), +// Elem: getCassandraType(names[1], logger), +// } +// } else if strings.HasPrefix(name, "tuple<") { +// names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "tuple<")) +// types := make([]TypeInfo, len(names)) +// +// for i, name := range names { +// types[i] = getCassandraType(name, logger) +// } +// +// return TupleTypeInfo{ +// NativeType: NativeType{typ: TypeTuple}, +// Elems: types, +// } +// } else { +// return NativeType{ +// typ: getCassandraBaseType(name), +// } +// } +//} diff --git a/marshal.go b/marshal.go index 4d0adb923..ab25d0f4f 100644 --- a/marshal.go +++ b/marshal.go @@ -29,6 +29,8 @@ import ( "encoding/binary" "errors" "fmt" + "github.com/gocql/gocql/internal" + "github.com/gocql/gocql/protocol" "math" "math/big" "math/bits" @@ -53,13 +55,13 @@ var ( // Marshaler is the interface implemented by objects that can marshal // themselves into values understood by Cassandra. type Marshaler interface { - MarshalCQL(info TypeInfo) ([]byte, error) + MarshalCQL(info protocol.TypeInfo) ([]byte, error) } // Unmarshaler is the interface implemented by objects that can unmarshal // a Cassandra specific description of themselves. type Unmarshaler interface { - UnmarshalCQL(info TypeInfo, data []byte) error + UnmarshalCQL(info protocol.TypeInfo, data []byte) error } // Marshal returns the CQL encoding of the value for the Cassandra @@ -110,7 +112,7 @@ type Unmarshaler interface { // duration | time.Duration | // duration | gocql.Duration | // duration | string | parsed with time.ParseDuration -func Marshal(info TypeInfo, value interface{}) ([]byte, error) { +func Marshal(info protocol.TypeInfo, value interface{}) ([]byte, error) { if info.Version() < protoVersion1 { panic("protocol version not set") } @@ -222,7 +224,7 @@ func Marshal(info TypeInfo, value interface{}) ([]byte, error) { // date | *time.Time | time of beginning of the day (in UTC) // date | *string | formatted with 2006-01-02 format // duration | *gocql.Duration | -func Unmarshal(info TypeInfo, data []byte, value interface{}) error { +func Unmarshal(info protocol.TypeInfo, data []byte, value interface{}) error { if v, ok := value.(Unmarshaler); ok { return v.UnmarshalCQL(info, data) } @@ -290,11 +292,11 @@ func isNullableValue(value interface{}) bool { return v.Kind() == reflect.Ptr && v.Type().Elem().Kind() == reflect.Ptr } -func isNullData(info TypeInfo, data []byte) bool { +func isNullData(info protocol.TypeInfo, data []byte) bool { return data == nil } -func unmarshalNullable(info TypeInfo, data []byte, value interface{}) error { +func unmarshalNullable(info protocol.TypeInfo, data []byte, value interface{}) error { valueRef := reflect.ValueOf(value) if isNullData(info, data) { @@ -308,7 +310,7 @@ func unmarshalNullable(info TypeInfo, data []byte, value interface{}) error { return Unmarshal(info, data, newValue.Interface()) } -func marshalVarchar(info TypeInfo, value interface{}) ([]byte, error) { +func marshalVarchar(info protocol.TypeInfo, value interface{}) ([]byte, error) { switch v := value.(type) { case Marshaler: return v.MarshalCQL(info) @@ -336,7 +338,7 @@ func marshalVarchar(info TypeInfo, value interface{}) ([]byte, error) { return nil, marshalErrorf("can not marshal %T into %s", value, info) } -func unmarshalVarchar(info TypeInfo, data []byte, value interface{}) error { +func unmarshalVarchar(info protocol.TypeInfo, data []byte, value interface{}) error { switch v := value.(type) { case Unmarshaler: return v.UnmarshalCQL(info, data) @@ -375,7 +377,7 @@ func unmarshalVarchar(info TypeInfo, data []byte, value interface{}) error { return unmarshalErrorf("can not unmarshal %s into %T", info, value) } -func marshalSmallInt(info TypeInfo, value interface{}) ([]byte, error) { +func marshalSmallInt(info protocol.TypeInfo, value interface{}) ([]byte, error) { switch v := value.(type) { case Marshaler: return v.MarshalCQL(info) @@ -453,7 +455,7 @@ func marshalSmallInt(info TypeInfo, value interface{}) ([]byte, error) { return nil, marshalErrorf("can not marshal %T into %s", value, info) } -func marshalTinyInt(info TypeInfo, value interface{}) ([]byte, error) { +func marshalTinyInt(info protocol.TypeInfo, value interface{}) ([]byte, error) { switch v := value.(type) { case Marshaler: return v.MarshalCQL(info) @@ -537,7 +539,7 @@ func marshalTinyInt(info TypeInfo, value interface{}) ([]byte, error) { return nil, marshalErrorf("can not marshal %T into %s", value, info) } -func marshalInt(info TypeInfo, value interface{}) ([]byte, error) { +func marshalInt(info protocol.TypeInfo, value interface{}) ([]byte, error) { switch v := value.(type) { case Marshaler: return v.MarshalCQL(info) @@ -641,7 +643,7 @@ func decTiny(p []byte) int8 { return int8(p[0]) } -func marshalBigInt(info TypeInfo, value interface{}) ([]byte, error) { +func marshalBigInt(info protocol.TypeInfo, value interface{}) ([]byte, error) { switch v := value.(type) { case Marshaler: return v.MarshalCQL(info) @@ -718,23 +720,23 @@ func bytesToUint64(data []byte) (ret uint64) { return ret } -func unmarshalBigInt(info TypeInfo, data []byte, value interface{}) error { +func unmarshalBigInt(info protocol.TypeInfo, data []byte, value interface{}) error { return unmarshalIntlike(info, decBigInt(data), data, value) } -func unmarshalInt(info TypeInfo, data []byte, value interface{}) error { +func unmarshalInt(info protocol.TypeInfo, data []byte, value interface{}) error { return unmarshalIntlike(info, int64(decInt(data)), data, value) } -func unmarshalSmallInt(info TypeInfo, data []byte, value interface{}) error { +func unmarshalSmallInt(info protocol.TypeInfo, data []byte, value interface{}) error { return unmarshalIntlike(info, int64(decShort(data)), data, value) } -func unmarshalTinyInt(info TypeInfo, data []byte, value interface{}) error { +func unmarshalTinyInt(info protocol.TypeInfo, data []byte, value interface{}) error { return unmarshalIntlike(info, int64(decTiny(data)), data, value) } -func unmarshalVarint(info TypeInfo, data []byte, value interface{}) error { +func unmarshalVarint(info protocol.TypeInfo, data []byte, value interface{}) error { switch v := value.(type) { case *big.Int: return unmarshalIntlike(info, 0, data, value) @@ -756,7 +758,7 @@ func unmarshalVarint(info TypeInfo, data []byte, value interface{}) error { return unmarshalIntlike(info, int64Val, data, value) } -func marshalVarint(info TypeInfo, value interface{}) ([]byte, error) { +func marshalVarint(info protocol.TypeInfo, value interface{}) ([]byte, error) { var ( retBytes []byte err error @@ -807,7 +809,7 @@ func marshalVarint(info TypeInfo, value interface{}) ([]byte, error) { return retBytes, err } -func unmarshalIntlike(info TypeInfo, int64Val int64, data []byte, value interface{}) error { +func unmarshalIntlike(info protocol.TypeInfo, int64Val int64, data []byte, value interface{}) error { switch v := value.(type) { case *int: if ^uint(0) == math.MaxUint32 && (int64Val < math.MinInt32 || int64Val > math.MaxInt32) { @@ -1019,7 +1021,7 @@ func decBigInt(data []byte) int64 { int64(data[6])<<8 | int64(data[7]) } -func marshalBool(info TypeInfo, value interface{}) ([]byte, error) { +func marshalBool(info protocol.TypeInfo, value interface{}) ([]byte, error) { switch v := value.(type) { case Marshaler: return v.MarshalCQL(info) @@ -1048,7 +1050,7 @@ func encBool(v bool) []byte { return []byte{0} } -func unmarshalBool(info TypeInfo, data []byte, value interface{}) error { +func unmarshalBool(info protocol.TypeInfo, data []byte, value interface{}) error { switch v := value.(type) { case Unmarshaler: return v.UnmarshalCQL(info, data) @@ -1076,7 +1078,7 @@ func decBool(v []byte) bool { return v[0] != 0 } -func marshalFloat(info TypeInfo, value interface{}) ([]byte, error) { +func marshalFloat(info protocol.TypeInfo, value interface{}) ([]byte, error) { switch v := value.(type) { case Marshaler: return v.MarshalCQL(info) @@ -1098,7 +1100,7 @@ func marshalFloat(info TypeInfo, value interface{}) ([]byte, error) { return nil, marshalErrorf("can not marshal %T into %s", value, info) } -func unmarshalFloat(info TypeInfo, data []byte, value interface{}) error { +func unmarshalFloat(info protocol.TypeInfo, data []byte, value interface{}) error { switch v := value.(type) { case Unmarshaler: return v.UnmarshalCQL(info, data) @@ -1119,7 +1121,7 @@ func unmarshalFloat(info TypeInfo, data []byte, value interface{}) error { return unmarshalErrorf("can not unmarshal %s into %T", info, value) } -func marshalDouble(info TypeInfo, value interface{}) ([]byte, error) { +func marshalDouble(info protocol.TypeInfo, value interface{}) ([]byte, error) { switch v := value.(type) { case Marshaler: return v.MarshalCQL(info) @@ -1139,7 +1141,7 @@ func marshalDouble(info TypeInfo, value interface{}) ([]byte, error) { return nil, marshalErrorf("can not marshal %T into %s", value, info) } -func unmarshalDouble(info TypeInfo, data []byte, value interface{}) error { +func unmarshalDouble(info protocol.TypeInfo, data []byte, value interface{}) error { switch v := value.(type) { case Unmarshaler: return v.UnmarshalCQL(info, data) @@ -1160,7 +1162,7 @@ func unmarshalDouble(info TypeInfo, data []byte, value interface{}) error { return unmarshalErrorf("can not unmarshal %s into %T", info, value) } -func marshalDecimal(info TypeInfo, value interface{}) ([]byte, error) { +func marshalDecimal(info protocol.TypeInfo, value interface{}) ([]byte, error) { if value == nil { return nil, nil } @@ -1184,7 +1186,7 @@ func marshalDecimal(info TypeInfo, value interface{}) ([]byte, error) { return nil, marshalErrorf("can not marshal %T into %s", value, info) } -func unmarshalDecimal(info TypeInfo, data []byte, value interface{}) error { +func unmarshalDecimal(info protocol.TypeInfo, data []byte, value interface{}) error { switch v := value.(type) { case Unmarshaler: return v.UnmarshalCQL(info, data) @@ -1240,7 +1242,7 @@ func encBigInt2C(n *big.Int) []byte { return nil } -func marshalTime(info TypeInfo, value interface{}) ([]byte, error) { +func marshalTime(info protocol.TypeInfo, value interface{}) ([]byte, error) { switch v := value.(type) { case Marshaler: return v.MarshalCQL(info) @@ -1264,7 +1266,7 @@ func marshalTime(info TypeInfo, value interface{}) ([]byte, error) { return nil, marshalErrorf("can not marshal %T into %s", value, info) } -func marshalTimestamp(info TypeInfo, value interface{}) ([]byte, error) { +func marshalTimestamp(info protocol.TypeInfo, value interface{}) ([]byte, error) { switch v := value.(type) { case Marshaler: return v.MarshalCQL(info) @@ -1292,7 +1294,7 @@ func marshalTimestamp(info TypeInfo, value interface{}) ([]byte, error) { return nil, marshalErrorf("can not marshal %T into %s", value, info) } -func unmarshalTime(info TypeInfo, data []byte, value interface{}) error { +func unmarshalTime(info protocol.TypeInfo, data []byte, value interface{}) error { switch v := value.(type) { case Unmarshaler: return v.UnmarshalCQL(info, data) @@ -1317,7 +1319,7 @@ func unmarshalTime(info TypeInfo, data []byte, value interface{}) error { return unmarshalErrorf("can not unmarshal %s into %T", info, value) } -func unmarshalTimestamp(info TypeInfo, data []byte, value interface{}) error { +func unmarshalTimestamp(info protocol.TypeInfo, data []byte, value interface{}) error { switch v := value.(type) { case Unmarshaler: return v.UnmarshalCQL(info, data) @@ -1351,7 +1353,7 @@ func unmarshalTimestamp(info TypeInfo, data []byte, value interface{}) error { const millisecondsInADay int64 = 24 * 60 * 60 * 1000 -func marshalDate(info TypeInfo, value interface{}) ([]byte, error) { +func marshalDate(info protocol.TypeInfo, value interface{}) ([]byte, error) { var timestamp int64 switch v := value.(type) { case Marshaler: @@ -1395,7 +1397,7 @@ func marshalDate(info TypeInfo, value interface{}) ([]byte, error) { return nil, marshalErrorf("can not marshal %T into %s", value, info) } -func unmarshalDate(info TypeInfo, data []byte, value interface{}) error { +func unmarshalDate(info protocol.TypeInfo, data []byte, value interface{}) error { switch v := value.(type) { case Unmarshaler: return v.UnmarshalCQL(info, data) @@ -1423,7 +1425,7 @@ func unmarshalDate(info TypeInfo, data []byte, value interface{}) error { return unmarshalErrorf("can not unmarshal %s into %T", info, value) } -func marshalDuration(info TypeInfo, value interface{}) ([]byte, error) { +func marshalDuration(info protocol.TypeInfo, value interface{}) ([]byte, error) { switch v := value.(type) { case Marshaler: return v.MarshalCQL(info) @@ -1439,7 +1441,7 @@ func marshalDuration(info TypeInfo, value interface{}) ([]byte, error) { return nil, err } return encVints(0, 0, d.Nanoseconds()), nil - case Duration: + case protocol.Duration: return encVints(v.Months, v.Days, v.Nanoseconds), nil } @@ -1455,13 +1457,13 @@ func marshalDuration(info TypeInfo, value interface{}) ([]byte, error) { return nil, marshalErrorf("can not marshal %T into %s", value, info) } -func unmarshalDuration(info TypeInfo, data []byte, value interface{}) error { +func unmarshalDuration(info protocol.TypeInfo, data []byte, value interface{}) error { switch v := value.(type) { case Unmarshaler: return v.UnmarshalCQL(info, data) - case *Duration: + case *protocol.Duration: if len(data) == 0 { - *v = Duration{ + *v = protocol.Duration{ Months: 0, Days: 0, Nanoseconds: 0, @@ -1472,7 +1474,7 @@ func unmarshalDuration(info TypeInfo, data []byte, value interface{}) error { if err != nil { return unmarshalErrorf("failed to unmarshal %s into %T: %s", info, value, err.Error()) } - *v = Duration{ + *v = protocol.Duration{ Months: months, Days: days, Nanoseconds: nanos, @@ -1572,7 +1574,7 @@ func writeCollectionSize(info CollectionType, n int, buf *bytes.Buffer) error { return nil } -func marshalList(info TypeInfo, value interface{}) ([]byte, error) { +func marshalList(info protocol.TypeInfo, value interface{}) ([]byte, error) { listInfo, ok := info.(CollectionType) if !ok { return nil, marshalErrorf("marshal: can not marshal non collection type into list") @@ -1647,7 +1649,7 @@ func readCollectionSize(info CollectionType, data []byte) (size, read int, err e return } -func unmarshalList(info TypeInfo, data []byte, value interface{}) error { +func unmarshalList(info protocol.TypeInfo, data []byte, value interface{}) error { listInfo, ok := info.(CollectionType) if !ok { return unmarshalErrorf("unmarshal: can not unmarshal none collection type into list") @@ -1709,7 +1711,7 @@ func unmarshalList(info TypeInfo, data []byte, value interface{}) error { return unmarshalErrorf("can not unmarshal %s into %T", info, value) } -func marshalMap(info TypeInfo, value interface{}) ([]byte, error) { +func marshalMap(info protocol.TypeInfo, value interface{}) ([]byte, error) { mapInfo, ok := info.(CollectionType) if !ok { return nil, marshalErrorf("marshal: can not marshal none collection type into map") @@ -1772,7 +1774,7 @@ func marshalMap(info TypeInfo, value interface{}) ([]byte, error) { return buf.Bytes(), nil } -func unmarshalMap(info TypeInfo, data []byte, value interface{}) error { +func unmarshalMap(info protocol.TypeInfo, data []byte, value interface{}) error { mapInfo, ok := info.(CollectionType) if !ok { return unmarshalErrorf("unmarshal: can not unmarshal none collection type into map") @@ -1845,7 +1847,7 @@ func unmarshalMap(info TypeInfo, data []byte, value interface{}) error { return nil } -func marshalUUID(info TypeInfo, value interface{}) ([]byte, error) { +func marshalUUID(info protocol.TypeInfo, value interface{}) ([]byte, error) { switch val := value.(type) { case unsetColumn: return nil, nil @@ -1873,7 +1875,7 @@ func marshalUUID(info TypeInfo, value interface{}) ([]byte, error) { return nil, marshalErrorf("can not marshal %T into %s", value, info) } -func unmarshalUUID(info TypeInfo, data []byte, value interface{}) error { +func unmarshalUUID(info protocol.TypeInfo, data []byte, value interface{}) error { if len(data) == 0 { switch v := value.(type) { case *string: @@ -1918,7 +1920,7 @@ func unmarshalUUID(info TypeInfo, data []byte, value interface{}) error { return unmarshalErrorf("can not unmarshal X %s into %T", info, value) } -func unmarshalTimeUUID(info TypeInfo, data []byte, value interface{}) error { +func unmarshalTimeUUID(info protocol.TypeInfo, data []byte, value interface{}) error { switch v := value.(type) { case Unmarshaler: return v.UnmarshalCQL(info, data) @@ -1936,7 +1938,7 @@ func unmarshalTimeUUID(info TypeInfo, data []byte, value interface{}) error { } } -func marshalInet(info TypeInfo, value interface{}) ([]byte, error) { +func marshalInet(info protocol.TypeInfo, value interface{}) ([]byte, error) { // we return either the 4 or 16 byte representation of an // ip address here otherwise the db value will be prefixed // with the remaining byte values e.g. ::ffff:127.0.0.1 and not 127.0.0.1 @@ -1968,7 +1970,7 @@ func marshalInet(info TypeInfo, value interface{}) ([]byte, error) { return nil, marshalErrorf("cannot marshal %T into %s", value, info) } -func unmarshalInet(info TypeInfo, data []byte, value interface{}) error { +func unmarshalInet(info protocol.TypeInfo, data []byte, value interface{}) error { switch v := value.(type) { case Unmarshaler: return v.UnmarshalCQL(info, data) @@ -1976,7 +1978,7 @@ func unmarshalInet(info TypeInfo, data []byte, value interface{}) error { if x := len(data); !(x == 4 || x == 16) { return unmarshalErrorf("cannot unmarshal %s into %T: invalid sized IP: got %d bytes not 4 or 16", info, value, x) } - buf := copyBytes(data) + buf := internal.CopyBytes(data) ip := net.IP(buf) if v4 := ip.To4(); v4 != nil { *v = v4 @@ -2000,7 +2002,7 @@ func unmarshalInet(info TypeInfo, data []byte, value interface{}) error { return unmarshalErrorf("cannot unmarshal %s into %T", info, value) } -func marshalTuple(info TypeInfo, value interface{}) ([]byte, error) { +func marshalTuple(info protocol.TypeInfo, value interface{}) ([]byte, error) { tuple := info.(TupleTypeInfo) switch v := value.(type) { case unsetColumn: @@ -2104,7 +2106,7 @@ func readBytes(p []byte) ([]byte, []byte) { // currently only support unmarshal into a list of values, this makes it possible // to support tuples without changing the query API. In the future this can be extend // to allow unmarshalling into custom tuple types. -func unmarshalTuple(info TypeInfo, data []byte, value interface{}) error { +func unmarshalTuple(info protocol.TypeInfo, data []byte, value interface{}) error { if v, ok := value.(Unmarshaler); ok { return v.UnmarshalCQL(info, data) } @@ -2218,7 +2220,7 @@ type UDTMarshaler interface { // MarshalUDT will be called for each field in the the UDT returned by Cassandra, // the implementor should marshal the type to return by for example calling // Marshal. - MarshalUDT(name string, info TypeInfo) ([]byte, error) + MarshalUDT(name string, info protocol.TypeInfo) ([]byte, error) } // UDTUnmarshaler should be implemented by users wanting to implement custom @@ -2227,10 +2229,10 @@ type UDTUnmarshaler interface { // UnmarshalUDT will be called for each field in the UDT return by Cassandra, // the implementor should unmarshal the data into the value of their chosing, // for example by calling Unmarshal. - UnmarshalUDT(name string, info TypeInfo, data []byte) error + UnmarshalUDT(name string, info protocol.TypeInfo, data []byte) error } -func marshalUDT(info TypeInfo, value interface{}) ([]byte, error) { +func marshalUDT(info protocol.TypeInfo, value interface{}) ([]byte, error) { udt := info.(UDTTypeInfo) switch v := value.(type) { @@ -2315,12 +2317,12 @@ func marshalUDT(info TypeInfo, value interface{}) ([]byte, error) { return buf, nil } -func unmarshalUDT(info TypeInfo, data []byte, value interface{}) error { +func unmarshalUDT(info protocol.TypeInfo, data []byte, value interface{}) error { switch v := value.(type) { case Unmarshaler: return v.UnmarshalCQL(info, data) case UDTUnmarshaler: - udt := info.(UDTTypeInfo) + udt := info.(protocol.UDTTypeInfo) for id, e := range udt.Elements { if len(data) == 0 { @@ -2339,7 +2341,7 @@ func unmarshalUDT(info TypeInfo, data []byte, value interface{}) error { return nil case *map[string]interface{}: - udt := info.(UDTTypeInfo) + udt := info.(protocol.UDTTypeInfo) rv := reflect.ValueOf(value) if rv.Kind() != reflect.Ptr { @@ -2449,282 +2451,282 @@ func unmarshalUDT(info TypeInfo, data []byte, value interface{}) error { return nil } -// TypeInfo describes a Cassandra specific data type. -type TypeInfo interface { - Type() Type - Version() byte - Custom() string - - // New creates a pointer to an empty version of whatever type - // is referenced by the TypeInfo receiver. - // - // If there is no corresponding Go type for the CQL type, New panics. - // - // Deprecated: Use NewWithError instead. - New() interface{} - - // NewWithError creates a pointer to an empty version of whatever type - // is referenced by the TypeInfo receiver. - // - // If there is no corresponding Go type for the CQL type, NewWithError returns an error. - NewWithError() (interface{}, error) -} - -type NativeType struct { - proto byte - typ Type - custom string // only used for TypeCustom -} - -func NewNativeType(proto byte, typ Type, custom string) NativeType { - return NativeType{proto, typ, custom} -} - -func (t NativeType) NewWithError() (interface{}, error) { - typ, err := goType(t) - if err != nil { - return nil, err - } - return reflect.New(typ).Interface(), nil -} - -func (t NativeType) New() interface{} { - val, err := t.NewWithError() - if err != nil { - panic(err.Error()) - } - return val -} - -func (s NativeType) Type() Type { - return s.typ -} - -func (s NativeType) Version() byte { - return s.proto -} - -func (s NativeType) Custom() string { - return s.custom -} - -func (s NativeType) String() string { - switch s.typ { - case TypeCustom: - return fmt.Sprintf("%s(%s)", s.typ, s.custom) - default: - return s.typ.String() - } -} - -type CollectionType struct { - NativeType - Key TypeInfo // only used for TypeMap - Elem TypeInfo // only used for TypeMap, TypeList and TypeSet -} - -func (t CollectionType) NewWithError() (interface{}, error) { - typ, err := goType(t) - if err != nil { - return nil, err - } - return reflect.New(typ).Interface(), nil -} - -func (t CollectionType) New() interface{} { - val, err := t.NewWithError() - if err != nil { - panic(err.Error()) - } - return val -} - -func (c CollectionType) String() string { - switch c.typ { - case TypeMap: - return fmt.Sprintf("%s(%s, %s)", c.typ, c.Key, c.Elem) - case TypeList, TypeSet: - return fmt.Sprintf("%s(%s)", c.typ, c.Elem) - case TypeCustom: - return fmt.Sprintf("%s(%s)", c.typ, c.custom) - default: - return c.typ.String() - } -} - -type TupleTypeInfo struct { - NativeType - Elems []TypeInfo -} - -func (t TupleTypeInfo) String() string { - var buf bytes.Buffer - buf.WriteString(fmt.Sprintf("%s(", t.typ)) - for _, elem := range t.Elems { - buf.WriteString(fmt.Sprintf("%s, ", elem)) - } - buf.Truncate(buf.Len() - 2) - buf.WriteByte(')') - return buf.String() -} - -func (t TupleTypeInfo) NewWithError() (interface{}, error) { - typ, err := goType(t) - if err != nil { - return nil, err - } - return reflect.New(typ).Interface(), nil -} - -func (t TupleTypeInfo) New() interface{} { - val, err := t.NewWithError() - if err != nil { - panic(err.Error()) - } - return val -} - -type UDTField struct { - Name string - Type TypeInfo -} - -type UDTTypeInfo struct { - NativeType - KeySpace string - Name string - Elements []UDTField -} - -func (u UDTTypeInfo) NewWithError() (interface{}, error) { - typ, err := goType(u) - if err != nil { - return nil, err - } - return reflect.New(typ).Interface(), nil -} - -func (u UDTTypeInfo) New() interface{} { - val, err := u.NewWithError() - if err != nil { - panic(err.Error()) - } - return val -} - -func (u UDTTypeInfo) String() string { - buf := &bytes.Buffer{} - - fmt.Fprintf(buf, "%s.%s{", u.KeySpace, u.Name) - first := true - for _, e := range u.Elements { - if !first { - fmt.Fprint(buf, ",") - } else { - first = false - } - - fmt.Fprintf(buf, "%s=%v", e.Name, e.Type) - } - fmt.Fprint(buf, "}") - - return buf.String() -} - -// String returns a human readable name for the Cassandra datatype -// described by t. -// Type is the identifier of a Cassandra internal datatype. -type Type int - -const ( - TypeCustom Type = 0x0000 - TypeAscii Type = 0x0001 - TypeBigInt Type = 0x0002 - TypeBlob Type = 0x0003 - TypeBoolean Type = 0x0004 - TypeCounter Type = 0x0005 - TypeDecimal Type = 0x0006 - TypeDouble Type = 0x0007 - TypeFloat Type = 0x0008 - TypeInt Type = 0x0009 - TypeText Type = 0x000A - TypeTimestamp Type = 0x000B - TypeUUID Type = 0x000C - TypeVarchar Type = 0x000D - TypeVarint Type = 0x000E - TypeTimeUUID Type = 0x000F - TypeInet Type = 0x0010 - TypeDate Type = 0x0011 - TypeTime Type = 0x0012 - TypeSmallInt Type = 0x0013 - TypeTinyInt Type = 0x0014 - TypeDuration Type = 0x0015 - TypeList Type = 0x0020 - TypeMap Type = 0x0021 - TypeSet Type = 0x0022 - TypeUDT Type = 0x0030 - TypeTuple Type = 0x0031 -) +//// TypeInfo describes a Cassandra specific data type. +//type TypeInfo interface { +// Type() Type +// Version() byte +// Custom() string +// +// // New creates a pointer to an empty version of whatever type +// // is referenced by the TypeInfo receiver. +// // +// // If there is no corresponding Go type for the CQL type, New panics. +// // +// // Deprecated: Use NewWithError instead. +// New() interface{} +// +// // NewWithError creates a pointer to an empty version of whatever type +// // is referenced by the TypeInfo receiver. +// // +// // If there is no corresponding Go type for the CQL type, NewWithError returns an error. +// NewWithError() (interface{}, error) +//} + +//type NativeType struct { +// proto byte +// typ Type +// custom string // only used for TypeCustom +//} +// +//func NewNativeType(proto byte, typ Type, custom string) NativeType { +// return NativeType{proto, typ, custom} +//} +// +//func (t NativeType) NewWithError() (interface{}, error) { +// typ, err := goType(t) +// if err != nil { +// return nil, err +// } +// return reflect.New(typ).Interface(), nil +//} +// +//func (t NativeType) New() interface{} { +// val, err := t.NewWithError() +// if err != nil { +// panic(err.Error()) +// } +// return val +//} +// +//func (s NativeType) Type() Type { +// return s.typ +//} +// +//func (s NativeType) Version() byte { +// return s.proto +//} +// +//func (s NativeType) Custom() string { +// return s.custom +//} +// +//func (s NativeType) String() string { +// switch s.typ { +// case TypeCustom: +// return fmt.Sprintf("%s(%s)", s.typ, s.custom) +// default: +// return s.typ.String() +// } +//} + +//type CollectionType struct { +// NativeType +// Key protocol.TypeInfo // only used for TypeMap +// Elem protocol.TypeInfo // only used for TypeMap, TypeList and TypeSet +//} +// +//func (t CollectionType) NewWithError() (interface{}, error) { +// typ, err := goType(t) +// if err != nil { +// return nil, err +// } +// return reflect.New(typ).Interface(), nil +//} +// +//func (t CollectionType) New() interface{} { +// val, err := t.NewWithError() +// if err != nil { +// panic(err.Error()) +// } +// return val +//} +// +//func (c CollectionType) String() string { +// switch c.typ { +// case TypeMap: +// return fmt.Sprintf("%s(%s, %s)", c.typ, c.Key, c.Elem) +// case TypeList, TypeSet: +// return fmt.Sprintf("%s(%s)", c.typ, c.Elem) +// case TypeCustom: +// return fmt.Sprintf("%s(%s)", c.typ, c.custom) +// default: +// return c.typ.String() +// } +//} +// +//type TupleTypeInfo struct { +// NativeType +// Elems []protocol.TypeInfo +//} +// +//func (t TupleTypeInfo) String() string { +// var buf bytes.Buffer +// buf.WriteString(fmt.Sprintf("%s(", t.typ)) +// for _, elem := range t.Elems { +// buf.WriteString(fmt.Sprintf("%s, ", elem)) +// } +// buf.Truncate(buf.Len() - 2) +// buf.WriteByte(')') +// return buf.String() +//} +// +//func (t TupleTypeInfo) NewWithError() (interface{}, error) { +// typ, err := goType(t) +// if err != nil { +// return nil, err +// } +// return reflect.New(typ).Interface(), nil +//} +// +//func (t TupleTypeInfo) New() interface{} { +// val, err := t.NewWithError() +// if err != nil { +// panic(err.Error()) +// } +// return val +//} + +//type UDTField struct { +// Name string +// Type protocol.TypeInfo +//} + +//type UDTTypeInfo struct { +// protocol.NativeType +// KeySpace string +// Name string +// Elements []UDTField +//} +// +//func (u UDTTypeInfo) NewWithError() (interface{}, error) { +// typ, err := goType(u) +// if err != nil { +// return nil, err +// } +// return reflect.New(typ).Interface(), nil +//} +// +//func (u UDTTypeInfo) New() interface{} { +// val, err := u.NewWithError() +// if err != nil { +// panic(err.Error()) +// } +// return val +//} +// +//func (u UDTTypeInfo) String() string { +// buf := &bytes.Buffer{} +// +// fmt.Fprintf(buf, "%s.%s{", u.KeySpace, u.Name) +// first := true +// for _, e := range u.Elements { +// if !first { +// fmt.Fprint(buf, ",") +// } else { +// first = false +// } +// +// fmt.Fprintf(buf, "%s=%v", e.Name, e.Type) +// } +// fmt.Fprint(buf, "}") +// +// return buf.String() +//} -// String returns the name of the identifier. -func (t Type) String() string { - switch t { - case TypeCustom: - return "custom" - case TypeAscii: - return "ascii" - case TypeBigInt: - return "bigint" - case TypeBlob: - return "blob" - case TypeBoolean: - return "boolean" - case TypeCounter: - return "counter" - case TypeDecimal: - return "decimal" - case TypeDouble: - return "double" - case TypeFloat: - return "float" - case TypeInt: - return "int" - case TypeText: - return "text" - case TypeTimestamp: - return "timestamp" - case TypeUUID: - return "uuid" - case TypeVarchar: - return "varchar" - case TypeTimeUUID: - return "timeuuid" - case TypeInet: - return "inet" - case TypeDate: - return "date" - case TypeDuration: - return "duration" - case TypeTime: - return "time" - case TypeSmallInt: - return "smallint" - case TypeTinyInt: - return "tinyint" - case TypeList: - return "list" - case TypeMap: - return "map" - case TypeSet: - return "set" - case TypeVarint: - return "varint" - case TypeTuple: - return "tuple" - default: - return fmt.Sprintf("unknown_type_%d", t) - } -} +//// String returns a human readable name for the Cassandra datatype +//// described by t. +//// Type is the identifier of a Cassandra internal datatype. +//type Type int +// +//const ( +// TypeCustom Type = 0x0000 +// TypeAscii Type = 0x0001 +// TypeBigInt Type = 0x0002 +// TypeBlob Type = 0x0003 +// TypeBoolean Type = 0x0004 +// TypeCounter Type = 0x0005 +// TypeDecimal Type = 0x0006 +// TypeDouble Type = 0x0007 +// TypeFloat Type = 0x0008 +// TypeInt Type = 0x0009 +// TypeText Type = 0x000A +// TypeTimestamp Type = 0x000B +// TypeUUID Type = 0x000C +// TypeVarchar Type = 0x000D +// TypeVarint Type = 0x000E +// TypeTimeUUID Type = 0x000F +// TypeInet Type = 0x0010 +// TypeDate Type = 0x0011 +// TypeTime Type = 0x0012 +// TypeSmallInt Type = 0x0013 +// TypeTinyInt Type = 0x0014 +// TypeDuration Type = 0x0015 +// TypeList Type = 0x0020 +// TypeMap Type = 0x0021 +// TypeSet Type = 0x0022 +// TypeUDT Type = 0x0030 +// TypeTuple Type = 0x0031 +//) +// +//// String returns the name of the identifier. +//func (t Type) String() string { +// switch t { +// case TypeCustom: +// return "custom" +// case TypeAscii: +// return "ascii" +// case TypeBigInt: +// return "bigint" +// case TypeBlob: +// return "blob" +// case TypeBoolean: +// return "boolean" +// case TypeCounter: +// return "counter" +// case TypeDecimal: +// return "decimal" +// case TypeDouble: +// return "double" +// case TypeFloat: +// return "float" +// case TypeInt: +// return "int" +// case TypeText: +// return "text" +// case TypeTimestamp: +// return "timestamp" +// case TypeUUID: +// return "uuid" +// case TypeVarchar: +// return "varchar" +// case TypeTimeUUID: +// return "timeuuid" +// case TypeInet: +// return "inet" +// case TypeDate: +// return "date" +// case TypeDuration: +// return "duration" +// case TypeTime: +// return "time" +// case TypeSmallInt: +// return "smallint" +// case TypeTinyInt: +// return "tinyint" +// case TypeList: +// return "list" +// case TypeMap: +// return "map" +// case TypeSet: +// return "set" +// case TypeVarint: +// return "varint" +// case TypeTuple: +// return "tuple" +// default: +// return fmt.Sprintf("unknown_type_%d", t) +// } +//} type MarshalError string diff --git a/marshal_test.go b/marshal_test.go index 6c139e6bc..cc41a0abf 100644 --- a/marshal_test.go +++ b/marshal_test.go @@ -30,6 +30,7 @@ package gocql import ( "bytes" "encoding/binary" + "github.com/gocql/gocql/protocol" "math" "math/big" "net" @@ -56,56 +57,56 @@ var marshalTests = []struct { UnmarshalError error }{ { - NativeType{proto: 2, typ: TypeVarchar}, + protocol.NativeType{proto: 2, typ: TypeVarchar}, []byte("hello world"), []byte("hello world"), nil, nil, }, { - NativeType{proto: 2, typ: TypeVarchar}, + protocol.NativeType{proto: 2, typ: TypeVarchar}, []byte("hello world"), "hello world", nil, nil, }, { - NativeType{proto: 2, typ: TypeVarchar}, + protocol.NativeType{proto: 2, typ: TypeVarchar}, []byte(nil), []byte(nil), nil, nil, }, { - NativeType{proto: 2, typ: TypeVarchar}, + protocol.NativeType{proto: 2, typ: TypeVarchar}, []byte("hello world"), MyString("hello world"), nil, nil, }, { - NativeType{proto: 2, typ: TypeVarchar}, + protocol.NativeType{proto: 2, typ: TypeVarchar}, []byte("HELLO WORLD"), CustomString("hello world"), nil, nil, }, { - NativeType{proto: 2, typ: TypeBlob}, + protocol.NativeType{proto: 2, typ: TypeBlob}, []byte("hello\x00"), []byte("hello\x00"), nil, nil, }, { - NativeType{proto: 2, typ: TypeBlob}, + protocol.NativeType{proto: 2, typ: TypeBlob}, []byte(nil), []byte(nil), nil, nil, }, { - NativeType{proto: 2, typ: TypeTimeUUID}, + protocol.NativeType{proto: 2, typ: TypeTimeUUID}, []byte{0x3d, 0xcd, 0x98, 0x0, 0xf3, 0xd9, 0x11, 0xbf, 0x86, 0xd4, 0xb8, 0xe8, 0x56, 0x2c, 0xc, 0xd0}, func() UUID { x, _ := UUIDFromBytes([]byte{0x3d, 0xcd, 0x98, 0x0, 0xf3, 0xd9, 0x11, 0xbf, 0x86, 0xd4, 0xb8, 0xe8, 0x56, 0x2c, 0xc, 0xd0}) @@ -115,126 +116,126 @@ var marshalTests = []struct { nil, }, { - NativeType{proto: 2, typ: TypeTimeUUID}, + protocol.NativeType{proto: 2, typ: TypeTimeUUID}, []byte{0xb8, 0xe8, 0x56, 0x2c, 0xc, 0xd0}, []byte{0xb8, 0xe8, 0x56, 0x2c, 0xc, 0xd0}, MarshalError("can not marshal []byte 6 bytes long into timeuuid, must be exactly 16 bytes long"), UnmarshalError("unable to parse UUID: UUIDs must be exactly 16 bytes long"), }, { - NativeType{proto: 2, typ: TypeTimeUUID}, + protocol.NativeType{proto: 2, typ: TypeTimeUUID}, []byte{0x3d, 0xcd, 0x98, 0x0, 0xf3, 0xd9, 0x11, 0xbf, 0x86, 0xd4, 0xb8, 0xe8, 0x56, 0x2c, 0xc, 0xd0}, [16]byte{0x3d, 0xcd, 0x98, 0x0, 0xf3, 0xd9, 0x11, 0xbf, 0x86, 0xd4, 0xb8, 0xe8, 0x56, 0x2c, 0xc, 0xd0}, nil, nil, }, { - NativeType{proto: 2, typ: TypeInt}, + protocol.NativeType{proto: 2, typ: TypeInt}, []byte("\x00\x00\x00\x00"), 0, nil, nil, }, { - NativeType{proto: 2, typ: TypeInt}, + protocol.NativeType{proto: 2, typ: TypeInt}, []byte("\x01\x02\x03\x04"), int(16909060), nil, nil, }, { - NativeType{proto: 2, typ: TypeInt}, + protocol.NativeType{proto: 2, typ: TypeInt}, []byte("\x01\x02\x03\x04"), AliasInt(16909060), nil, nil, }, { - NativeType{proto: 2, typ: TypeInt}, + protocol.NativeType{proto: 2, typ: TypeInt}, []byte("\x80\x00\x00\x00"), int32(math.MinInt32), nil, nil, }, { - NativeType{proto: 2, typ: TypeInt}, + protocol.NativeType{proto: 2, typ: TypeInt}, []byte("\x7f\xff\xff\xff"), int32(math.MaxInt32), nil, nil, }, { - NativeType{proto: 2, typ: TypeInt}, + protocol.NativeType{proto: 2, typ: TypeInt}, []byte("\x00\x00\x00\x00"), "0", nil, nil, }, { - NativeType{proto: 2, typ: TypeInt}, + protocol.NativeType{proto: 2, typ: TypeInt}, []byte("\x01\x02\x03\x04"), "16909060", nil, nil, }, { - NativeType{proto: 2, typ: TypeInt}, + protocol.NativeType{proto: 2, typ: TypeInt}, []byte("\x80\x00\x00\x00"), "-2147483648", // math.MinInt32 nil, nil, }, { - NativeType{proto: 2, typ: TypeInt}, + protocol.NativeType{proto: 2, typ: TypeInt}, []byte("\x7f\xff\xff\xff"), "2147483647", // math.MaxInt32 nil, nil, }, { - NativeType{proto: 2, typ: TypeBigInt}, + protocol.NativeType{proto: 2, typ: TypeBigInt}, []byte("\x00\x00\x00\x00\x00\x00\x00\x00"), 0, nil, nil, }, { - NativeType{proto: 2, typ: TypeBigInt}, + protocol.NativeType{proto: 2, typ: TypeBigInt}, []byte("\x01\x02\x03\x04\x05\x06\x07\x08"), 72623859790382856, nil, nil, }, { - NativeType{proto: 2, typ: TypeBigInt}, + protocol.NativeType{proto: 2, typ: TypeBigInt}, []byte("\x80\x00\x00\x00\x00\x00\x00\x00"), int64(math.MinInt64), nil, nil, }, { - NativeType{proto: 2, typ: TypeBigInt}, + protocol.NativeType{proto: 2, typ: TypeBigInt}, []byte("\x7f\xff\xff\xff\xff\xff\xff\xff"), int64(math.MaxInt64), nil, nil, }, { - NativeType{proto: 2, typ: TypeBigInt}, + protocol.NativeType{proto: 2, typ: TypeBigInt}, []byte("\x00\x00\x00\x00\x00\x00\x00\x00"), "0", nil, nil, }, { - NativeType{proto: 2, typ: TypeBigInt}, + protocol.NativeType{proto: 2, typ: TypeBigInt}, []byte("\x01\x02\x03\x04\x05\x06\x07\x08"), "72623859790382856", nil, nil, }, { - NativeType{proto: 2, typ: TypeBigInt}, + protocol.NativeType{proto: 2, typ: TypeBigInt}, []byte("\x80\x00\x00\x00\x00\x00\x00\x00"), "-9223372036854775808", // math.MinInt64 nil, @@ -383,21 +384,21 @@ var marshalTests = []struct { { NativeType{proto: 5, typ: TypeDuration}, []byte("\x89\xa2\xc3\xc2\x9a\xe0F\x91\x06"), - Duration{Months: 1233, Days: 123213, Nanoseconds: 2312323}, + protocol.Duration{Months: 1233, Days: 123213, Nanoseconds: 2312323}, nil, nil, }, { NativeType{proto: 5, typ: TypeDuration}, []byte("\x89\xa1\xc3\xc2\x99\xe0F\x91\x05"), - Duration{Months: -1233, Days: -123213, Nanoseconds: -2312323}, + protocol.Duration{Months: -1233, Days: -123213, Nanoseconds: -2312323}, nil, nil, }, { NativeType{proto: 5, typ: TypeDuration}, []byte("\x02\x04\x80\xe6"), - Duration{Months: 1, Days: 2, Nanoseconds: 115}, + protocol.Duration{Months: 1, Days: 2, Nanoseconds: 115}, nil, nil, }, @@ -1297,31 +1298,31 @@ var unmarshalTests = []struct { { NativeType{proto: 5, typ: TypeDuration}, []byte("\x89\xa2\xc3\xc2\x9a\xe0F\x91"), - Duration{}, + protocol.Duration{}, UnmarshalError("failed to unmarshal duration into *gocql.Duration: failed to extract nanoseconds: data expect to have 9 bytes, but it has only 8"), }, { NativeType{proto: 5, typ: TypeDuration}, []byte("\x89\xa2\xc3\xc2\x9a"), - Duration{}, + protocol.Duration{}, UnmarshalError("failed to unmarshal duration into *gocql.Duration: failed to extract nanoseconds: unexpected eof"), }, { NativeType{proto: 5, typ: TypeDuration}, []byte("\x89\xa2\xc3\xc2"), - Duration{}, + protocol.Duration{}, UnmarshalError("failed to unmarshal duration into *gocql.Duration: failed to extract days: data expect to have 5 bytes, but it has only 4"), }, { NativeType{proto: 5, typ: TypeDuration}, []byte("\x89\xa2"), - Duration{}, + protocol.Duration{}, UnmarshalError("failed to unmarshal duration into *gocql.Duration: failed to extract days: unexpected eof"), }, { NativeType{proto: 5, typ: TypeDuration}, []byte("\x89"), - Duration{}, + protocol.Duration{}, UnmarshalError("failed to unmarshal duration into *gocql.Duration: failed to extract month: data expect to have 2 bytes, but it has only 1"), }, } diff --git a/metadata.go b/metadata.go index 6eb798f8a..0081d6399 100644 --- a/metadata.go +++ b/metadata.go @@ -32,6 +32,7 @@ import ( "encoding/hex" "encoding/json" "fmt" + "github.com/gocql/gocql/protocol" "strconv" "strings" "sync" @@ -1232,8 +1233,8 @@ func (t *typeParser) parse() typeParserResult { // treat this is a custom type return typeParserResult{ isComposite: false, - types: []TypeInfo{ - NativeType{ + types: []protocol.TypeInfo{ + protocol.NativeType{ typ: TypeCustom, custom: t.input, }, @@ -1312,7 +1313,7 @@ func (class *typeParserClassNode) asTypeInfo() TypeInfo { if strings.HasPrefix(class.name, LIST_TYPE) { elem := class.params[0].class.asTypeInfo() return CollectionType{ - NativeType: NativeType{ + NativeType: protocol.NativeType{ typ: TypeList, }, Elem: elem, @@ -1321,7 +1322,7 @@ func (class *typeParserClassNode) asTypeInfo() TypeInfo { if strings.HasPrefix(class.name, SET_TYPE) { elem := class.params[0].class.asTypeInfo() return CollectionType{ - NativeType: NativeType{ + NativeType: protocol.NativeType{ typ: TypeSet, }, Elem: elem, @@ -1331,7 +1332,7 @@ func (class *typeParserClassNode) asTypeInfo() TypeInfo { key := class.params[0].class.asTypeInfo() elem := class.params[1].class.asTypeInfo() return CollectionType{ - NativeType: NativeType{ + NativeType: protocol.NativeType{ typ: TypeMap, }, Key: key, @@ -1340,7 +1341,7 @@ func (class *typeParserClassNode) asTypeInfo() TypeInfo { } // must be a simple type or custom type - info := NativeType{typ: getApacheCassandraType(class.name)} + info := protocol.NativeType{typ: getApacheCassandraType(class.name)} if info.typ == TypeCustom { // add the entire class definition info.custom = class.input diff --git a/cqltypes.go b/protocol/cqltypes.go similarity index 98% rename from cqltypes.go rename to protocol/cqltypes.go index ce2e1cee7..72cba3fe3 100644 --- a/cqltypes.go +++ b/protocol/cqltypes.go @@ -22,7 +22,7 @@ * See the NOTICE file distributed with this work for additional information. */ -package gocql +package protocol type Duration struct { Months int32 diff --git a/protocol/marshal.go b/protocol/marshal.go new file mode 100644 index 000000000..159712894 --- /dev/null +++ b/protocol/marshal.go @@ -0,0 +1,348 @@ +package protocol + +import ( + "bytes" + "fmt" + "gopkg.in/inf.v0" + "math/big" + "reflect" + "time" +) + +// TypeInfo describes a Cassandra specific data type. +type TypeInfo interface { + Type() Type + Version() byte + Custom() string + + // New creates a pointer to an empty version of whatever type + // is referenced by the TypeInfo receiver. + // + // If there is no corresponding Go type for the CQL type, New panics. + // + // Deprecated: Use NewWithError instead. + New() interface{} + + // NewWithError creates a pointer to an empty version of whatever type + // is referenced by the TypeInfo receiver. + // + // If there is no corresponding Go type for the CQL type, NewWithError returns an error. + NewWithError() (interface{}, error) +} + +type NativeType struct { + Proto byte + Typ Type + Cust string // only used for TypeCustom +} + +func NewNativeType(proto byte, typ Type, custom string) NativeType { + return NativeType{proto, typ, custom} +} + +func (t NativeType) NewWithError() (interface{}, error) { + typ, err := goType(t) + if err != nil { + return nil, err + } + return reflect.New(typ).Interface(), nil +} + +func (t NativeType) New() interface{} { + val, err := t.NewWithError() + if err != nil { + panic(err.Error()) + } + return val +} + +func (s NativeType) Type() Type { + return s.Typ +} + +func (s NativeType) Version() byte { + return s.Proto +} + +func (s NativeType) Custom() string { + return s.Cust +} + +func (s NativeType) String() string { + switch s.Typ { + case TypeCustom: + return fmt.Sprintf("%s(%s)", s.Typ, s.Cust) + default: + return s.Typ.String() + } +} + +type CollectionType struct { + NativeType + Key TypeInfo // only used for TypeMap + Elem TypeInfo // only used for TypeMap, TypeList and TypeSet +} + +func (t CollectionType) NewWithError() (interface{}, error) { + typ, err := goType(t) + if err != nil { + return nil, err + } + return reflect.New(typ).Interface(), nil +} + +func (t CollectionType) New() interface{} { + val, err := t.NewWithError() + if err != nil { + panic(err.Error()) + } + return val +} + +func (c CollectionType) String() string { + switch c.Typ { + case TypeMap: + return fmt.Sprintf("%s(%s, %s)", c.Typ, c.Key, c.Elem) + case TypeList, TypeSet: + return fmt.Sprintf("%s(%s)", c.Typ, c.Elem) + case TypeCustom: + return fmt.Sprintf("%s(%s)", c.Typ, c.Cust) + default: + return c.Typ.String() + } +} + +type TupleTypeInfo struct { + NativeType + Elems []TypeInfo +} + +func (t TupleTypeInfo) String() string { + var buf bytes.Buffer + buf.WriteString(fmt.Sprintf("%s(", t.Typ)) + for _, elem := range t.Elems { + buf.WriteString(fmt.Sprintf("%s, ", elem)) + } + buf.Truncate(buf.Len() - 2) + buf.WriteByte(')') + return buf.String() +} + +func (t TupleTypeInfo) NewWithError() (interface{}, error) { + typ, err := goType(t) + if err != nil { + return nil, err + } + return reflect.New(typ).Interface(), nil +} + +func (t TupleTypeInfo) New() interface{} { + val, err := t.NewWithError() + if err != nil { + panic(err.Error()) + } + return val +} + +type UDTField struct { + Name string + Type TypeInfo +} + +type UDTTypeInfo struct { + NativeType + KeySpace string + Name string + Elements []UDTField +} + +func (u UDTTypeInfo) NewWithError() (interface{}, error) { + typ, err := goType(u) + if err != nil { + return nil, err + } + return reflect.New(typ).Interface(), nil +} + +func (u UDTTypeInfo) New() interface{} { + val, err := u.NewWithError() + if err != nil { + panic(err.Error()) + } + return val +} + +func (u UDTTypeInfo) String() string { + buf := &bytes.Buffer{} + + fmt.Fprintf(buf, "%s.%s{", u.KeySpace, u.Name) + first := true + for _, e := range u.Elements { + if !first { + fmt.Fprint(buf, ",") + } else { + first = false + } + + fmt.Fprintf(buf, "%s=%v", e.Name, e.Type) + } + fmt.Fprint(buf, "}") + + return buf.String() +} + +// String returns a human readable name for the Cassandra datatype +// described by t. +// Type is the identifier of a Cassandra internal datatype. +type Type int + +const ( + TypeCustom Type = 0x0000 + TypeAscii Type = 0x0001 + TypeBigInt Type = 0x0002 + TypeBlob Type = 0x0003 + TypeBoolean Type = 0x0004 + TypeCounter Type = 0x0005 + TypeDecimal Type = 0x0006 + TypeDouble Type = 0x0007 + TypeFloat Type = 0x0008 + TypeInt Type = 0x0009 + TypeText Type = 0x000A + TypeTimestamp Type = 0x000B + TypeUUID Type = 0x000C + TypeVarchar Type = 0x000D + TypeVarint Type = 0x000E + TypeTimeUUID Type = 0x000F + TypeInet Type = 0x0010 + TypeDate Type = 0x0011 + TypeTime Type = 0x0012 + TypeSmallInt Type = 0x0013 + TypeTinyInt Type = 0x0014 + TypeDuration Type = 0x0015 + TypeList Type = 0x0020 + TypeMap Type = 0x0021 + TypeSet Type = 0x0022 + TypeUDT Type = 0x0030 + TypeTuple Type = 0x0031 +) + +// String returns the name of the identifier. +func (t Type) String() string { + switch t { + case TypeCustom: + return "custom" + case TypeAscii: + return "ascii" + case TypeBigInt: + return "bigint" + case TypeBlob: + return "blob" + case TypeBoolean: + return "boolean" + case TypeCounter: + return "counter" + case TypeDecimal: + return "decimal" + case TypeDouble: + return "double" + case TypeFloat: + return "float" + case TypeInt: + return "int" + case TypeText: + return "text" + case TypeTimestamp: + return "timestamp" + case TypeUUID: + return "uuid" + case TypeVarchar: + return "varchar" + case TypeTimeUUID: + return "timeuuid" + case TypeInet: + return "inet" + case TypeDate: + return "date" + case TypeDuration: + return "duration" + case TypeTime: + return "time" + case TypeSmallInt: + return "smallint" + case TypeTinyInt: + return "tinyint" + case TypeList: + return "list" + case TypeMap: + return "map" + case TypeSet: + return "set" + case TypeVarint: + return "varint" + case TypeTuple: + return "tuple" + default: + return fmt.Sprintf("unknown_type_%d", t) + } +} + +func goType(t TypeInfo) (reflect.Type, error) { + switch t.Type() { + case TypeVarchar, TypeAscii, TypeInet, TypeText: + return reflect.TypeOf(*new(string)), nil + case TypeBigInt, TypeCounter: + return reflect.TypeOf(*new(int64)), nil + case TypeTime: + return reflect.TypeOf(*new(time.Duration)), nil + case TypeTimestamp: + return reflect.TypeOf(*new(time.Time)), nil + case TypeBlob: + return reflect.TypeOf(*new([]byte)), nil + case TypeBoolean: + return reflect.TypeOf(*new(bool)), nil + case TypeFloat: + return reflect.TypeOf(*new(float32)), nil + case TypeDouble: + return reflect.TypeOf(*new(float64)), nil + case TypeInt: + return reflect.TypeOf(*new(int)), nil + case TypeSmallInt: + return reflect.TypeOf(*new(int16)), nil + case TypeTinyInt: + return reflect.TypeOf(*new(int8)), nil + case TypeDecimal: + return reflect.TypeOf(*new(*inf.Dec)), nil + case TypeUUID, TypeTimeUUID: + return reflect.TypeOf(*new(UUID)), nil + case TypeList, TypeSet: + elemType, err := goType(t.(CollectionType).Elem) + if err != nil { + return nil, err + } + return reflect.SliceOf(elemType), nil + case TypeMap: + keyType, err := goType(t.(CollectionType).Key) + if err != nil { + return nil, err + } + valueType, err := goType(t.(CollectionType).Elem) + if err != nil { + return nil, err + } + return reflect.MapOf(keyType, valueType), nil + case TypeVarint: + return reflect.TypeOf(*new(*big.Int)), nil + case TypeTuple: + // what can we do here? all there is to do is to make a list of interface{} + tuple := t.(TupleTypeInfo) + return reflect.TypeOf(make([]interface{}, len(tuple.Elems))), nil + case TypeUDT: + return reflect.TypeOf(make(map[string]interface{})), nil + case TypeDate: + return reflect.TypeOf(*new(time.Time)), nil + case TypeDuration: + return reflect.TypeOf(*new(Duration)), nil + default: + return nil, fmt.Errorf("cannot create Go type for unknown CQL type %s", t) + } +} diff --git a/protocol/proto.go b/protocol/proto.go new file mode 100644 index 000000000..6b7ba2df1 --- /dev/null +++ b/protocol/proto.go @@ -0,0 +1,11 @@ +package protocol + +type Version byte + +const ( + ProtoVersion1 Version = 0x01 + ProtoVersion2 Version = 0x02 + ProtoVersion3 Version = 0x03 + ProtoVersion4 Version = 0x04 + ProtoVersion5 Version = 0x05 +) diff --git a/protocol/uuid..go b/protocol/uuid..go new file mode 100644 index 000000000..0c88428c0 --- /dev/null +++ b/protocol/uuid..go @@ -0,0 +1,321 @@ +package protocol + +import ( + "crypto/rand" + "errors" + "fmt" + "io" + "net" + "strings" + "sync/atomic" + "time" +) + +// The uuid package can be used to generate and parse universally unique +// identifiers, a standardized format in the form of a 128 bit number. +// +// http://tools.ietf.org/html/rfc4122 + +type UUID [16]byte + +var hardwareAddr []byte +var clockSeq uint32 + +const ( + VariantNCSCompat = 0 + VariantIETF = 2 + VariantMicrosoft = 6 + VariantFuture = 7 +) + +func init() { + if interfaces, err := net.Interfaces(); err == nil { + for _, i := range interfaces { + if i.Flags&net.FlagLoopback == 0 && len(i.HardwareAddr) > 0 { + hardwareAddr = i.HardwareAddr + break + } + } + } + if hardwareAddr == nil { + // If we failed to obtain the MAC address of the current computer, + // we will use a randomly generated 6 byte sequence instead and set + // the multicast bit as recommended in RFC 4122. + hardwareAddr = make([]byte, 6) + _, err := io.ReadFull(rand.Reader, hardwareAddr) + if err != nil { + panic(err) + } + hardwareAddr[0] = hardwareAddr[0] | 0x01 + } + + // initialize the clock sequence with a random number + var clockSeqRand [2]byte + io.ReadFull(rand.Reader, clockSeqRand[:]) + clockSeq = uint32(clockSeqRand[1])<<8 | uint32(clockSeqRand[0]) +} + +// ParseUUID parses a 32 digit hexadecimal number (that might contain hypens) +// representing an UUID. +func ParseUUID(input string) (UUID, error) { + var u UUID + j := 0 + for _, r := range input { + switch { + case r == '-' && j&1 == 0: + continue + case r >= '0' && r <= '9' && j < 32: + u[j/2] |= byte(r-'0') << uint(4-j&1*4) + case r >= 'a' && r <= 'f' && j < 32: + u[j/2] |= byte(r-'a'+10) << uint(4-j&1*4) + case r >= 'A' && r <= 'F' && j < 32: + u[j/2] |= byte(r-'A'+10) << uint(4-j&1*4) + default: + return UUID{}, fmt.Errorf("invalid UUID %q", input) + } + j += 1 + } + if j != 32 { + return UUID{}, fmt.Errorf("invalid UUID %q", input) + } + return u, nil +} + +// Deprecated: use protocol.UUIDFromBytes instead. +// UUIDFromBytes converts a raw byte slice to an UUID. +func UUIDFromBytes(input []byte) (UUID, error) { + var u UUID + if len(input) != 16 { + return u, errors.New("UUIDs must be exactly 16 bytes long") + } + + copy(u[:], input) + return u, nil +} + +func MustRandomUUID() UUID { + uuid, err := RandomUUID() + if err != nil { + panic(err) + } + return uuid +} + +// RandomUUID generates a totally random UUID (version 4) as described in +// RFC 4122. +func RandomUUID() (UUID, error) { + var u UUID + _, err := io.ReadFull(rand.Reader, u[:]) + if err != nil { + return u, err + } + u[6] &= 0x0F // clear version + u[6] |= 0x40 // set version to 4 (random uuid) + u[8] &= 0x3F // clear variant + u[8] |= 0x80 // set to IETF variant + return u, nil +} + +var timeBase = time.Date(1582, time.October, 15, 0, 0, 0, 0, time.UTC).Unix() + +// getTimestamp converts time to UUID (version 1) timestamp. +// It must be an interval of 100-nanoseconds since timeBase. +func getTimestamp(t time.Time) int64 { + utcTime := t.In(time.UTC) + ts := int64(utcTime.Unix()-timeBase)*10000000 + int64(utcTime.Nanosecond()/100) + + return ts +} + +// TimeUUID generates a new time based UUID (version 1) using the current +// time as the timestamp. +func TimeUUID() UUID { + return UUIDFromTime(time.Now()) +} + +// The min and max clock values for a UUID. +// +// Cassandra's TimeUUIDType compares the lsb parts as signed byte arrays. +// Thus, the min value for each byte is -128 and the max is +127. +const ( + minClock = 0x8080 + maxClock = 0x7f7f +) + +// The min and max node values for a UUID. +// +// See explanation about Cassandra's TimeUUIDType comparison logic above. +var ( + minNode = []byte{0x80, 0x80, 0x80, 0x80, 0x80, 0x80} + maxNode = []byte{0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f} +) + +// MinTimeUUID generates a "fake" time based UUID (version 1) which will be +// the smallest possible UUID generated for the provided timestamp. +// +// UUIDs generated by this function are not unique and are mostly suitable only +// in queries to select a time range of a Cassandra's TimeUUID column. +func MinTimeUUID(t time.Time) UUID { + return TimeUUIDWith(getTimestamp(t), minClock, minNode) +} + +// MaxTimeUUID generates a "fake" time based UUID (version 1) which will be +// the biggest possible UUID generated for the provided timestamp. +// +// UUIDs generated by this function are not unique and are mostly suitable only +// in queries to select a time range of a Cassandra's TimeUUID column. +func MaxTimeUUID(t time.Time) UUID { + return TimeUUIDWith(getTimestamp(t), maxClock, maxNode) +} + +// UUIDFromTime generates a new time based UUID (version 1) as described in +// RFC 4122. This UUID contains the MAC address of the node that generated +// the UUID, the given timestamp and a sequence number. +func UUIDFromTime(t time.Time) UUID { + ts := getTimestamp(t) + clock := atomic.AddUint32(&clockSeq, 1) + + return TimeUUIDWith(ts, clock, hardwareAddr) +} + +// TimeUUIDWith generates a new time based UUID (version 1) as described in +// RFC4122 with given parameters. t is the number of 100's of nanoseconds +// since 15 Oct 1582 (60bits). clock is the number of clock sequence (14bits). +// node is a slice to gurarantee the uniqueness of the UUID (up to 6bytes). +// Note: calling this function does not increment the static clock sequence. +func TimeUUIDWith(t int64, clock uint32, node []byte) UUID { + var u UUID + + u[0], u[1], u[2], u[3] = byte(t>>24), byte(t>>16), byte(t>>8), byte(t) + u[4], u[5] = byte(t>>40), byte(t>>32) + u[6], u[7] = byte(t>>56)&0x0F, byte(t>>48) + + u[8] = byte(clock >> 8) + u[9] = byte(clock) + + copy(u[10:], node) + + u[6] |= 0x10 // set version to 1 (time based uuid) + u[8] &= 0x3F // clear variant + u[8] |= 0x80 // set to IETF variant + + return u +} + +// String returns the UUID in it's canonical form, a 32 digit hexadecimal +// number in the form of xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx. +func (u UUID) String() string { + var offsets = [...]int{0, 2, 4, 6, 9, 11, 14, 16, 19, 21, 24, 26, 28, 30, 32, 34} + const hexString = "0123456789abcdef" + r := make([]byte, 36) + for i, b := range u { + r[offsets[i]] = hexString[b>>4] + r[offsets[i]+1] = hexString[b&0xF] + } + r[8] = '-' + r[13] = '-' + r[18] = '-' + r[23] = '-' + return string(r) + +} + +// Bytes returns the raw byte slice for this UUID. A UUID is always 128 bits +// (16 bytes) long. +func (u UUID) Bytes() []byte { + return u[:] +} + +// Variant returns the variant of this UUID. This package will only generate +// UUIDs in the IETF variant. +func (u UUID) Variant() int { + x := u[8] + if x&0x80 == 0 { + return VariantNCSCompat + } + if x&0x40 == 0 { + return VariantIETF + } + if x&0x20 == 0 { + return VariantMicrosoft + } + return VariantFuture +} + +// Version extracts the version of this UUID variant. The RFC 4122 describes +// five kinds of UUIDs. +func (u UUID) Version() int { + return int(u[6] & 0xF0 >> 4) +} + +// Node extracts the MAC address of the node who generated this UUID. It will +// return nil if the UUID is not a time based UUID (version 1). +func (u UUID) Node() []byte { + if u.Version() != 1 { + return nil + } + return u[10:] +} + +// Clock extracts the clock sequence of this UUID. It will return zero if the +// UUID is not a time based UUID (version 1). +func (u UUID) Clock() uint32 { + if u.Version() != 1 { + return 0 + } + + // Clock sequence is the lower 14bits of u[8:10] + return uint32(u[8]&0x3F)<<8 | uint32(u[9]) +} + +// Timestamp extracts the timestamp information from a time based UUID +// (version 1). +func (u UUID) Timestamp() int64 { + if u.Version() != 1 { + return 0 + } + return int64(uint64(u[0])<<24|uint64(u[1])<<16| + uint64(u[2])<<8|uint64(u[3])) + + int64(uint64(u[4])<<40|uint64(u[5])<<32) + + int64(uint64(u[6]&0x0F)<<56|uint64(u[7])<<48) +} + +// Time is like Timestamp, except that it returns a time.Time. +func (u UUID) Time() time.Time { + if u.Version() != 1 { + return time.Time{} + } + t := u.Timestamp() + sec := t / 1e7 + nsec := (t % 1e7) * 100 + return time.Unix(sec+timeBase, nsec).UTC() +} + +// Marshaling for JSON +func (u UUID) MarshalJSON() ([]byte, error) { + return []byte(`"` + u.String() + `"`), nil +} + +// Unmarshaling for JSON +func (u *UUID) UnmarshalJSON(data []byte) error { + str := strings.Trim(string(data), `"`) + if len(str) > 36 { + return fmt.Errorf("invalid JSON UUID %s", str) + } + + parsed, err := ParseUUID(str) + if err == nil { + copy(u[:], parsed[:]) + } + + return err +} + +func (u UUID) MarshalText() ([]byte, error) { + return []byte(u.String()), nil +} + +func (u *UUID) UnmarshalText(text []byte) (err error) { + *u, err = ParseUUID(string(text)) + return +} diff --git a/session.go b/session.go index b884735c2..1bdd12ddb 100644 --- a/session.go +++ b/session.go @@ -30,6 +30,8 @@ import ( "encoding/binary" "errors" "fmt" + "github.com/gocql/gocql/protocol" + "github.com/gocql/gocql/session" "io" "net" "strings" @@ -655,7 +657,7 @@ func (s *Session) routingKeyInfo(ctx context.Context, stmt string) (*routingKeyI if len(info.request.pkeyColumns) > 0 { // proto v4 dont need to calculate primary key columns - types := make([]TypeInfo, len(info.request.pkeyColumns)) + types := make([]protocol.TypeInfo, len(info.request.pkeyColumns)) for i, col := range info.request.pkeyColumns { types[i] = info.request.columns[col].TypeInfo } @@ -695,7 +697,7 @@ func (s *Session) routingKeyInfo(ctx context.Context, stmt string) (*routingKeyI size := len(partitionKey) routingKeyInfo := &routingKeyInfo{ indexes: make([]int, size), - types: make([]TypeInfo, size), + types: make([]protocol.TypeInfo, size), keyspace: keyspace, table: table, } @@ -1506,9 +1508,9 @@ func scanColumn(p []byte, col ColumnInfo, dest []interface{}) (int, error) { return 1, nil } - if col.TypeInfo.Type() == TypeTuple { + if col.TypeInfo.Type() == protocol.TypeTuple { // this will panic, actually a bug, please report - tuple := col.TypeInfo.(TupleTypeInfo) + tuple := col.TypeInfo.(protocol.TupleTypeInfo) count := len(tuple.Elems) // here we pass in a slice of the struct which has the number number of @@ -1725,7 +1727,7 @@ func (n *nextIter) fetch() *Iter { } type Batch struct { - Type BatchType + Type session.BatchType Entries []BatchEntry Cons Consistency routingKey []byte @@ -1748,7 +1750,7 @@ type Batch struct { } // NewBatch creates a new batch operation using defaults defined in the cluster -func (s *Session) NewBatch(typ BatchType) *Batch { +func (s *Session) NewBatch(typ session.BatchType) *Batch { s.mu.RLock() batch := &Batch{ Type: typ, @@ -2030,12 +2032,12 @@ func (b *Batch) releaseAfterExecution() { // that would race with speculative executions. } -type BatchType byte +//type BatchType byte const ( - LoggedBatch BatchType = 0 - UnloggedBatch BatchType = 1 - CounterBatch BatchType = 2 + LoggedBatch session.BatchType = 0 + UnloggedBatch session.BatchType = 1 + CounterBatch session.BatchType = 2 ) type BatchEntry struct { @@ -2049,7 +2051,7 @@ type ColumnInfo struct { Keyspace string Table string Name string - TypeInfo TypeInfo + TypeInfo protocol.TypeInfo } func (c ColumnInfo) String() string { @@ -2064,7 +2066,7 @@ type routingKeyInfoLRU struct { type routingKeyInfo struct { indexes []int - types []TypeInfo + types []protocol.TypeInfo keyspace string table string } diff --git a/session/session.go b/session/session.go new file mode 100644 index 000000000..c89c9d5f9 --- /dev/null +++ b/session/session.go @@ -0,0 +1,25 @@ +package session + +import ( + "fmt" + "github.com/gocql/gocql/protocol" +) + +type ErrProtocol struct{ error } + +type BatchType byte + +func NewErrProtocol(format string, args ...interface{}) error { + return ErrProtocol{fmt.Errorf(format, args...)} +} + +type ColumnInfo struct { + Keyspace string + Table string + Name string + TypeInfo protocol.TypeInfo +} + +func (c ColumnInfo) String() string { + return fmt.Sprintf("[column keyspace=%s table=%s name=%s type=%v]", c.Keyspace, c.Table, c.Name, c.TypeInfo) +} diff --git a/uuid.go b/uuid.go index cc5f1c21f..e6ffe57f6 100644 --- a/uuid.go +++ b/uuid.go @@ -105,6 +105,7 @@ func ParseUUID(input string) (UUID, error) { return u, nil } +// Deprecated: use ololo // UUIDFromBytes converts a raw byte slice to an UUID. func UUIDFromBytes(input []byte) (UUID, error) { var u UUID