diff --git a/cluster.go b/cluster.go index 13e62f3b0..1fd81fd65 100644 --- a/cluster.go +++ b/cluster.go @@ -214,6 +214,10 @@ type ClusterConfig struct { // See https://issues.apache.org/jira/browse/CASSANDRA-10786 DisableSkipMetadata bool + // QueryAttemptInterceptor will set the provided query interceptor on all queries created from this session. + // Use it to intercept and modify queries by providing an implementation of QueryAttemptInterceptor. + QueryAttemptInterceptor QueryAttemptInterceptor + // QueryObserver will set the provided query observer on all queries created from this session. // Use it to collect metrics / stats from queries by providing an implementation of QueryObserver. QueryObserver QueryObserver diff --git a/doc.go b/doc.go index f23e812c5..94ad85211 100644 --- a/doc.go +++ b/doc.go @@ -362,6 +362,18 @@ // // See Example_userDefinedTypesMap, Example_userDefinedTypesStruct, ExampleUDTMarshaler, ExampleUDTUnmarshaler. // +// # Interceptors +// +// A QueryAttemptInterceptor wraps query execution and can be used to inject logic that should apply to all query +// and batch execution attempts. For example, interceptors can be used for rate limiting, logging, attaching +// distributed tracing metadata to the context, modifying queries, and inspecting query results. +// +// A QueryAttemptInterceptor will be invoked once prior to each query execution attempt, including retry attempts +// and speculative execution attempts. Interceptors are responsible for calling the provided handler and returning +// a non-nil Iter or an error. +// +// See Example_interceptor for full example. +// // # Metrics and tracing // // It is possible to provide observer implementations that could be used to gather metrics: diff --git a/example_interceptor_test.go b/example_interceptor_test.go new file mode 100644 index 000000000..097fa4eb2 --- /dev/null +++ b/example_interceptor_test.go @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/* + * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 + * Copyright (c) 2016, The Gocql authors, + * provided under the BSD-3-Clause License. + * See the NOTICE file distributed with this work for additional information. + */ + +package gocql_test + +import ( + "context" + "log" + "time" + + gocql "github.com/gocql/gocql" +) + +type MyQueryAttemptInterceptor struct { + injectFault bool +} + +func (q MyQueryAttemptInterceptor) Intercept( + ctx context.Context, + attempt gocql.QueryAttempt, + handler gocql.QueryAttemptHandler, +) (*gocql.Iter, error) { + switch q := attempt.Query.(type) { + case *gocql.Query: + // Inspect or modify query + attempt.Query = q + case *gocql.Batch: + // Inspect or modify batch + attempt.Query = q + } + + // Inspect or modify context + ctx = context.WithValue(ctx, "trace-id", "123") + + // Optionally bypass the handler and return an error to prevent query execution. + // For example, to simulate query timeouts. + if q.injectFault && attempt.Attempts == 0 { + <-time.After(1 * time.Second) + return nil, gocql.RequestErrWriteTimeout{} + } + + // The interceptor *must* invoke the handler to execute the query. + return handler(ctx, attempt) +} + +// Example_interceptor demonstrates how to implement a QueryAttemptInterceptor. +func Example_interceptor() { + cluster := gocql.NewCluster("localhost:9042") + cluster.QueryAttemptInterceptor = MyQueryAttemptInterceptor{injectFault: true} + + session, err := cluster.CreateSession() + if err != nil { + log.Fatal(err) + } + defer session.Close() + + ctx := context.Background() + + var stringValue string + err = session.Query("select now() from system.local"). + WithContext(ctx). + RetryPolicy(&gocql.SimpleRetryPolicy{NumRetries: 2}). + Scan(&stringValue) + if err != nil { + log.Fatalf("query failed %T", err) + } +} + +type QueryAttemptInterceptorChain struct { + interceptors []gocql.QueryAttemptInterceptor +} + +func (c QueryAttemptInterceptorChain) Intercept( + ctx context.Context, + attempt gocql.QueryAttempt, + handler gocql.QueryAttemptHandler, +) (*gocql.Iter, error) { + return c.interceptors[0].Intercept(ctx, attempt, c.getNextHandler(0, handler)) +} + +func (c QueryAttemptInterceptorChain) getNextHandler(curr int, final gocql.QueryAttemptHandler) gocql.QueryAttemptHandler { + if curr == len(c.interceptors)-1 { + return final + } + + return func(ctx context.Context, attempt gocql.QueryAttempt) (*gocql.Iter, error) { + return c.interceptors[curr+1].Intercept(ctx, attempt, c.getNextHandler(curr+1, final)) + } +} + +// Example_interceptor_chain demonstrates how to chain QueryAttemptInterceptors. +func Example_interceptor_chain() { + cluster := gocql.NewCluster("localhost:9042") + cluster.QueryAttemptInterceptor = QueryAttemptInterceptorChain{ + []gocql.QueryAttemptInterceptor{ + MyQueryAttemptInterceptor{}, + MyQueryAttemptInterceptor{}, + MyQueryAttemptInterceptor{}, + }, + } + + session, err := cluster.CreateSession() + if err != nil { + log.Fatal(err) + } + defer session.Close() + + ctx := context.Background() + + var stringValue string + err = session.Query("select now() from system.local"). + WithContext(ctx). + RetryPolicy(&gocql.SimpleRetryPolicy{NumRetries: 2}). + Scan(&stringValue) + if err != nil { + log.Fatalf("query failed %T", err) + } +} diff --git a/query_executor.go b/query_executor.go index fb68b07f2..03f0c3aa0 100644 --- a/query_executor.go +++ b/query_executor.go @@ -26,6 +26,7 @@ package gocql import ( "context" + "net" "sync" "time" ) @@ -34,7 +35,7 @@ type ExecutableQuery interface { borrowForExecution() // Used to ensure that the query stays alive for lifetime of a particular execution goroutine. releaseAfterExecution() // Used when a goroutine finishes its execution attempts, either with ok result or an error. execute(ctx context.Context, conn *Conn) *Iter - attempt(keyspace string, end, start time.Time, iter *Iter, host *HostInfo) + attempt(ctx context.Context, keyspace string, end, start time.Time, iter *Iter, host *HostInfo) retryPolicy() RetryPolicy speculativeExecutionPolicy() SpeculativeExecutionPolicy GetRoutingKey() ([]byte, error) @@ -48,16 +49,68 @@ type ExecutableQuery interface { } type queryExecutor struct { - pool *policyConnPool - policy HostSelectionPolicy + pool *policyConnPool + policy HostSelectionPolicy + interceptor QueryAttemptInterceptor +} + +type QueryAttempt struct { + // The query to execute, either a *gocql.Query or *gocql.Batch. + Query ExecutableQuery + // The host that will receive the query. + Host *HostInfo + // The local address of the connection used to execute the query. + LocalAddr net.Addr + // The remote address of the connection used to execute the query. + RemoteAddr net.Addr + // The number of previous query attempts. 0 for the initial attempt, 1 for the first retry, etc. + Attempts int +} + +// QueryAttemptHandler is a function that attempts query execution. +type QueryAttemptHandler = func(context.Context, QueryAttempt) (*Iter, error) + +// QueryAttemptInterceptor is the interface implemented by query interceptors / middleware. +// +// Interceptors are well-suited to logic that is not specific to a single query or batch. +type QueryAttemptInterceptor interface { + // Intercept is invoked once immediately before a query execution attempt, including retry attempts and + // speculative execution attempts. + + // The interceptor is responsible for calling the `handler` function and returning the handler result. If the + // interceptor wants to bypass the handler and skip query execution, it should return an error. Failure to + // return either the handler result or an error will panic. + Intercept(ctx context.Context, attempt QueryAttempt, handler QueryAttemptHandler) (*Iter, error) } func (q *queryExecutor) attemptQuery(ctx context.Context, qry ExecutableQuery, conn *Conn) *Iter { start := time.Now() - iter := qry.execute(ctx, conn) - end := time.Now() + var iter *Iter + var err error + if q.interceptor != nil { + // Propagate interceptor context modifications. + _ctx := ctx + attempt := QueryAttempt{ + Query: qry, + Host: conn.host, + LocalAddr: conn.conn.LocalAddr(), + RemoteAddr: conn.conn.RemoteAddr(), + Attempts: qry.Attempts(), + } + iter, err = q.interceptor.Intercept(_ctx, attempt, func(_ctx context.Context, attempt QueryAttempt) (*Iter, error) { + ctx = _ctx + iter := attempt.Query.execute(ctx, conn) + return iter, iter.err + }) + if err != nil { + iter = &Iter{err: err} + } + } else { + iter = qry.execute(ctx, conn) + } - qry.attempt(q.pool.keyspace, end, start, iter, conn.host) + end := time.Now() + qry.attempt(ctx, q.pool.keyspace, end, start, iter, conn.host) return iter } diff --git a/session.go b/session.go index a600b95f3..8fe777bac 100644 --- a/session.go +++ b/session.go @@ -178,8 +178,9 @@ func NewSession(cfg ClusterConfig) (*Session, error) { s.policy.Init(s) s.executor = &queryExecutor{ - pool: s.pool, - policy: cfg.PoolConfig.HostSelectionPolicy, + pool: s.pool, + policy: cfg.PoolConfig.HostSelectionPolicy, + interceptor: cfg.QueryAttemptInterceptor, } s.queryObserver = cfg.QueryObserver @@ -1111,12 +1112,12 @@ func (q *Query) execute(ctx context.Context, conn *Conn) *Iter { return conn.executeQuery(ctx, q) } -func (q *Query) attempt(keyspace string, end, start time.Time, iter *Iter, host *HostInfo) { +func (q *Query) attempt(ctx context.Context, keyspace string, end, start time.Time, iter *Iter, host *HostInfo) { latency := end.Sub(start) attempt, metricsForHost := q.metrics.attempt(1, latency, host, q.observer != nil) if q.observer != nil { - q.observer.ObserveQuery(q.Context(), ObservedQuery{ + q.observer.ObserveQuery(ctx, ObservedQuery{ Keyspace: keyspace, Statement: q.stmt, Values: q.values, @@ -1942,7 +1943,7 @@ func (b *Batch) WithTimestamp(timestamp int64) *Batch { return b } -func (b *Batch) attempt(keyspace string, end, start time.Time, iter *Iter, host *HostInfo) { +func (b *Batch) attempt(ctx context.Context, keyspace string, end, start time.Time, iter *Iter, host *HostInfo) { latency := end.Sub(start) attempt, metricsForHost := b.metrics.attempt(1, latency, host, b.observer != nil) @@ -1958,7 +1959,7 @@ func (b *Batch) attempt(keyspace string, end, start time.Time, iter *Iter, host values[i] = entry.Args } - b.observer.ObserveBatch(b.Context(), ObservedBatch{ + b.observer.ObserveBatch(ctx, ObservedBatch{ Keyspace: keyspace, Statements: statements, Values: values,