Skip to content

Commit

Permalink
Adds timeout to mysql queries.
Browse files Browse the repository at this point in the history
* Also, it refactors default value to use infinity instead of setting to 0. I've
  seeing this pattern in gRPC codebase and I think is more succint. We don't
  need to special case "0"

Signed-off-by: Rafael Chacon <[email protected]>
  • Loading branch information
rafael committed May 16, 2018
1 parent c58c5ae commit 7399bde
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 14 deletions.
4 changes: 1 addition & 3 deletions go/mysql/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,7 @@ func (l *Listener) Accept() {
// handle is called in a go routine for each client connection.
// FIXME(alainjobart) handle per-connection logs in a way that makes sense.
func (l *Listener) handle(conn net.Conn, connectionID uint32, acceptTime time.Time) {
if l.connReadTimeout != 0 || l.connWriteTimeout != 0 {
conn = netutil.NewConnWithTimeouts(conn, l.connReadTimeout, l.connWriteTimeout)
}
conn = netutil.NewConnWithTimeouts(conn, l.connReadTimeout, l.connWriteTimeout)
c := newConn(conn)
c.ConnectionID = connectionID

Expand Down
14 changes: 9 additions & 5 deletions go/vt/vtgate/plugin_mysql_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package vtgate
import (
"flag"
"fmt"
"math"
"net"
"os"
"sync/atomic"
Expand All @@ -39,6 +40,7 @@ import (
)

var (
infinity = time.Duration(math.MaxInt64)
mysqlServerPort = flag.Int("mysql_server_port", -1, "If set, also listen for MySQL binary protocol connections on this port.")
mysqlServerBindAddress = flag.String("mysql_server_bind_address", "", "Binds on this address when listening to MySQL binary protocol. Useful to restrict listening to 'localhost' only for instance.")
mysqlServerSocketPath = flag.String("mysql_server_socket_path", "", "This option specifies the Unix socket file to use when listening for local connections. By default it will be empty and it won't listen to a unix socket")
Expand All @@ -53,8 +55,9 @@ var (

mysqlSlowConnectWarnThreshold = flag.Duration("mysql_slow_connect_warn_threshold", 0, "Warn if it takes more than the given threshold for a mysql connection to establish")

mysqlConnReadTimeout = flag.Duration("mysql_server_read_timeout", 0, "connection read timeout")
mysqlConnWriteTimeout = flag.Duration("mysql_server_write_timeout", 0, "connection write timeout")
mysqlConnReadTimeout = flag.Duration("mysql_server_read_timeout", infinity, "connection read timeout")
mysqlConnWriteTimeout = flag.Duration("mysql_server_write_timeout", infinity, "connection write timeout")
mysqlQueryTimeout = flag.Duration("mysql_server_query_timeout", infinity, "mysql query timeout")

busyConnections int32
)
Expand All @@ -77,7 +80,8 @@ func (vh *vtgateHandler) NewConnection(c *mysql.Conn) {

func (vh *vtgateHandler) ConnectionClosed(c *mysql.Conn) {
// Rollback if there is an ongoing transaction. Ignore error.
ctx := context.Background()
ctx, cancel := context.WithTimeout(context.Background(), *mysqlQueryTimeout)
defer cancel()
session, _ := c.ClientData.(*vtgatepb.Session)
if session != nil {
if session.InTransaction {
Expand All @@ -88,8 +92,8 @@ func (vh *vtgateHandler) ConnectionClosed(c *mysql.Conn) {
}

func (vh *vtgateHandler) ComQuery(c *mysql.Conn, query string, callback func(*sqltypes.Result) error) error {
// FIXME(alainjobart): Add some kind of timeout to the context.
ctx := context.Background()
ctx, cancel := context.WithTimeout(context.Background(), *mysqlQueryTimeout)
defer cancel()

// Fill in the ImmediateCallerID with the UserData returned by
// the AuthServer plugin for that user. If nothing was
Expand Down
15 changes: 10 additions & 5 deletions go/vt/vtqueryserver/plugin_mysql_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@ package vtqueryserver
import (
"flag"
"fmt"
"math"
"net"
"os"
"syscall"
"time"

"golang.org/x/net/context"

Expand All @@ -37,6 +39,7 @@ import (
)

var (
infinity = time.Duration(math.MaxInt64)
mysqlServerPort = flag.Int("mysqlproxy_server_port", -1, "If set, also listen for MySQL binary protocol connections on this port.")
mysqlServerBindAddress = flag.String("mysqlproxy_server_bind_address", "", "Binds on this address when listening to MySQL binary protocol. Useful to restrict listening to 'localhost' only for instance.")
mysqlServerSocketPath = flag.String("mysqlproxy_server_socket_path", "", "This option specifies the Unix socket file to use when listening for local connections. By default it will be empty and it won't listen to a unix socket")
Expand All @@ -49,8 +52,9 @@ var (

mysqlSlowConnectWarnThreshold = flag.Duration("mysqlproxy_slow_connect_warn_threshold", 0, "Warn if it takes more than the given threshold for a mysql connection to establish")

mysqlConnReadTimeout = flag.Duration("mysql_server_read_timeout", 0, "connection read timeout")
mysqlConnWriteTimeout = flag.Duration("mysql_server_write_timeout", 0, "connection write timeout")
mysqlConnReadTimeout = flag.Duration("mysql_server_read_timeout", infinity, "connection read timeout")
mysqlConnWriteTimeout = flag.Duration("mysql_server_write_timeout", infinity, "connection write timeout")
mysqlQueryTimeout = flag.Duration("mysql_server_query_timeout", infinity, "mysql query timeout")
)

// proxyHandler implements the Listener interface.
Expand All @@ -71,16 +75,17 @@ func (mh *proxyHandler) NewConnection(c *mysql.Conn) {

func (mh *proxyHandler) ConnectionClosed(c *mysql.Conn) {
// Rollback if there is an ongoing transaction. Ignore error.
ctx := context.Background()
ctx, cancel := context.WithTimeout(context.Background(), *mysqlQueryTimeout)
defer cancel()
session, _ := c.ClientData.(*mysqlproxy.ProxySession)
if session != nil && session.TransactionID != 0 {
_ = mh.mp.Rollback(ctx, session)
}
}

func (mh *proxyHandler) ComQuery(c *mysql.Conn, query string, callback func(*sqltypes.Result) error) error {
// FIXME(alainjobart): Add some kind of timeout to the context.
ctx := context.Background()
ctx, cancel := context.WithTimeout(context.Background(), *mysqlQueryTimeout)
defer cancel()

// Fill in the ImmediateCallerID with the UserData returned by
// the AuthServer plugin for that user. If nothing was
Expand Down
15 changes: 14 additions & 1 deletion test/mysql_server_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def test_mysql_connector(self):
fd.write("""{
"table_groups": [
{
"table_names_or_prefixes": ["vt_insert_test"],
"table_names_or_prefixes": ["vt_insert_test", "dual"],
"readers": ["vtgate client 1"],
"writers": ["vtgate client 1"],
"admins": ["vtgate client 1"]
Expand Down Expand Up @@ -143,6 +143,7 @@ def test_mysql_connector(self):
# start vtgate
utils.VtGate(mysql_server=True).start(
extra_args=['-mysql_auth_server_impl', 'static',
'-mysql_server_query_timeout', '1s',
'-mysql_auth_server_static_file', mysql_auth_server_static])
# We use gethostbyname('localhost') so we don't presume
# of the IP format (travis is only IP v4, really).
Expand Down Expand Up @@ -189,6 +190,18 @@ def test_mysql_connector(self):

conn.close()

# 'vtgate client' this query should timeout
conn = MySQLdb.Connect(**params)
try:
cursor = conn.cursor()
cursor.execute('SELECT SLEEP(5)', {})
self.fail('Execute went through')
except MySQLdb.OperationalError, e:
s = str(e)
# 1317 is DeadlineExceeded error code
self.assertIn('1317', s)
conn.close()

# 'vtgate client 2' is not authorized to access vt_insert_test
params['user'] = 'testuser2'
params['passwd'] = 'testpassword2'
Expand Down

0 comments on commit 7399bde

Please sign in to comment.