Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Let retry policy to decide about non idempotent queries #376

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 40 additions & 15 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -1101,13 +1101,13 @@ func (c *Conn) addCall(call *callReq) error {

func (c *Conn) exec(ctx context.Context, req frameBuilder, tracer Tracer) (*framer, error) {
if ctxErr := ctx.Err(); ctxErr != nil {
return nil, ctxErr
return nil, &QueryError{err: ctxErr, potentiallyExecuted: false}
}

// TODO: move tracer onto conn
stream, ok := c.streams.GetStream()
if !ok {
return nil, ErrNoStreams
return nil, &QueryError{err: ErrNoStreams, potentiallyExecuted: false}
}

// resp is basically a waiting semaphore protecting the framer
Expand All @@ -1125,7 +1125,7 @@ func (c *Conn) exec(ctx context.Context, req frameBuilder, tracer Tracer) (*fram
}

if err := c.addCall(call); err != nil {
return nil, err
return nil, &QueryError{err: err, potentiallyExecuted: false}
}

// After this point, we need to either read from call.resp or close(call.timeout)
Expand Down Expand Up @@ -1157,7 +1157,7 @@ func (c *Conn) exec(ctx context.Context, req frameBuilder, tracer Tracer) (*fram
// We need to release the stream after we remove the call from c.calls, otherwise the existingCall != nil
// check above could fail.
c.releaseStream(call)
return nil, err
return nil, &QueryError{err: err, potentiallyExecuted: false}
}

n, err := c.w.writeContext(ctx, framer.buf)
Expand Down Expand Up @@ -1185,7 +1185,7 @@ func (c *Conn) exec(ctx context.Context, req frameBuilder, tracer Tracer) (*fram
// send a frame on, with all the streams used up and not returned.
c.closeWithError(err)
}
return nil, err
return nil, &QueryError{err: err, potentiallyExecuted: true}
}

var timeoutCh <-chan time.Time
Expand Down Expand Up @@ -1222,7 +1222,7 @@ func (c *Conn) exec(ctx context.Context, req frameBuilder, tracer Tracer) (*fram
// connection to close.
c.releaseStream(call)
}
return nil, resp.err
return nil, &QueryError{err: resp.err, potentiallyExecuted: true}
}
// dont release the stream if detect a timeout as another request can reuse
// that stream and get a response for the old request, which we have no
Expand All @@ -1233,20 +1233,20 @@ func (c *Conn) exec(ctx context.Context, req frameBuilder, tracer Tracer) (*fram
defer c.releaseStream(call)

if v := resp.framer.header.version.version(); v != c.version {
return nil, NewErrProtocol("unexpected protocol version in response: got %d expected %d", v, c.version)
return nil, &QueryError{err: NewErrProtocol("unexpected protocol version in response: got %d expected %d", v, c.version), potentiallyExecuted: true}
}

return resp.framer, nil
case <-timeoutCh:
close(call.timeout)
c.handleTimeout()
return nil, ErrTimeoutNoResponse
return nil, &QueryError{err: ErrTimeoutNoResponse, potentiallyExecuted: true}
case <-ctxDone:
close(call.timeout)
return nil, ctx.Err()
return nil, &QueryError{err: ctx.Err(), potentiallyExecuted: true}
case <-c.ctx.Done():
close(call.timeout)
return nil, ErrConnectionClosed
return nil, &QueryError{err: ErrConnectionClosed, potentiallyExecuted: true}
}
}

Expand Down Expand Up @@ -1906,11 +1906,14 @@ func (c *Conn) awaitSchemaAgreement(ctx context.Context) error {
}

var (
ErrQueryArgLength = errors.New("gocql: query argument length mismatch")
ErrTimeoutNoResponse = errors.New("gocql: no response received from cassandra within timeout period")
ErrTooManyTimeouts = errors.New("gocql: too many query timeouts on the connection")
ErrConnectionClosed = errors.New("gocql: connection closed waiting for response")
ErrNoStreams = errors.New("gocql: no streams available on connection")
ErrQueryArgLength = errors.New("gocql: query argument length mismatch")
ErrTimeoutNoResponse = errors.New("gocql: no response received from cassandra within timeout period")
ErrTooManyTimeouts = errors.New("gocql: too many query timeouts on the connection")
ErrConnectionClosed = errors.New("gocql: connection closed waiting for response")
ErrNoStreams = errors.New("gocql: no streams available on connection")
ErrHostDown = errors.New("gocql: host is nil or down")
ErrNoPool = errors.New("gocql: host does not have a pool")
ErrNoConnectionsInPool = errors.New("gocql: host pool does not have connections")
)

