Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 8 additions & 11 deletions go/pools/smartconnpool/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ func (pool *ConnPool[C]) recordWait(start time.Time) {
// The connection must be returned to the pool once it's not needed by calling Pooled.Recycle
func (pool *ConnPool[C]) Get(ctx context.Context, setting *Setting) (*Pooled[C], error) {
if ctx.Err() != nil {
return nil, ErrCtxTimeout
return nil, context.Cause(ctx)
}
if pool.capacity.Load() == 0 {
return nil, ErrConnPoolClosed
Expand Down Expand Up @@ -594,10 +594,12 @@ func (pool *ConnPool[C]) get(ctx context.Context) (*Pooled[C], error) {
if err != nil {
return nil, err
}

// if we don't have capacity, try popping a connection from any of the setting stacks
if conn == nil {
conn = pool.getFromSettingsStack(nil)
}

// if there are no connections in the setting stacks and we've lent out connections
// to other clients, wait until one of the connections is returned
if conn == nil {
Expand All @@ -610,14 +612,10 @@ func (pool *ConnPool[C]) get(ctx context.Context) (*Pooled[C], error) {

conn, err = pool.wait.waitForConn(ctx, nil, *closeChan)
if err != nil {
return nil, ErrTimeout
return nil, err
}
pool.recordWait(start)
}
// no connections available and no connections to wait for (pool is closed)
if conn == nil {
return nil, ErrTimeout
}

// if the connection we've acquired has a Setting applied, we must reset it before returning
if conn.Conn.Setting() != nil {
Expand Down Expand Up @@ -649,18 +647,21 @@ func (pool *ConnPool[C]) getWithSetting(ctx context.Context, setting *Setting) (
if conn == nil {
conn = pool.pop(&pool.clean)
}

// otherwise try opening a brand new connection and we'll apply the setting to it
if conn == nil {
conn, err = pool.getNew(ctx)
if err != nil {
return nil, err
}
}

// try on the _other_ setting stacks, even if we have to reset the Setting for the returned
// connection
if conn == nil {
conn = pool.getFromSettingsStack(setting)
}

// no connections anywhere in the pool; if we've lent out connections to other clients
// wait for one of them
if conn == nil {
Expand All @@ -673,14 +674,10 @@ func (pool *ConnPool[C]) getWithSetting(ctx context.Context, setting *Setting) (

conn, err = pool.wait.waitForConn(ctx, setting, *closeChan)
if err != nil {
return nil, ErrTimeout
return nil, err
}
pool.recordWait(start)
}
// no connections available and no connections to wait for (pool is closed)
if conn == nil {
return nil, ErrTimeout
}

// ensure that the setting applied to the connection matches the one we want
connSetting := conn.Conn.Setting()
Expand Down
46 changes: 37 additions & 9 deletions go/pools/smartconnpool/pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package smartconnpool

import (
"context"
"errors"
"fmt"
"reflect"
"sync"
Expand Down Expand Up @@ -965,18 +966,31 @@ func TestTimeout(t *testing.T) {

for _, setting := range []*Setting{nil, sFoo} {
// trying to get the connection without a timeout.
newctx, cancel := context.WithTimeout(ctx, 10*time.Millisecond)
_, err = p.Get(newctx, setting)
cancel()
assert.EqualError(t, err, "connection pool timed out")
{
newctx, cancel := context.WithTimeout(ctx, 10*time.Millisecond)
defer cancel()

_, err = p.Get(newctx, setting)
assert.ErrorIs(t, err, context.DeadlineExceeded)
}

// trying to get the connection with a timeout and a specific cause
{
expectedError := errors.New("test error")

newctx, cancel := context.WithTimeoutCause(ctx, 10*time.Millisecond, expectedError)
defer cancel()

_, err = p.Get(newctx, setting)
assert.ErrorIs(t, err, expectedError)
}
}

// put the connection take was taken initially.
p.put(r)
}

func TestExpired(t *testing.T) {
func TestGetWithExpiredContext(t *testing.T) {
var state TestState

p := NewPool(&Config[*TestConn]{
Expand All @@ -989,10 +1003,24 @@ func TestExpired(t *testing.T) {

for _, setting := range []*Setting{nil, sFoo} {
// expired context
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-1*time.Second))
_, err := p.Get(ctx, setting)
cancel()
require.EqualError(t, err, "connection pool context already expired")
{
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-1*time.Second))
defer cancel()

_, err := p.Get(ctx, setting)
require.ErrorIs(t, err, context.DeadlineExceeded)
}

// context cancelled with cause
{
expectedError := errors.New("test error")

ctx, cancel := context.WithCancelCause(context.Background())
cancel(expectedError)

_, err := p.Get(ctx, setting)
require.ErrorIs(t, err, expectedError)
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion go/vt/vttablet/tabletserver/connpool/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ func (cp *Pool) Get(ctx context.Context, setting *smartconnpool.Setting) (*Poole

if cp.timeout != 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, cp.timeout)
ctx, cancel = context.WithTimeoutCause(ctx, cp.timeout, smartconnpool.ErrTimeout)
defer cancel()
}

Expand Down
19 changes: 12 additions & 7 deletions go/vt/vttablet/tabletserver/tx_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -300,19 +300,24 @@ func (tp *TxPool) begin(ctx context.Context, options *querypb.ExecuteOptions, re
}

func (tp *TxPool) createConn(ctx context.Context, options *querypb.ExecuteOptions, setting *smartconnpool.Setting) (*StatefulConnection, error) {
if ctx.Err() != nil {
tp.LogActive()

errCode := vterrors.Code(smartconnpool.ErrCtxTimeout)
return nil, vterrors.Errorf(errCode, "transaction pool aborting request due to already expired context")
}

conn, err := tp.scp.NewConn(ctx, options, setting)
if err != nil {
errCode := vterrors.Code(err)
switch err {
case smartconnpool.ErrCtxTimeout:
if err == smartconnpool.ErrTimeout {
tp.LogActive()
err = vterrors.Errorf(errCode, "transaction pool aborting request due to already expired context")
case smartconnpool.ErrTimeout:
tp.LogActive()
err = vterrors.Errorf(errCode, "transaction pool connection limit exceeded")
errCode := vterrors.Code(smartconnpool.ErrTimeout)
return nil, vterrors.Errorf(errCode, "transaction pool connection limit exceeded")
}

return nil, err
}

return conn, nil
}

Expand Down
Loading