From 59f3d68c1bbd32936e631b9e533c31f45d5ec770 Mon Sep 17 00:00:00 2001 From: vaintroub Date: Fri, 29 May 2015 16:18:32 -0700 Subject: [PATCH] fix for max_allowed_packet check. also remove some obsolete stuff . also, remove batch rewrites for prepared statements, it does not work as it is written --- .../mariadb/jdbc/MySQLPreparedStatement.java | 109 +++--------------- .../common/packet/PacketOutputStream.java | 18 ++- .../packet/commands/StreamedQueryPacket.java | 15 +-- .../common/query/MySQLParameterizedQuery.java | 47 +++----- .../internal/common/query/MySQLQuery.java | 4 +- .../jdbc/internal/common/query/Query.java | 3 +- .../jdbc/internal/mysql/MySQLProtocol.java | 9 +- 7 files changed, 59 insertions(+), 146 deletions(-) diff --git a/src/main/java/org/mariadb/jdbc/MySQLPreparedStatement.java b/src/main/java/org/mariadb/jdbc/MySQLPreparedStatement.java index 61b3d036d..4526b474a 100644 --- a/src/main/java/org/mariadb/jdbc/MySQLPreparedStatement.java +++ b/src/main/java/org/mariadb/jdbc/MySQLPreparedStatement.java @@ -76,8 +76,6 @@ public class MySQLPreparedStatement extends MySQLStatement implements PreparedSt private boolean useFractionalSeconds; boolean parametersCleared; List batchPreparedStatements; - private boolean isRewriteable = true; - private String firstRewrite = null; public MySQLPreparedStatement(MySQLConnection connection, @@ -195,12 +193,11 @@ public void setNull(final int parameterIndex, final int sqlType) throws SQLExcep public void addBatch() throws SQLException { checkBatchFields(); batchPreparedStatements.add(new MySQLPreparedStatement(connection,sql, dQuery, useFractionalSeconds)); - isInsertRewriteable(sql); + } public void addBatch(final String sql) throws SQLException { checkBatchFields(); batchPreparedStatements.add(new MySQLPreparedStatement(connection, sql)); - isInsertRewriteable(sql); } private void checkBatchFields() { @@ -213,50 +210,8 @@ public void clearBatch() { if (batchPreparedStatements != null) { batchPreparedStatements.clear(); } - firstRewrite = null; - isRewriteable = true; } - /** - * Parses the sql string to understand whether it is compatible with rewritten batches. - * @param sql the sql string - */ - private void isInsertRewriteable(String sql) { - if (!isRewriteable) { - return; - } - int index = getInsertIncipit(sql); - if (index == -1) { - isRewriteable = false; - return; - } - if (firstRewrite == null) { - firstRewrite = sql.substring(0, index); - } - boolean isRewrite = sql.startsWith(firstRewrite); - if (isRewrite) { - isRewriteable = isRewriteable && true; - } - } - - /** - * If the batch array contains only rewriteable sql strings, returns the rewritten statement. - * @return the rewritten statement - */ - private String rewrittenBatch() { - StringBuilder result = null; - if(isRewriteable) { - result = new StringBuilder(""); - result.append(firstRewrite); - for (MySQLPreparedStatement mySQLPS : batchPreparedStatements) { - String query = mySQLPS.dQuery.toSQL(); - result.append(query.substring(firstRewrite.length())); - result.append(","); - } - result.deleteCharAt(result.length() - 1); - } - return (result == null ? null : result.toString()); - } @Override public int[] executeBatch() throws SQLException { @@ -267,28 +222,23 @@ public int[] executeBatch() throws SQLException { int i = 0; MySQLResultSet rs = null; try { - synchronized (this.getProtocol()) { - if (getProtocol().getInfo().getProperty("rewriteBatchedStatements") != null - && "true".equalsIgnoreCase(getProtocol().getInfo().getProperty("rewriteBatchedStatements"))) { - ret = executeBatchAsMultiQueries(); - } else { - for (; i < batchPreparedStatements.size(); i++) { - PreparedStatement ps = batchPreparedStatements.get(i); - ps.execute(); - int updateCount = ps.getUpdateCount(); - if (updateCount == -1) { - ret[i] = SUCCESS_NO_INFO; - } else { - ret[i] = updateCount; - } - if (i == 0) { - rs = (MySQLResultSet)ps.getGeneratedKeys(); - } else { - rs = rs.joinResultSets((MySQLResultSet)ps.getGeneratedKeys()); - } - } - } - } + synchronized (this.getProtocol()) { + for (; i < batchPreparedStatements.size(); i++) { + PreparedStatement ps = batchPreparedStatements.get(i); + ps.execute(); + int updateCount = ps.getUpdateCount(); + if (updateCount == -1) { + ret[i] = SUCCESS_NO_INFO; + } else { + ret[i] = updateCount; + } + if (i == 0) { + rs = (MySQLResultSet)ps.getGeneratedKeys(); + } else { + rs = rs.joinResultSets((MySQLResultSet)ps.getGeneratedKeys()); + } + } + } } catch (SQLException sqle) { throw new BatchUpdateException(sqle.getMessage(), sqle.getSQLState(), sqle.getErrorCode(), Arrays.copyOf(ret, i), sqle); } finally { @@ -298,29 +248,6 @@ public int[] executeBatch() throws SQLException { return ret; } - /** - * Builds a new statement which contains the batched Statements and executes it. - * @return an array of update counts containing one element for each command in the batch. - * The elements of the array are ordered according to the order in which commands were added to the batch. - * @throws SQLException - */ - private int[] executeBatchAsMultiQueries() throws SQLException { - StringBuilder stringBuilder = new StringBuilder(); - int i = 0; - String rewrite = rewrittenBatch(); - boolean rewrittenBatch = rewrite != null; - if (rewrittenBatch) { - stringBuilder.append(rewrite); - i = batchPreparedStatements.size(); - } else { - for (; i < batchPreparedStatements.size(); i++) { - stringBuilder.append(batchPreparedStatements.get(i).dQuery.toSQL() + ";"); - } - } - Statement ps = connection.createStatement(); - ps.execute(stringBuilder.toString()); - return rewrittenBatch ? getUpdateCountsForReWrittenBatch(ps, i) : getUpdateCounts(ps, i); - } /** * Sets the designated parameter to the given Reader object, which is the given number of characters diff --git a/src/main/java/org/mariadb/jdbc/internal/common/packet/PacketOutputStream.java b/src/main/java/org/mariadb/jdbc/internal/common/packet/PacketOutputStream.java index 535d59bf9..eceae57ef 100644 --- a/src/main/java/org/mariadb/jdbc/internal/common/packet/PacketOutputStream.java +++ b/src/main/java/org/mariadb/jdbc/internal/common/packet/PacketOutputStream.java @@ -17,7 +17,9 @@ public class PacketOutputStream extends OutputStream{ int position; int seqNo; boolean compress; - int maxAllowedPacket = 0; + int maxAllowedPacket; + int bytesWritten; + boolean checkPacketLength; public PacketOutputStream(OutputStream baseStream) { this.baseStream = baseStream; @@ -31,14 +33,19 @@ public void setCompress(boolean value) { compress = value; } - public void startPacket(int seqNo) throws IOException { + public void startPacket(int seqNo, boolean checkPacketLength) throws IOException { if (this.seqNo != -1) { throw new IOException("Last packet not finished"); } this.seqNo = seqNo; position = HEADER_LENGTH; + bytesWritten = 0; + this.checkPacketLength = checkPacketLength; } + public void startPacket(int seqNo) throws IOException { + startPacket(seqNo, true); + } public int getSeqNo() { return seqNo; } @@ -59,7 +66,7 @@ public void sendFile(InputStream is, int seq) throws IOException{ byte[] buffer = new byte[bufferSize]; int len; while((len = is.read(buffer)) > 0) { - startPacket(seq++); + startPacket(seq++, false); write(buffer, 0, len); finishPacket(); } @@ -116,6 +123,11 @@ private void internalFlush() throws IOException { byteBuffer[1] = (byte)((dataLen >> 8) & 0xff); byteBuffer[2] = (byte)((dataLen >> 16) & 0xff); byteBuffer[SEQNO_OFFSET] = (byte)this.seqNo; + bytesWritten += dataLen; + if (maxAllowedPacket > 0 && bytesWritten > maxAllowedPacket && checkPacketLength) { + baseStream.close(); + throw new IOException("max_allowed_packet exceeded. wrote " + bytesWritten + ", max_allowed_packet = " +maxAllowedPacket); + } baseStream.write(byteBuffer, 0, position); position = HEADER_LENGTH; this.seqNo++; diff --git a/src/main/java/org/mariadb/jdbc/internal/common/packet/commands/StreamedQueryPacket.java b/src/main/java/org/mariadb/jdbc/internal/common/packet/commands/StreamedQueryPacket.java index 11a778777..e1047edd9 100644 --- a/src/main/java/org/mariadb/jdbc/internal/common/packet/commands/StreamedQueryPacket.java +++ b/src/main/java/org/mariadb/jdbc/internal/common/packet/commands/StreamedQueryPacket.java @@ -62,27 +62,16 @@ WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWIS public class StreamedQueryPacket implements CommandPacket { private final Query query; - private final int maxAllowedPacket; - public StreamedQueryPacket(final Query query, int maxAllowedPacket) { + public StreamedQueryPacket(final Query query) { this.query = query; - this.maxAllowedPacket = maxAllowedPacket; } public int send(final OutputStream ostream) throws IOException, QueryException { - byte[] queryStream = query.sqlByteArray(); - if (maxAllowedPacket > 0 && queryStream.length > maxAllowedPacket) { - throw new QueryException("Packet for query is too large (" - + queryStream.length - + " > " - + maxAllowedPacket - + "). You can change this value on the server by setting the max_allowed_packet' variable.", - -1, SQLExceptionMapper.SQLStates.UNDEFINED_SQLSTATE.getSqlState()); - } PacketOutputStream pos = (PacketOutputStream)ostream; pos.startPacket(0); pos.write(0x03); - ostream.write(queryStream, 0, queryStream.length); + query.writeTo(ostream); pos.finishPacket(); return 0; } diff --git a/src/main/java/org/mariadb/jdbc/internal/common/query/MySQLParameterizedQuery.java b/src/main/java/org/mariadb/jdbc/internal/common/query/MySQLParameterizedQuery.java index 526251cf5..bd05992d3 100644 --- a/src/main/java/org/mariadb/jdbc/internal/common/query/MySQLParameterizedQuery.java +++ b/src/main/java/org/mariadb/jdbc/internal/common/query/MySQLParameterizedQuery.java @@ -50,9 +50,12 @@ WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWIS import org.mariadb.jdbc.internal.common.QueryException; import org.mariadb.jdbc.internal.common.query.parameters.ParameterHolder; + import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.io.OutputStream; import java.io.UnsupportedEncodingException; +import java.nio.charset.Charset; import java.util.List; import static org.mariadb.jdbc.internal.common.Utils.createQueryParts; @@ -118,18 +121,17 @@ public void validate() throws QueryException{ } - public byte[] sqlByteArray() throws IOException { - ByteArrayOutputStream baos = new ByteArrayOutputStream(); + public void writeTo(final OutputStream os) throws IOException, QueryException { + if(queryPartsArray.length == 0) { throw new AssertionError("Invalid query, queryParts was empty"); } - baos.write(queryPartsArray[0]); + os.write(queryPartsArray[0]); for(int i = 1; i 0) { sb.append(", parameters : ["); for(int i = 0; i < parameters.length; i++) { - if (parameters[i] == null) { - sb.append("null"); - } else { - sb.append(parameters[i].toString()); - } - if (i != parameters.length -1) { - sb.append(","); - } + if (parameters[i] == null) { + sb.append("null"); + } else { + sb.append(parameters[i].toString()); + } + if (i != parameters.length -1) { + sb.append(","); + } } sb.append("]"); } return sb.toString(); } - - /** - * Returns a string representing the SQL of the query. - * @return - */ - public String toSQL() { - try { - return new String(sqlByteArray()); - } catch (IOException e) { - return ""; - } - } - - - -} +} \ No newline at end of file diff --git a/src/main/java/org/mariadb/jdbc/internal/common/query/MySQLQuery.java b/src/main/java/org/mariadb/jdbc/internal/common/query/MySQLQuery.java index c59d1f1b9..9e6415f57 100644 --- a/src/main/java/org/mariadb/jdbc/internal/common/query/MySQLQuery.java +++ b/src/main/java/org/mariadb/jdbc/internal/common/query/MySQLQuery.java @@ -82,8 +82,8 @@ public int length() { return queryToSend.length; } - public byte[] sqlByteArray() { - return queryToSend; + public void writeTo(final OutputStream os) throws IOException { + os.write(queryToSend, 0, queryToSend.length); } public String getQuery() { diff --git a/src/main/java/org/mariadb/jdbc/internal/common/query/Query.java b/src/main/java/org/mariadb/jdbc/internal/common/query/Query.java index 3b75b603d..39dc782d9 100644 --- a/src/main/java/org/mariadb/jdbc/internal/common/query/Query.java +++ b/src/main/java/org/mariadb/jdbc/internal/common/query/Query.java @@ -51,10 +51,11 @@ WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWIS import org.mariadb.jdbc.internal.common.QueryException; import java.io.IOException; +import java.io.OutputStream; public interface Query { - byte[] sqlByteArray() throws IOException; String getQuery(); + void writeTo(OutputStream os) throws IOException, QueryException; QueryType getQueryType(); void validate() throws QueryException; } diff --git a/src/main/java/org/mariadb/jdbc/internal/mysql/MySQLProtocol.java b/src/main/java/org/mariadb/jdbc/internal/mysql/MySQLProtocol.java index 02ad8f290..d81a33e0e 100644 --- a/src/main/java/org/mariadb/jdbc/internal/mysql/MySQLProtocol.java +++ b/src/main/java/org/mariadb/jdbc/internal/mysql/MySQLProtocol.java @@ -526,7 +526,6 @@ void connect(String host, int port) throws QueryException, IOException, SQLExcep hasWarnings = false; connected = true; hostFailed = false; // Prevent reconnects - writer.setMaxAllowedPacket(this.maxAllowedPacket); } catch (IOException e) { throw new QueryException("Could not connect to " + host + ":" + port + ": " + e.getMessage(), @@ -1016,7 +1015,7 @@ public QueryResult executeQuery(final Query dQuery, boolean streaming) throws Qu dQuery.validate(); log.log(Level.FINEST, "Executing streamed query: {0}", dQuery); this.moreResults = false; - final StreamedQueryPacket packet = new StreamedQueryPacket(dQuery, this.maxAllowedPacket); + final StreamedQueryPacket packet = new StreamedQueryPacket(dQuery); try { packet.send(writer); @@ -1173,11 +1172,9 @@ public void setLocalInfileInputStream(InputStream inputStream) { this.localInfileInputStream = inputStream; } - public int getMaxAllowedPacket() { - return this.maxAllowedPacket; - } + public void setMaxAllowedPacket(int maxAllowedPacket) { - this.maxAllowedPacket = maxAllowedPacket; + writer.setMaxAllowedPacket(maxAllowedPacket); } /**