type ErrSchemaMismatch struct {
Expand All @@ -1920,3 +1923,25 @@ type ErrSchemaMismatch struct {
func (e *ErrSchemaMismatch) Error() string {
return fmt.Sprintf("gocql: cluster schema versions not consistent: %+v", e.schemas)
}

type QueryError struct {
err error
potentiallyExecuted bool
isIdempotent bool
}

func (e *QueryError) IsIdempotent() bool {
return e.isIdempotent
}

func (e *QueryError) PotentiallyExecuted() bool {
return e.potentiallyExecuted
}

func (e *QueryError) Error() string {
return fmt.Sprintf("%s (potentially executed: %v)", e.err.Error(), e.potentiallyExecuted)
}

func (e *QueryError) Unwrap() error {
return e.err
}
16 changes: 10 additions & 6 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ func TestCancel(t *testing.T) {
wg.Add(1)

go func() {
if err := qry.Exec(); err != context.Canceled {
if err := qry.Exec(); !errors.Is(err, context.Canceled) {
t.Fatalf("expected to get context cancel error: '%v', got '%v'", context.Canceled, err)
}
wg.Done()
Expand Down Expand Up @@ -456,6 +456,10 @@ func (t *testRetryPolicy) Attempt(qry RetryableQuery) bool {
return qry.Attempts() <= t.NumRetries
}
func (t *testRetryPolicy) GetRetryType(err error) RetryType {
var executedErr *QueryError
if errors.As(err, &executedErr) && executedErr.PotentiallyExecuted() && !executedErr.IsIdempotent() {
return Rethrow
}
return Retry
}

Expand Down Expand Up @@ -573,7 +577,7 @@ func TestQueryTimeout(t *testing.T) {

select {
case err := <-ch:
if err != ErrTimeoutNoResponse {
if !errors.Is(err, ErrTimeoutNoResponse) {
t.Fatalf("expected to get %v for timeout got %v", ErrTimeoutNoResponse, err)
}
case <-time.After(40*time.Millisecond + db.cfg.Timeout):
Expand Down Expand Up @@ -667,8 +671,8 @@ func TestQueryTimeoutClose(t *testing.T) {
t.Fatal("timedout waiting to get a response once cluster is closed")
}

if err != ErrConnectionClosed {
t.Fatalf("expected to get %v got %v", ErrConnectionClosed, err)
if !errors.Is(err, ErrConnectionClosed) {
t.Fatalf("expected to get %v or an error wrapping it, got %v", ErrConnectionClosed, err)
}
}

Expand Down Expand Up @@ -721,7 +725,7 @@ func TestContext_Timeout(t *testing.T) {
cancel()

err = db.Query("timeout").WithContext(ctx).Exec()
if err != context.Canceled {
if !errors.Is(err, context.Canceled) {
t.Fatalf("expected to get context cancel error: %v got %v", context.Canceled, err)
}
}
Expand Down Expand Up @@ -838,7 +842,7 @@ func TestContext_CanceledBeforeExec(t *testing.T) {
cancel()

err = db.Query("timeout").WithContext(ctx).Exec()
if err != context.Canceled {
if !errors.Is(err, context.Canceled) {
t.Fatalf("expected to get context cancel error: %v got %v", context.Canceled, err)
}

Expand Down
24 changes: 24 additions & 0 deletions policies.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@ func (s *SimpleRetryPolicy) AttemptLWT(q RetryableQuery) bool {
}

func (s *SimpleRetryPolicy) GetRetryType(err error) RetryType {
var executedErr *QueryError
if errors.As(err, &executedErr) && executedErr.PotentiallyExecuted() && !executedErr.IsIdempotent() {
return Rethrow
}
return RetryNextHost
}

Expand All @@ -168,6 +172,10 @@ func (s *SimpleRetryPolicy) GetRetryType(err error) RetryType {
// even timeouts if other clients send statements touching the same
// partition to the original node at the same time.
func (s *SimpleRetryPolicy) GetRetryTypeLWT(err error) RetryType {
var executedErr *QueryError
if errors.As(err, &executedErr) && executedErr.PotentiallyExecuted() && !executedErr.IsIdempotent() {
return Rethrow
}
return Retry
}

Expand Down Expand Up @@ -208,6 +216,10 @@ func getExponentialTime(min time.Duration, max time.Duration, attempts int) time
}

func (e *ExponentialBackoffRetryPolicy) GetRetryType(err error) RetryType {
var executedErr *QueryError
if errors.As(err, &executedErr) && executedErr.PotentiallyExecuted() && !executedErr.IsIdempotent() {
return Rethrow
}
return RetryNextHost
}

Expand All @@ -216,6 +228,10 @@ func (e *ExponentialBackoffRetryPolicy) GetRetryType(err error) RetryType {
// even timeouts if other clients send statements touching the same
// partition to the original node at the same time.
func (e *ExponentialBackoffRetryPolicy) GetRetryTypeLWT(err error) RetryType {
var executedErr *QueryError
if errors.As(err, &executedErr) && executedErr.PotentiallyExecuted() && !executedErr.IsIdempotent() {
return Rethrow
}
return Retry
}

Expand Down Expand Up @@ -250,6 +266,14 @@ func (d *DowngradingConsistencyRetryPolicy) Attempt(q RetryableQuery) bool {
}

func (d *DowngradingConsistencyRetryPolicy) GetRetryType(err error) RetryType {
var executedErr *QueryError
if errors.As(err, &executedErr) {
err = executedErr.err
if executedErr.PotentiallyExecuted() && !executedErr.IsIdempotent() {
return Rethrow
}
}

switch t := err.(type) {
case *RequestErrUnavailable:
if t.Alive > 0 {
Expand Down
99 changes: 61 additions & 38 deletions query_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package gocql

import (
"context"
"errors"
"sync"
"time"
)
Expand Down Expand Up @@ -115,45 +116,69 @@ func (q *queryExecutor) do(ctx context.Context, qry ExecutableQuery, hostIter Ne

var lastErr error
var iter *Iter
var conn *Conn
var potentiallyExecuted bool
for selectedHost != nil {
host := selectedHost.Info()
if host == nil || !host.IsUp() {
selectedHost = hostIter()
continue
}

pool, ok := q.pool.getPool(host)
if !ok {
selectedHost = hostIter()
continue
}

conn := pool.Pick(selectedHost.Token(), qry)
if conn == nil {
selectedHost = hostIter()
continue
}

iter = q.attemptQuery(ctx, qry, conn)
iter.host = selectedHost.Info()
// Update host
switch iter.err {
case context.Canceled, context.DeadlineExceeded, ErrNotFound:
// those errors represents logical errors, they should not count
// toward removing a node from the pool
selectedHost.Mark(nil)
return iter
default:
selectedHost.Mark(iter.err)
}

// Exit if the query was successful
// or no retry policy defined
if iter.err == nil || rt == nil {
return iter
if rt == nil {
dkropachev marked this conversation as resolved.
Show resolved Hide resolved
selectedHost = hostIter()
lastErr = ErrHostDown
continue
}
} else {
pool, ok := q.pool.getPool(host)
if !ok {
if rt == nil {
selectedHost = hostIter()
lastErr = ErrNoPool
continue
}
} else {
conn = pool.Pick(selectedHost.Token(), qry)
if conn == nil {
if rt == nil {
selectedHost = hostIter()
lastErr = ErrNoConnectionsInPool
continue
}
} else {
iter = q.attemptQuery(ctx, qry, conn)
iter.host = selectedHost.Info()
// Update host
switch {
case errors.Is(iter.err, context.Canceled),
errors.Is(iter.err, context.DeadlineExceeded),
errors.Is(iter.err, ErrNotFound):
// those errors represents logical errors, they should not count
// toward removing a node from the pool
selectedHost.Mark(nil)
if potentiallyExecuted && !qry.IsIdempotent() {
iter.err = &QueryError{err: iter.err, potentiallyExecuted: true, isIdempotent: false}
}
return iter
dkropachev marked this conversation as resolved.
Show resolved Hide resolved
default:
selectedHost.Mark(iter.err)
}

// Exit if the query was successful
// or no retry policy defined
if iter.err == nil || rt == nil {
return iter
dkropachev marked this conversation as resolved.
Show resolved Hide resolved
}

lastErr = iter.err

if customErr, ok := iter.err.(*QueryError); ok && customErr.PotentiallyExecuted() {
customErr.isIdempotent = qry.IsIdempotent()
lastErr = customErr
potentiallyExecuted = true
}
}
}
}

// or retry policy decides to not retry anymore
// Exit if retry policy decides to not retry anymore
if use_lwt_rt {
if !lwt_rt.AttemptLWT(qry) {
return iter
Expand All @@ -164,13 +189,11 @@ func (q *queryExecutor) do(ctx context.Context, qry ExecutableQuery, hostIter Ne
}
}

lastErr = iter.err

var retry_type RetryType
if use_lwt_rt {
retry_type = lwt_rt.GetRetryTypeLWT(iter.err)
retry_type = lwt_rt.GetRetryTypeLWT(lastErr)
} else {
retry_type = rt.GetRetryType(iter.err)
retry_type = rt.GetRetryType(lastErr)
}

// If query is unsuccessful, check the error with RetryPolicy to retry
Expand Down
Loading