From fae71ea2dbbd9cee5bf637d954a4360f811223f4 Mon Sep 17 00:00:00 2001 From: kolzeq Date: Thu, 4 Jun 2015 13:20:03 +0200 Subject: [PATCH] CONJ-152 - rewriteBatchedStatements and multiple executeBatch error --- .../org/mariadb/jdbc/MySQLConnection.java | 13 -- .../mariadb/jdbc/MySQLPreparedStatement.java | 67 ++--------- .../java/org/mariadb/jdbc/MySQLStatement.java | 112 +++++++++--------- .../packet/MaxAllowedPacketException.java | 18 +++ .../common/packet/PacketOutputStream.java | 22 +++- .../packet/commands/StreamedQueryPacket.java | 39 ++++-- .../common/query/MySQLParameterizedQuery.java | 17 ++- .../internal/common/query/MySQLQuery.java | 34 +++--- .../jdbc/internal/common/query/Query.java | 1 + .../jdbc/internal/mysql/MySQLProtocol.java | 60 ++++++---- src/test/java/org/mariadb/jdbc/BaseTest.java | 2 +- .../java/org/mariadb/jdbc/ConnectionTest.java | 93 +++++++++++---- src/test/java/org/mariadb/jdbc/MultiTest.java | 5 +- 13 files changed, 282 insertions(+), 201 deletions(-) create mode 100644 src/main/java/org/mariadb/jdbc/internal/common/packet/MaxAllowedPacketException.java diff --git a/src/main/java/org/mariadb/jdbc/MySQLConnection.java b/src/main/java/org/mariadb/jdbc/MySQLConnection.java index f99fa1f7b..1e2182141 100644 --- a/src/main/java/org/mariadb/jdbc/MySQLConnection.java +++ b/src/main/java/org/mariadb/jdbc/MySQLConnection.java @@ -124,19 +124,6 @@ public static MySQLConnection newConnection(MySQLProtocol protocol) throws SQLEx connection.nullCatalogMeansCurrent = false; } - Statement st = null; - try { - st = connection.createStatement(); - if (sessionVariables != null) { - st.executeUpdate("set session " + sessionVariables); - } - ResultSet rs = st.executeQuery("show variables like 'max_allowed_packet'"); - rs.next(); - protocol.setMaxAllowedPacket(Integer.parseInt(rs.getString(2))); - } finally { - if (st != null) - st.close(); - } return connection; } diff --git a/src/main/java/org/mariadb/jdbc/MySQLPreparedStatement.java b/src/main/java/org/mariadb/jdbc/MySQLPreparedStatement.java index 4526b474a..63a15d02a 100644 --- a/src/main/java/org/mariadb/jdbc/MySQLPreparedStatement.java +++ b/src/main/java/org/mariadb/jdbc/MySQLPreparedStatement.java @@ -50,8 +50,7 @@ WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWIS import org.mariadb.jdbc.internal.SQLExceptionMapper; import org.mariadb.jdbc.internal.common.Utils; -import org.mariadb.jdbc.internal.common.query.IllegalParameterException; -import org.mariadb.jdbc.internal.common.query.MySQLParameterizedQuery; +import org.mariadb.jdbc.internal.common.query.*; import org.mariadb.jdbc.internal.common.query.parameters.*; import java.io.IOException; @@ -75,7 +74,6 @@ public class MySQLPreparedStatement extends MySQLStatement implements PreparedSt private String sql; private boolean useFractionalSeconds; boolean parametersCleared; - List batchPreparedStatements; public MySQLPreparedStatement(MySQLConnection connection, @@ -92,13 +90,6 @@ public MySQLPreparedStatement(MySQLConnection connection, parametersCleared = true; } - private MySQLPreparedStatement (MySQLConnection connection, String sql, MySQLParameterizedQuery dQuery, boolean useFractionalSeconds ) { - super(connection); - this.dQuery = dQuery.cloneQuery(); - this.sql = sql; - this.useFractionalSeconds = useFractionalSeconds; - } - /** * Executes the SQL query in this PreparedStatement object * and returns the ResultSet object generated by the query. @@ -192,64 +183,30 @@ public void setNull(final int parameterIndex, final int sqlType) throws SQLExcep */ public void addBatch() throws SQLException { checkBatchFields(); - batchPreparedStatements.add(new MySQLPreparedStatement(connection,sql, dQuery, useFractionalSeconds)); - + batchQueries.add(dQuery.cloneQuery()); + isInsertRewriteable(dQuery.getQuery()); } public void addBatch(final String sql) throws SQLException { checkBatchFields(); - batchPreparedStatements.add(new MySQLPreparedStatement(connection, sql)); + isInsertRewriteable(sql); + batchQueries.add(new MySQLQuery(sql)); } private void checkBatchFields() { - if (batchPreparedStatements == null) { - batchPreparedStatements = new ArrayList(); + if (batchQueries == null) { + batchQueries = new ArrayList(); } } public void clearBatch() { - if (batchPreparedStatements != null) { - batchPreparedStatements.clear(); - } - } - - - @Override - public int[] executeBatch() throws SQLException { - if (batchPreparedStatements == null || batchPreparedStatements.isEmpty()) { - return new int[0]; + if (batchQueries != null) { + batchQueries.clear(); } - int[] ret = new int[batchPreparedStatements.size()]; - int i = 0; - MySQLResultSet rs = null; - try { - synchronized (this.getProtocol()) { - for (; i < batchPreparedStatements.size(); i++) { - PreparedStatement ps = batchPreparedStatements.get(i); - ps.execute(); - int updateCount = ps.getUpdateCount(); - if (updateCount == -1) { - ret[i] = SUCCESS_NO_INFO; - } else { - ret[i] = updateCount; - } - if (i == 0) { - rs = (MySQLResultSet)ps.getGeneratedKeys(); - } else { - rs = rs.joinResultSets((MySQLResultSet)ps.getGeneratedKeys()); - } - } - } - } catch (SQLException sqle) { - throw new BatchUpdateException(sqle.getMessage(), sqle.getSQLState(), sqle.getErrorCode(), Arrays.copyOf(ret, i), sqle); - } finally { - clearBatch(); - } - batchResultSet = rs; - return ret; + firstRewrite = null; + isRewriteable = true; } - - /** + /** * Sets the designated parameter to the given Reader object, which is the given number of characters * long. When a very large UNICODE value is input to a LONGVARCHAR parameter, it may be more practical * to send it via a java.io.Reader object. The data will be read from the stream as needed until diff --git a/src/main/java/org/mariadb/jdbc/MySQLStatement.java b/src/main/java/org/mariadb/jdbc/MySQLStatement.java index b153016a1..3489c68e7 100644 --- a/src/main/java/org/mariadb/jdbc/MySQLStatement.java +++ b/src/main/java/org/mariadb/jdbc/MySQLStatement.java @@ -96,10 +96,10 @@ public class MySQLStatement implements Statement { boolean isTimedout; volatile boolean executing; - List batchQueries; + List batchQueries; Queue cachedResultSets; - private boolean isRewriteable = true; - private String firstRewrite = null; + protected boolean isRewriteable = true; + protected String firstRewrite = null; protected ResultSet batchResultSet = null; @@ -299,6 +299,45 @@ protected boolean execute(Query query) throws SQLException { } } + /** + * Execute statements. if many queries, those queries will be rewritten + * if isRewritable = false, the query will be agreggated : + * INSERT INTO jdbc (`name`) VALUES ('Line 1: Lorem ipsum ...') + * INSERT INTO jdbc (`name`) VALUES ('Line 2: Lorem ipsum ...') + * will be agreggate as + * INSERT INTO jdbc (`name`) VALUES ('Line 1: Lorem ipsum ...');INSERT INTO jdbc (`name`) VALUES ('Line 2: Lorem ipsum ...') + * and if isRewritable, agreggated as + * INSERT INTO jdbc (`name`) VALUES ('Line 1: Lorem ipsum ...'),('Line 2: Lorem ipsum ...') + * @param queries list of queries + * @param isRewritable are the queries of the same type to be agreggated + * @param rewriteOffset offset of the parameter if query are similar + * @return true if there was a result set, false otherwise. + * @throws SQLException + */ + protected boolean execute(List queries, boolean isRewritable, int rewriteOffset) throws SQLException { + //System.out.println(query); + synchronized (protocol) { + if (protocol.activeResult != null) { + protocol.activeResult.close(); + } + executing = true; + QueryException exception = null; + executeQueryProlog(); + try { + batchResultSet = null; + queryResult = protocol.executeQuery(queries, isStreaming(), isRewritable, rewriteOffset); + cacheMoreResults(); + return (queryResult.getResultSetType() == ResultSetType.SELECT); + } catch (QueryException e) { + exception = e; + return false; + } finally { + executeQueryEpilog(exception, queries.get(0)); + executing = false; + } + } + } + /** * executes a select query. * @@ -1095,17 +1134,17 @@ public int getResultSetType() throws SQLException { */ public void addBatch(final String sql) throws SQLException { if (batchQueries == null) { - batchQueries = new ArrayList(); + batchQueries = new ArrayList(); } - batchQueries.add(sql); isInsertRewriteable(sql); + batchQueries.add(new MySQLQuery(sql)); } /** * Parses the sql string to understand whether it is compatible with rewritten batches. * @param sql the sql string */ - private void isInsertRewriteable(String sql) { + protected void isInsertRewriteable(String sql) { if (!isRewriteable) { return; } @@ -1157,24 +1196,7 @@ protected int getInsertIncipit(String sql) { return startBracket; } - - /** - * If the batch array contains only rewriteable sql strings, returns the rewritten statement. - * @return the rewritten statement - */ - private String rewrittenBatch() { - StringBuilder result = null; - if(isRewriteable) { - result = new StringBuilder(""); - result.append(firstRewrite); - for (String query : batchQueries) { - result.append(query.substring(getInsertIncipit(query))); - result.append(","); - } - result.deleteCharAt(result.length() - 1); - } - return (result == null ? null : result.toString()); - } + /** @@ -1231,18 +1253,23 @@ public void clearBatch() throws SQLException { * @since 1.3 */ public int[] executeBatch() throws SQLException { - if (batchQueries == null) - return new int[0]; + if (batchQueries == null || batchQueries.size() == 0) return new int[0]; int[] ret = new int[batchQueries.size()]; int i = 0; MySQLResultSet rs = null; + + boolean allowMultiQueries = "true".equals(getProtocol().getInfo().getProperty("allowMultiQueries")); + boolean rewriteBatchedStatements = "true".equals(getProtocol().getInfo().getProperty("rewriteBatchedStatements")); + if (rewriteBatchedStatements) allowMultiQueries=true; try { synchronized (this.protocol) { - if (getProtocol().getInfo().getProperty("rewriteBatchedStatements") != null - && "true".equalsIgnoreCase(getProtocol().getInfo().getProperty("rewriteBatchedStatements"))) { - ret = executeBatchAsMultiQueries(); - } else { + if (allowMultiQueries) { + int size = batchQueries.size(); + MySQLStatement ps = (MySQLStatement) connection.createStatement(); + ps.execute(batchQueries, isRewriteable && rewriteBatchedStatements, (isRewriteable && rewriteBatchedStatements)?firstRewrite.length():0); + return isRewriteable?getUpdateCountsForReWrittenBatch(ps, size):getUpdateCounts(ps, size); + } else { for(; i < batchQueries.size(); i++) { execute(batchQueries.get(i)); int updateCount = getUpdateCount(); @@ -1267,31 +1294,6 @@ public int[] executeBatch() throws SQLException { batchResultSet = rs; return ret; } - - /** - * Builds a new statement which contains the batched Statements and executes it. - * @return an array of update counts containing one element for each command in the batch. - * The elements of the array are ordered according to the order in which commands were added to the batch. - * @throws SQLException - */ - private int[] executeBatchAsMultiQueries() throws SQLException { - int i = 0; - StringBuilder stringBuilder = new StringBuilder(); - String rewrite = rewrittenBatch(); - boolean rewrittenBatch = rewrite != null; - if (rewrittenBatch) { - stringBuilder.append(rewrite); - i = batchQueries.size(); - } else { - for (; i < batchQueries.size(); i++) { - stringBuilder.append(batchQueries.get(i) + ";"); - } - } - Statement ps = connection.createStatement(); - ps.execute(stringBuilder.toString()); - return rewrittenBatch ? getUpdateCountsForReWrittenBatch(ps, i) : getUpdateCounts(ps, i); - } - /** * Retrieves the update counts for the batched statements rewritten as diff --git a/src/main/java/org/mariadb/jdbc/internal/common/packet/MaxAllowedPacketException.java b/src/main/java/org/mariadb/jdbc/internal/common/packet/MaxAllowedPacketException.java new file mode 100644 index 000000000..2d0f59929 --- /dev/null +++ b/src/main/java/org/mariadb/jdbc/internal/common/packet/MaxAllowedPacketException.java @@ -0,0 +1,18 @@ +package org.mariadb.jdbc.internal.common.packet; + +import java.io.IOException; + +/** + * Created by diego_000 on 03/06/2015. + */ +public class MaxAllowedPacketException extends IOException { + boolean mustReconnect; + public MaxAllowedPacketException(String message, boolean mustReconnect) { + super(message); + this.mustReconnect = mustReconnect; + } + + public boolean isMustReconnect() { + return mustReconnect; + } +} diff --git a/src/main/java/org/mariadb/jdbc/internal/common/packet/PacketOutputStream.java b/src/main/java/org/mariadb/jdbc/internal/common/packet/PacketOutputStream.java index eceae57ef..2012ea0fe 100644 --- a/src/main/java/org/mariadb/jdbc/internal/common/packet/PacketOutputStream.java +++ b/src/main/java/org/mariadb/jdbc/internal/common/packet/PacketOutputStream.java @@ -1,11 +1,16 @@ package org.mariadb.jdbc.internal.common.packet; +import org.mariadb.jdbc.internal.common.packet.commands.StreamedQueryPacket; +import org.mariadb.jdbc.internal.common.query.MySQLQuery; + import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; +import java.util.logging.Level; +import java.util.logging.Logger; - -public class PacketOutputStream extends OutputStream{ +public class PacketOutputStream extends OutputStream { + private final static Logger log = Logger.getLogger("org.maria.jdbc"); private static final int MAX_PACKET_LENGTH = 0x00ffffff; private static final int SEQNO_OFFSET = 3; @@ -100,7 +105,6 @@ public void write(byte[] bytes, int off, int len) throws IOException{ System.arraycopy(byteBuffer, 0, tmp, 0, position); byteBuffer = tmp; } - System.arraycopy(bytes, off, byteBuffer, position, bytesToWrite); position += bytesToWrite; off += bytesToWrite; @@ -123,12 +127,18 @@ private void internalFlush() throws IOException { byteBuffer[1] = (byte)((dataLen >> 8) & 0xff); byteBuffer[2] = (byte)((dataLen >> 16) & 0xff); byteBuffer[SEQNO_OFFSET] = (byte)this.seqNo; - bytesWritten += dataLen; + bytesWritten += dataLen + HEADER_LENGTH; if (maxAllowedPacket > 0 && bytesWritten > maxAllowedPacket && checkPacketLength) { - baseStream.close(); - throw new IOException("max_allowed_packet exceeded. wrote " + bytesWritten + ", max_allowed_packet = " +maxAllowedPacket); + this.seqNo=-1; + throw new MaxAllowedPacketException("max_allowed_packet exceeded. wrote " + bytesWritten + ", max_allowed_packet = " +maxAllowedPacket, this.seqNo != 0); } baseStream.write(byteBuffer, 0, position); + if (log.isLoggable(Level.FINEST)) { + byte[] tmp = new byte[Math.min(1000, position)]; + System.arraycopy(byteBuffer, 0, tmp, 0, Math.min(1000, position)); + log.finest(new String(tmp)); + } + position = HEADER_LENGTH; this.seqNo++; } diff --git a/src/main/java/org/mariadb/jdbc/internal/common/packet/commands/StreamedQueryPacket.java b/src/main/java/org/mariadb/jdbc/internal/common/packet/commands/StreamedQueryPacket.java index e1047edd9..1445ba43d 100644 --- a/src/main/java/org/mariadb/jdbc/internal/common/packet/commands/StreamedQueryPacket.java +++ b/src/main/java/org/mariadb/jdbc/internal/common/packet/commands/StreamedQueryPacket.java @@ -49,7 +49,6 @@ WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWIS package org.mariadb.jdbc.internal.common.packet.commands; -import org.mariadb.jdbc.internal.SQLExceptionMapper; import org.mariadb.jdbc.internal.common.QueryException; import org.mariadb.jdbc.internal.common.packet.CommandPacket; import org.mariadb.jdbc.internal.common.packet.PacketOutputStream; @@ -57,22 +56,44 @@ WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWIS import java.io.IOException; import java.io.OutputStream; +import java.util.ArrayList; +import java.util.List; public class StreamedQueryPacket implements CommandPacket { - private final Query query; + private List queries; + private boolean isRewritable; + private int rewriteOffset; - public StreamedQueryPacket(final Query query) { - this.query = query; + public StreamedQueryPacket(final List queries, boolean isRewritable, int rewriteOffset) { + this.queries = queries; + this.isRewritable = isRewritable; + this.rewriteOffset = rewriteOffset; } public int send(final OutputStream ostream) throws IOException, QueryException { - PacketOutputStream pos = (PacketOutputStream)ostream; - pos.startPacket(0); - pos.write(0x03); - query.writeTo(ostream); - pos.finishPacket(); + if (queries.size() == 1) { + PacketOutputStream pos = (PacketOutputStream)ostream; + pos.startPacket(0); + pos.write(0x03); + queries.get(0).writeTo(ostream); + pos.finishPacket(); + } else { + PacketOutputStream pos = (PacketOutputStream)ostream; + pos.startPacket(0); + pos.write(0x03); + queries.get(0).writeTo(ostream); + for (int i=1;i dQueries, boolean streaming) throws QueryException{ RawPacket rawPacket; ResultPacket resultPacket; try { @@ -973,8 +987,8 @@ public QueryResult getResult(Query dQuery, boolean streaming) throws QueryExcept this.moreResults = false; this.hasWarnings = false; ErrorPacket ep = (ErrorPacket) resultPacket; - if (dQuery != null) { - log.warning("Could not execute query " + dQuery + ": " + ((ErrorPacket) resultPacket).getMessage()); + if (dQueries != null && dQueries.size() == 1) { + log.warning("Could not execute query " + dQueries.get(0) + ": " + ((ErrorPacket) resultPacket).getMessage()); } else { log.warning("Got error from server: " + ((ErrorPacket) resultPacket).getMessage()); } @@ -1010,37 +1024,41 @@ public QueryResult getResult(Query dQuery, boolean streaming) throws QueryExcept } } - public QueryResult executeQuery(final Query dQuery, boolean streaming) throws QueryException, SQLException - { - dQuery.validate(); - log.log(Level.FINEST, "Executing streamed query: {0}", dQuery); + public QueryResult executeQuery(final Query query, boolean streaming) throws QueryException, SQLException { + List queries = new ArrayList(); + queries.add(query); + return executeQuery(queries, streaming, false, 0); + } + + public QueryResult executeQuery(final List dQueries, boolean streaming, boolean isRewritable, int rewriteOffset) throws QueryException, SQLException { + for (Query query : dQueries) query.validate(); + this.moreResults = false; - final StreamedQueryPacket packet = new StreamedQueryPacket(dQuery); + final StreamedQueryPacket packet = new StreamedQueryPacket(dQueries, isRewritable, rewriteOffset); try { packet.send(writer); + } catch (MaxAllowedPacketException e) { + if (e.isMustReconnect()) connect(); + throw new QueryException("Could not send query: " + e.getMessage(), -1, SQLExceptionMapper.SQLStates.INTERRUPTED_EXCEPTION.getSqlState(), e); } catch (IOException e) { - throw new QueryException("Could not send query: " + e.getMessage(), - -1, - SQLExceptionMapper.SQLStates.CONNECTION_EXCEPTION.getSqlState(), - e); + throw new QueryException("Could not send query: " + e.getMessage(), -1, SQLExceptionMapper.SQLStates.CONNECTION_EXCEPTION.getSqlState(), e); } if (!isMasterConnection()) queriesSinceFailover++; try { - return getResult(dQuery, streaming); + return getResult(dQueries, streaming); } catch (QueryException qex) { - if (qex.getCause() instanceof SocketTimeoutException) { - close(); - throw SQLExceptionMapper.getSQLException("Connection timed out"); - } else { - throw qex; - } + if (qex.getCause() instanceof SocketTimeoutException) { + close(); + throw SQLExceptionMapper.getSQLException("Connection timed out"); + } else { + throw qex; + } } } - public String getServerVariable(String variable) throws QueryException, SQLException { CachedSelectResult qr = (CachedSelectResult) executeQuery(new MySQLQuery("select @@" + variable)); try { diff --git a/src/test/java/org/mariadb/jdbc/BaseTest.java b/src/test/java/org/mariadb/jdbc/BaseTest.java index 23ed301ad..2205fa621 100644 --- a/src/test/java/org/mariadb/jdbc/BaseTest.java +++ b/src/test/java/org/mariadb/jdbc/BaseTest.java @@ -47,7 +47,7 @@ public static void beforeClassBaseTest() { consoleHandler.setFormatter(new CustomFormatter()); consoleHandler.setLevel(Level.parse(logLevel)); log.addHandler(consoleHandler); - log.setLevel(Level.ALL); + log.setLevel(Level.FINE); } } diff --git a/src/test/java/org/mariadb/jdbc/ConnectionTest.java b/src/test/java/org/mariadb/jdbc/ConnectionTest.java index cb5442406..2ba1427d2 100644 --- a/src/test/java/org/mariadb/jdbc/ConnectionTest.java +++ b/src/test/java/org/mariadb/jdbc/ConnectionTest.java @@ -5,13 +5,16 @@ import java.io.InputStream; import java.io.OutputStream; import java.lang.reflect.Field; +import java.lang.reflect.Method; import java.sql.*; import java.io.UnsupportedEncodingException; import java.util.Arrays; import java.util.Properties; import java.util.concurrent.Executor; +import org.junit.Assume; import org.junit.Test; +import org.mariadb.jdbc.internal.common.query.MySQLQuery; import org.mariadb.jdbc.internal.mysql.MySQLProtocol; public class ConnectionTest extends BaseTest { @@ -142,58 +145,77 @@ public void isValid_shouldThrowExceptionWithNegativeTimeout() * @throws UnsupportedEncodingException */ @Test - public void maxAllowedPackedExceptionIsPrettyTest() throws SQLException, UnsupportedEncodingException { - int maxAllowedPacket = 1024 * 1024; - Statement statement = connection.createStatement(); - ResultSet rs = statement.executeQuery("SHOW VARIABLES LIKE 'max_allowed_packet'"); - if (rs.next()) { - maxAllowedPacket = rs.getInt(2); + public void maxAllowedPackedExceptionIsPrettyTest() throws Throwable, SQLException, UnsupportedEncodingException { + //test without reconnection if packet to long + Connection tmpConnection = getChangedAllowedPacketConnection(8 * 1024 * 1024); + try { + checkMaxAllowedPacket(tmpConnection, 8 * 1024 * 1024); + } finally { + tmpConnection.close(); + } + + Connection tmpConnection2 = getChangedAllowedPacketConnection(32 * 1024 * 1024); + try { + checkMaxAllowedPacket(tmpConnection2, 32 * 1024 * 1024); + } finally { + tmpConnection2.close(); } - rs.close(); + } + + private void checkMaxAllowedPacket(Connection tmpConnection, int maxAllowedPacket ) throws Throwable, SQLException, UnsupportedEncodingException { + Statement statement = tmpConnection.createStatement(); statement.execute("DROP TABLE IF EXISTS dummy"); statement.execute("CREATE TABLE dummy (a BLOB)"); - //Create a SQL packet bigger than maxAllowedPacket + ResultSet rs = statement.executeQuery("show variables like 'max_allowed_packet'"); + rs.next(); + log.fine("max_allowed_packet DB" + rs.getString(2) + " / " + maxAllowedPacket); + + /**Create a SQL packet bigger than maxAllowedPacket**/ StringBuilder sb = new StringBuilder(); String rowData = "('this is a dummy row values')"; int rowsToWrite = (maxAllowedPacket / rowData.getBytes("UTF-8").length) + 1; - try { + try { for (int row = 1; row <= rowsToWrite; row++) { if (row >= 2) { sb.append(", "); } sb.append(rowData); } + statement.executeUpdate("INSERT INTO dummy VALUES " + sb.toString()); + fail("The previous statement should throw an SQLException"); } catch (OutOfMemoryError e) { log.warning("skip test 'maxAllowedPackedExceptionIsPrettyTest' - not enough memory"); - return; - } - String sql = "INSERT INTO dummy VALUES " + sb.toString(); - try { - statement.executeUpdate(sql); - fail("The previous statement should throw an SQLException"); + Assume.assumeNoException(e); } catch (SQLException e) { assertTrue(e.getMessage().contains("max_allowed_packet")); } catch (Exception e) { fail("The previous statement should throw an SQLException not a general Exception"); } - //added in CONJ-151 to check the 2 differents type of query - PreparedStatement preparedStatement = connection.prepareStatement("INSERT INTO dummy VALUES (?)"); - byte [] arr = new byte[maxAllowedPacket + 1000]; - Arrays.fill(arr, (byte) 'a'); - preparedStatement.setBytes(1,arr); - preparedStatement.addBatch(); + resetProtocolMaxAllowedPacket(tmpConnection, maxAllowedPacket); //reset maxPacket because if connection reconnect, the maxAllowedPacket is reset to default value + statement.execute("select count(*) from dummy"); //check that the connection is still working + + + /**added in CONJ-151 to check the 2 differents type of query implementation**/ + PreparedStatement preparedStatement = tmpConnection.prepareStatement("INSERT INTO dummy VALUES (?)"); try { + byte [] arr = new byte[maxAllowedPacket + 1000]; + Arrays.fill(arr, (byte) 'a'); + preparedStatement.setBytes(1,arr); + preparedStatement.addBatch(); preparedStatement.executeBatch(); fail("The previous statement should throw an SQLException"); + } catch (OutOfMemoryError e) { + log.warning("skip second test 'maxAllowedPackedExceptionIsPrettyTest' - not enough memory"); + Assume.assumeNoException(e); } catch (SQLException e) { + log.fine("normal SQlExeption "+e.getMessage()); assertTrue(e.getMessage().contains("max_allowed_packet")); } catch (Exception e) { fail("The previous statement should throw an SQLException not a general Exception"); } finally { - statement.execute("DROP TABLE dummy"); + statement.execute("select count(*) from dummy"); //to check that connection is open } - } @@ -261,4 +283,29 @@ private long getServerThreadId() throws Exception { return threadId; } + private Connection getChangedAllowedPacketConnection(int maxAllowedPacket) throws Throwable { + Statement statement = null; + try { + statement = connection.createStatement(); + ResultSet rs = statement.executeQuery("show variables like 'max_allowed_packet'"); + rs.next(); + int currentAllowedPacket = rs.getInt(2); + + statement.execute("SET GLOBAL max_allowed_packet=" + maxAllowedPacket); + Connection tmpConnection = openNewConnection(connURI, new Properties()); + resetProtocolMaxAllowedPacket(tmpConnection, maxAllowedPacket); + statement.execute("SET GLOBAL max_allowed_packet=" + currentAllowedPacket); + return tmpConnection; + } finally { + statement.close(); + } + } + private void resetProtocolMaxAllowedPacket(Connection conn, int maxAllowedPacket) throws Throwable { + Method getProtocolMethod = MySQLConnection.class.getDeclaredMethod("getProtocol", new Class[0]); + getProtocolMethod.setAccessible(true); + MySQLProtocol protocol = (MySQLProtocol) getProtocolMethod.invoke(conn); + protocol.setMaxAllowedPacket(maxAllowedPacket); + + } + } diff --git a/src/test/java/org/mariadb/jdbc/MultiTest.java b/src/test/java/org/mariadb/jdbc/MultiTest.java index 518aad863..f41b83d9f 100644 --- a/src/test/java/org/mariadb/jdbc/MultiTest.java +++ b/src/test/java/org/mariadb/jdbc/MultiTest.java @@ -4,8 +4,11 @@ import java.sql.*; import java.util.Properties; +import java.util.logging.Level; +import java.util.logging.Logger; import junit.framework.Assert; +import org.mariadb.jdbc.internal.common.packet.PacketOutputStream; import static org.junit.Assert.*; @@ -14,7 +17,6 @@ public class MultiTest extends BaseTest { private static Connection connection; public MultiTest() throws SQLException { - } @BeforeClass @@ -434,6 +436,7 @@ public void updateCountTest() throws SQLException { sqlUpdate.addBatch(); int[] updateCounts = sqlUpdate.executeBatch(); + log.finest("updateCounts : "+updateCounts.length); Assert.assertEquals(3, updateCounts.length); Assert.assertEquals(1, updateCounts[0]); Assert.assertEquals(0, updateCounts[1]);