diff --git a/go/vt/vtgate/plugin_mysql_server.go b/go/vt/vtgate/plugin_mysql_server.go index 34451f9ed05..194f04a428b 100644 --- a/go/vt/vtgate/plugin_mysql_server.go +++ b/go/vt/vtgate/plugin_mysql_server.go @@ -55,6 +55,7 @@ var ( mysqlConnReadTimeout = flag.Duration("mysql_server_read_timeout", 0, "connection read timeout") mysqlConnWriteTimeout = flag.Duration("mysql_server_write_timeout", 0, "connection write timeout") + mysqlQueryTimeout = flag.Duration("mysql_server_query_timeout", 0, "mysql query timeout") busyConnections int32 ) @@ -77,7 +78,14 @@ func (vh *vtgateHandler) NewConnection(c *mysql.Conn) { func (vh *vtgateHandler) ConnectionClosed(c *mysql.Conn) { // Rollback if there is an ongoing transaction. Ignore error. - ctx := context.Background() + var ctx context.Context + var cancel context.CancelFunc + if *mysqlQueryTimeout != 0 { + ctx, cancel = context.WithTimeout(context.Background(), *mysqlQueryTimeout) + defer cancel() + } else { + ctx = context.Background() + } session, _ := c.ClientData.(*vtgatepb.Session) if session != nil { if session.InTransaction { @@ -88,8 +96,14 @@ func (vh *vtgateHandler) ConnectionClosed(c *mysql.Conn) { } func (vh *vtgateHandler) ComQuery(c *mysql.Conn, query string, callback func(*sqltypes.Result) error) error { - // FIXME(alainjobart): Add some kind of timeout to the context. - ctx := context.Background() + var ctx context.Context + var cancel context.CancelFunc + if *mysqlQueryTimeout != 0 { + ctx, cancel = context.WithTimeout(context.Background(), *mysqlQueryTimeout) + defer cancel() + } else { + ctx = context.Background() + } // Fill in the ImmediateCallerID with the UserData returned by // the AuthServer plugin for that user. If nothing was diff --git a/go/vt/vtqueryserver/plugin_mysql_server.go b/go/vt/vtqueryserver/plugin_mysql_server.go index f0400a88c3e..ab0327b6273 100644 --- a/go/vt/vtqueryserver/plugin_mysql_server.go +++ b/go/vt/vtqueryserver/plugin_mysql_server.go @@ -51,6 +51,7 @@ var ( mysqlConnReadTimeout = flag.Duration("mysql_server_read_timeout", 0, "connection read timeout") mysqlConnWriteTimeout = flag.Duration("mysql_server_write_timeout", 0, "connection write timeout") + mysqlQueryTimeout = flag.Duration("mysql_server_query_timeout", 0, "mysql query timeout") ) // proxyHandler implements the Listener interface. @@ -71,7 +72,14 @@ func (mh *proxyHandler) NewConnection(c *mysql.Conn) { func (mh *proxyHandler) ConnectionClosed(c *mysql.Conn) { // Rollback if there is an ongoing transaction. Ignore error. - ctx := context.Background() + var ctx context.Context + var cancel context.CancelFunc + if *mysqlQueryTimeout != 0 { + ctx, cancel = context.WithTimeout(context.Background(), *mysqlQueryTimeout) + defer cancel() + } else { + ctx = context.Background() + } session, _ := c.ClientData.(*mysqlproxy.ProxySession) if session != nil && session.TransactionID != 0 { _ = mh.mp.Rollback(ctx, session) @@ -79,9 +87,14 @@ func (mh *proxyHandler) ConnectionClosed(c *mysql.Conn) { } func (mh *proxyHandler) ComQuery(c *mysql.Conn, query string, callback func(*sqltypes.Result) error) error { - // FIXME(alainjobart): Add some kind of timeout to the context. - ctx := context.Background() - + var ctx context.Context + var cancel context.CancelFunc + if *mysqlQueryTimeout != 0 { + ctx, cancel = context.WithTimeout(context.Background(), *mysqlQueryTimeout) + defer cancel() + } else { + ctx = context.Background() + } // Fill in the ImmediateCallerID with the UserData returned by // the AuthServer plugin for that user. If nothing was // returned, use the User. This lets the plugin map a MySQL diff --git a/test/mysql_server_test.py b/test/mysql_server_test.py index 6f07e804ed4..3cd11cf148d 100755 --- a/test/mysql_server_test.py +++ b/test/mysql_server_test.py @@ -104,7 +104,7 @@ def test_mysql_connector(self): fd.write("""{ "table_groups": [ { - "table_names_or_prefixes": ["vt_insert_test"], + "table_names_or_prefixes": ["vt_insert_test", "dual"], "readers": ["vtgate client 1"], "writers": ["vtgate client 1"], "admins": ["vtgate client 1"] @@ -143,6 +143,7 @@ def test_mysql_connector(self): # start vtgate utils.VtGate(mysql_server=True).start( extra_args=['-mysql_auth_server_impl', 'static', + '-mysql_server_query_timeout', '1s', '-mysql_auth_server_static_file', mysql_auth_server_static]) # We use gethostbyname('localhost') so we don't presume # of the IP format (travis is only IP v4, really). @@ -189,6 +190,18 @@ def test_mysql_connector(self): conn.close() + # 'vtgate client' this query should timeout + conn = MySQLdb.Connect(**params) + try: + cursor = conn.cursor() + cursor.execute('SELECT SLEEP(5)', {}) + self.fail('Execute went through') + except MySQLdb.OperationalError, e: + s = str(e) + # 1317 is DeadlineExceeded error code + self.assertIn('1317', s) + conn.close() + # 'vtgate client 2' is not authorized to access vt_insert_test params['user'] = 'testuser2' params['passwd'] = 'testpassword2'