diff --git a/go/pools/smartconnpool/pool.go b/go/pools/smartconnpool/pool.go index 2beb0b8ac46..af19bb30593 100644 --- a/go/pools/smartconnpool/pool.go +++ b/go/pools/smartconnpool/pool.go @@ -388,7 +388,7 @@ func (pool *ConnPool[C]) recordWait(start time.Time) { // The connection must be returned to the pool once it's not needed by calling Pooled.Recycle func (pool *ConnPool[C]) Get(ctx context.Context, setting *Setting) (*Pooled[C], error) { if ctx.Err() != nil { - return nil, ErrCtxTimeout + return nil, context.Cause(ctx) } if pool.capacity.Load() == 0 { return nil, ErrConnPoolClosed @@ -594,10 +594,12 @@ func (pool *ConnPool[C]) get(ctx context.Context) (*Pooled[C], error) { if err != nil { return nil, err } + // if we don't have capacity, try popping a connection from any of the setting stacks if conn == nil { conn = pool.getFromSettingsStack(nil) } + // if there are no connections in the setting stacks and we've lent out connections // to other clients, wait until one of the connections is returned if conn == nil { @@ -610,14 +612,10 @@ func (pool *ConnPool[C]) get(ctx context.Context) (*Pooled[C], error) { conn, err = pool.wait.waitForConn(ctx, nil, *closeChan) if err != nil { - return nil, ErrTimeout + return nil, err } pool.recordWait(start) } - // no connections available and no connections to wait for (pool is closed) - if conn == nil { - return nil, ErrTimeout - } // if the connection we've acquired has a Setting applied, we must reset it before returning if conn.Conn.Setting() != nil { @@ -649,6 +647,7 @@ func (pool *ConnPool[C]) getWithSetting(ctx context.Context, setting *Setting) ( if conn == nil { conn = pool.pop(&pool.clean) } + // otherwise try opening a brand new connection and we'll apply the setting to it if conn == nil { conn, err = pool.getNew(ctx) @@ -656,11 +655,13 @@ func (pool *ConnPool[C]) getWithSetting(ctx context.Context, setting *Setting) ( return nil, err } } + // try on the _other_ setting stacks, even if we have to reset the Setting for the returned // connection if conn == nil { conn = pool.getFromSettingsStack(setting) } + // no connections anywhere in the pool; if we've lent out connections to other clients // wait for one of them if conn == nil { @@ -673,14 +674,10 @@ func (pool *ConnPool[C]) getWithSetting(ctx context.Context, setting *Setting) ( conn, err = pool.wait.waitForConn(ctx, setting, *closeChan) if err != nil { - return nil, ErrTimeout + return nil, err } pool.recordWait(start) } - // no connections available and no connections to wait for (pool is closed) - if conn == nil { - return nil, ErrTimeout - } // ensure that the setting applied to the connection matches the one we want connSetting := conn.Conn.Setting() diff --git a/go/pools/smartconnpool/pool_test.go b/go/pools/smartconnpool/pool_test.go index 24100864d25..a8640f8b98d 100644 --- a/go/pools/smartconnpool/pool_test.go +++ b/go/pools/smartconnpool/pool_test.go @@ -18,6 +18,7 @@ package smartconnpool import ( "context" + "errors" "fmt" "reflect" "sync" @@ -965,18 +966,31 @@ func TestTimeout(t *testing.T) { for _, setting := range []*Setting{nil, sFoo} { // trying to get the connection without a timeout. - newctx, cancel := context.WithTimeout(ctx, 10*time.Millisecond) - _, err = p.Get(newctx, setting) - cancel() - assert.EqualError(t, err, "connection pool timed out") + { + newctx, cancel := context.WithTimeout(ctx, 10*time.Millisecond) + defer cancel() + _, err = p.Get(newctx, setting) + assert.ErrorIs(t, err, context.DeadlineExceeded) + } + + // trying to get the connection with a timeout and a specific cause + { + expectedError := errors.New("test error") + + newctx, cancel := context.WithTimeoutCause(ctx, 10*time.Millisecond, expectedError) + defer cancel() + + _, err = p.Get(newctx, setting) + assert.ErrorIs(t, err, expectedError) + } } // put the connection take was taken initially. p.put(r) } -func TestExpired(t *testing.T) { +func TestGetWithExpiredContext(t *testing.T) { var state TestState p := NewPool(&Config[*TestConn]{ @@ -989,10 +1003,24 @@ func TestExpired(t *testing.T) { for _, setting := range []*Setting{nil, sFoo} { // expired context - ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-1*time.Second)) - _, err := p.Get(ctx, setting) - cancel() - require.EqualError(t, err, "connection pool context already expired") + { + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-1*time.Second)) + defer cancel() + + _, err := p.Get(ctx, setting) + require.ErrorIs(t, err, context.DeadlineExceeded) + } + + // context cancelled with cause + { + expectedError := errors.New("test error") + + ctx, cancel := context.WithCancelCause(context.Background()) + cancel(expectedError) + + _, err := p.Get(ctx, setting) + require.ErrorIs(t, err, expectedError) + } } } diff --git a/go/vt/vttablet/tabletserver/connpool/pool.go b/go/vt/vttablet/tabletserver/connpool/pool.go index 141d8257062..f49136741f0 100644 --- a/go/vt/vttablet/tabletserver/connpool/pool.go +++ b/go/vt/vttablet/tabletserver/connpool/pool.go @@ -134,7 +134,7 @@ func (cp *Pool) Get(ctx context.Context, setting *smartconnpool.Setting) (*Poole if cp.timeout != 0 { var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, cp.timeout) + ctx, cancel = context.WithTimeoutCause(ctx, cp.timeout, smartconnpool.ErrTimeout) defer cancel() } diff --git a/go/vt/vttablet/tabletserver/tx_pool.go b/go/vt/vttablet/tabletserver/tx_pool.go index cca44056608..398d7a8d6ce 100644 --- a/go/vt/vttablet/tabletserver/tx_pool.go +++ b/go/vt/vttablet/tabletserver/tx_pool.go @@ -300,19 +300,24 @@ func (tp *TxPool) begin(ctx context.Context, options *querypb.ExecuteOptions, re } func (tp *TxPool) createConn(ctx context.Context, options *querypb.ExecuteOptions, setting *smartconnpool.Setting) (*StatefulConnection, error) { + if ctx.Err() != nil { + tp.LogActive() + + errCode := vterrors.Code(smartconnpool.ErrCtxTimeout) + return nil, vterrors.Errorf(errCode, "transaction pool aborting request due to already expired context") + } + conn, err := tp.scp.NewConn(ctx, options, setting) if err != nil { - errCode := vterrors.Code(err) - switch err { - case smartconnpool.ErrCtxTimeout: + if err == smartconnpool.ErrTimeout { tp.LogActive() - err = vterrors.Errorf(errCode, "transaction pool aborting request due to already expired context") - case smartconnpool.ErrTimeout: - tp.LogActive() - err = vterrors.Errorf(errCode, "transaction pool connection limit exceeded") + errCode := vterrors.Code(smartconnpool.ErrTimeout) + return nil, vterrors.Errorf(errCode, "transaction pool connection limit exceeded") } + return nil, err } + return conn, nil }