diff --git a/go/mysql/server.go b/go/mysql/server.go index 6b634b9c8cc..0380f58e688 100644 --- a/go/mysql/server.go +++ b/go/mysql/server.go @@ -174,9 +174,7 @@ func (l *Listener) Accept() { // handle is called in a go routine for each client connection. // FIXME(alainjobart) handle per-connection logs in a way that makes sense. func (l *Listener) handle(conn net.Conn, connectionID uint32, acceptTime time.Time) { - if l.connReadTimeout != 0 || l.connWriteTimeout != 0 { - conn = netutil.NewConnWithTimeouts(conn, l.connReadTimeout, l.connWriteTimeout) - } + conn = netutil.NewConnWithTimeouts(conn, l.connReadTimeout, l.connWriteTimeout) c := newConn(conn) c.ConnectionID = connectionID diff --git a/go/vt/vtgate/plugin_mysql_server.go b/go/vt/vtgate/plugin_mysql_server.go index 34451f9ed05..b6790cb71fc 100644 --- a/go/vt/vtgate/plugin_mysql_server.go +++ b/go/vt/vtgate/plugin_mysql_server.go @@ -19,6 +19,7 @@ package vtgate import ( "flag" "fmt" + "math" "net" "os" "sync/atomic" @@ -39,6 +40,7 @@ import ( ) var ( + infinity = time.Duration(math.MaxInt64) mysqlServerPort = flag.Int("mysql_server_port", -1, "If set, also listen for MySQL binary protocol connections on this port.") mysqlServerBindAddress = flag.String("mysql_server_bind_address", "", "Binds on this address when listening to MySQL binary protocol. Useful to restrict listening to 'localhost' only for instance.") mysqlServerSocketPath = flag.String("mysql_server_socket_path", "", "This option specifies the Unix socket file to use when listening for local connections. By default it will be empty and it won't listen to a unix socket") @@ -53,8 +55,9 @@ var ( mysqlSlowConnectWarnThreshold = flag.Duration("mysql_slow_connect_warn_threshold", 0, "Warn if it takes more than the given threshold for a mysql connection to establish") - mysqlConnReadTimeout = flag.Duration("mysql_server_read_timeout", 0, "connection read timeout") - mysqlConnWriteTimeout = flag.Duration("mysql_server_write_timeout", 0, "connection write timeout") + mysqlConnReadTimeout = flag.Duration("mysql_server_read_timeout", infinity, "connection read timeout") + mysqlConnWriteTimeout = flag.Duration("mysql_server_write_timeout", infinity, "connection write timeout") + mysqlQueryTimeout = flag.Duration("mysql_server_query_timeout", infinity, "mysql query timeout") busyConnections int32 ) @@ -77,7 +80,8 @@ func (vh *vtgateHandler) NewConnection(c *mysql.Conn) { func (vh *vtgateHandler) ConnectionClosed(c *mysql.Conn) { // Rollback if there is an ongoing transaction. Ignore error. - ctx := context.Background() + ctx, cancel := context.WithTimeout(context.Background(), *mysqlQueryTimeout) + defer cancel() session, _ := c.ClientData.(*vtgatepb.Session) if session != nil { if session.InTransaction { @@ -88,8 +92,8 @@ func (vh *vtgateHandler) ConnectionClosed(c *mysql.Conn) { } func (vh *vtgateHandler) ComQuery(c *mysql.Conn, query string, callback func(*sqltypes.Result) error) error { - // FIXME(alainjobart): Add some kind of timeout to the context. - ctx := context.Background() + ctx, cancel := context.WithTimeout(context.Background(), *mysqlQueryTimeout) + defer cancel() // Fill in the ImmediateCallerID with the UserData returned by // the AuthServer plugin for that user. If nothing was diff --git a/go/vt/vtqueryserver/plugin_mysql_server.go b/go/vt/vtqueryserver/plugin_mysql_server.go index f0400a88c3e..530a6110eea 100644 --- a/go/vt/vtqueryserver/plugin_mysql_server.go +++ b/go/vt/vtqueryserver/plugin_mysql_server.go @@ -19,9 +19,11 @@ package vtqueryserver import ( "flag" "fmt" + "math" "net" "os" "syscall" + "time" "golang.org/x/net/context" @@ -37,6 +39,7 @@ import ( ) var ( + infinity = time.Duration(math.MaxInt64) mysqlServerPort = flag.Int("mysqlproxy_server_port", -1, "If set, also listen for MySQL binary protocol connections on this port.") mysqlServerBindAddress = flag.String("mysqlproxy_server_bind_address", "", "Binds on this address when listening to MySQL binary protocol. Useful to restrict listening to 'localhost' only for instance.") mysqlServerSocketPath = flag.String("mysqlproxy_server_socket_path", "", "This option specifies the Unix socket file to use when listening for local connections. By default it will be empty and it won't listen to a unix socket") @@ -49,8 +52,9 @@ var ( mysqlSlowConnectWarnThreshold = flag.Duration("mysqlproxy_slow_connect_warn_threshold", 0, "Warn if it takes more than the given threshold for a mysql connection to establish") - mysqlConnReadTimeout = flag.Duration("mysql_server_read_timeout", 0, "connection read timeout") - mysqlConnWriteTimeout = flag.Duration("mysql_server_write_timeout", 0, "connection write timeout") + mysqlConnReadTimeout = flag.Duration("mysql_server_read_timeout", infinity, "connection read timeout") + mysqlConnWriteTimeout = flag.Duration("mysql_server_write_timeout", infinity, "connection write timeout") + mysqlQueryTimeout = flag.Duration("mysql_server_query_timeout", infinity, "mysql query timeout") ) // proxyHandler implements the Listener interface. @@ -71,7 +75,8 @@ func (mh *proxyHandler) NewConnection(c *mysql.Conn) { func (mh *proxyHandler) ConnectionClosed(c *mysql.Conn) { // Rollback if there is an ongoing transaction. Ignore error. - ctx := context.Background() + ctx, cancel := context.WithTimeout(context.Background(), *mysqlQueryTimeout) + defer cancel() session, _ := c.ClientData.(*mysqlproxy.ProxySession) if session != nil && session.TransactionID != 0 { _ = mh.mp.Rollback(ctx, session) @@ -79,8 +84,8 @@ func (mh *proxyHandler) ConnectionClosed(c *mysql.Conn) { } func (mh *proxyHandler) ComQuery(c *mysql.Conn, query string, callback func(*sqltypes.Result) error) error { - // FIXME(alainjobart): Add some kind of timeout to the context. - ctx := context.Background() + ctx, cancel := context.WithTimeout(context.Background(), *mysqlQueryTimeout) + defer cancel() // Fill in the ImmediateCallerID with the UserData returned by // the AuthServer plugin for that user. If nothing was diff --git a/test/mysql_server_test.py b/test/mysql_server_test.py index 6f07e804ed4..3cd11cf148d 100755 --- a/test/mysql_server_test.py +++ b/test/mysql_server_test.py @@ -104,7 +104,7 @@ def test_mysql_connector(self): fd.write("""{ "table_groups": [ { - "table_names_or_prefixes": ["vt_insert_test"], + "table_names_or_prefixes": ["vt_insert_test", "dual"], "readers": ["vtgate client 1"], "writers": ["vtgate client 1"], "admins": ["vtgate client 1"] @@ -143,6 +143,7 @@ def test_mysql_connector(self): # start vtgate utils.VtGate(mysql_server=True).start( extra_args=['-mysql_auth_server_impl', 'static', + '-mysql_server_query_timeout', '1s', '-mysql_auth_server_static_file', mysql_auth_server_static]) # We use gethostbyname('localhost') so we don't presume # of the IP format (travis is only IP v4, really). @@ -189,6 +190,18 @@ def test_mysql_connector(self): conn.close() + # 'vtgate client' this query should timeout + conn = MySQLdb.Connect(**params) + try: + cursor = conn.cursor() + cursor.execute('SELECT SLEEP(5)', {}) + self.fail('Execute went through') + except MySQLdb.OperationalError, e: + s = str(e) + # 1317 is DeadlineExceeded error code + self.assertIn('1317', s) + conn.close() + # 'vtgate client 2' is not authorized to access vt_insert_test params['user'] = 'testuser2' params['passwd'] = 'testpassword2'