diff --git a/src/main/scala/ru/tinkoff/load/jdbc/JdbcDsl.scala b/src/main/scala/ru/tinkoff/load/jdbc/JdbcDsl.scala index c38607b..b6ce78f 100644 --- a/src/main/scala/ru/tinkoff/load/jdbc/JdbcDsl.scala +++ b/src/main/scala/ru/tinkoff/load/jdbc/JdbcDsl.scala @@ -2,14 +2,16 @@ package ru.tinkoff.load.jdbc import io.gatling.core.protocol.Protocol import io.gatling.core.session.Expression -import ru.tinkoff.load.jdbc.actions.DBBaseAction - +import ru.tinkoff.load.jdbc.actions.{BatchUpdateBaseAction, BatchInsertBaseAction, Columns, DBBaseAction} import ru.tinkoff.load.jdbc.check.JdbcCheckSupport import ru.tinkoff.load.jdbc.protocol.{JdbcProtocolBuilder, JdbcProtocolBuilderBase} -trait JdbcDsl extends JdbcCheckSupport{ +trait JdbcDsl extends JdbcCheckSupport { def DB: JdbcProtocolBuilderBase.type = JdbcProtocolBuilderBase def jdbc(name: Expression[String]): DBBaseAction = DBBaseAction(name) + def insertInto(tableName: Expression[String], columns: Columns): BatchInsertBaseAction = + BatchInsertBaseAction(tableName, columns) + def update(tableName: Expression[String]): BatchUpdateBaseAction = BatchUpdateBaseAction(tableName) implicit def jdbcProtocolBuilder2jdbcProtocol(builder: JdbcProtocolBuilder): Protocol = builder.build } diff --git a/src/main/scala/ru/tinkoff/load/jdbc/actions/DBBatchAction.scala b/src/main/scala/ru/tinkoff/load/jdbc/actions/DBBatchAction.scala new file mode 100644 index 0000000..3a865f3 --- /dev/null +++ b/src/main/scala/ru/tinkoff/load/jdbc/actions/DBBatchAction.scala @@ -0,0 +1,93 @@ +package ru.tinkoff.load.jdbc.actions + +import io.gatling.commons.stats.{KO, OK} +import io.gatling.core.action.{Action, ChainableAction} +import io.gatling.core.session.{Expression, Session} +import io.gatling.core.structure.ScenarioContext +import io.gatling.core.util.NameGen +import io.gatling.commons.validation._ +import ru.tinkoff.load.jdbc.db.{SQL, SqlWithParam} + +final case class DBBatchAction( + batchName: Expression[String], + actions: Seq[BatchAction], + next: Action, + ctx: ScenarioContext +) extends ChainableAction with NameGen with ActionBase { + + private implicit class TrSeq[+T](seq: Seq[T]) { + def traverse[S](f: T => Validation[S]): Validation[Seq[S]] = seq.foldRight(Seq.empty[S].success)( + (i, r) => r.flatMap(s => f(i).map(s.prepended)) + ) + } + + override def name: String = genName("jdbcBatchAction") + + private def resolveParams(session: Session, values: Seq[(String, Expression[Any])]) = + values.traverse { case (k, v) => v(session).map((k, _)) }.map(_.toMap) + + private def resolveBatchAction(session: Session): PartialFunction[BatchAction, Validation[SqlWithParam]] = { + + case BatchUpdateAction(tableName, updateValues, None) => + for { + tName <- tableName(session) + iParams <- resolveParams(session, updateValues) + sql <- SQL(s"UPDATE $tName SET ${iParams.map(c => s"${c._1} = {${c._1}}").mkString(",")}") + .withParamsMap(iParams) + .success + } yield sql + + case BatchUpdateAction(tableName, updateValues, Some(whereExpression)) => + for { + tName <- tableName(session) + iParams <- resolveParams(session, updateValues) + resolvedWhere <- whereExpression(session) + sql <- SQL(s"UPDATE $tName SET ${iParams.map(c => s"${c._1} = {${c._1}}").mkString(",")} WHERE $resolvedWhere") + .withParamsMap(iParams) + .success + } yield sql + + case BatchInsertAction(tableName, columns, sessionValues) => + for { + tName <- tableName(session) + iParams <- resolveParams(session, sessionValues) + sql <- SQL( + s"INSERT INTO $tName (${columns.names.mkString(",")}) VALUES(${columns.names.map(s => s"{$s}").mkString(",")})") + .withParamsMap(iParams) + .success + } yield sql + } + + override protected def execute(session: Session): Unit = + (for { + resolvedBatchName <- batchName(session) + sqlQueriesWithParams <- actions.traverse(resolveBatchAction(session)) + startTime <- ctx.coreComponents.clock.nowMillis.success + } yield + db.executeBatch(implicit c => c.batch(sqlQueriesWithParams)) + .fold( + e => { + println(s"ERROR: ${e.getMessage}") + executeNext(session, + startTime, + ctx.coreComponents.clock.nowMillis, + KO, + next, + resolvedBatchName, + Some("ERROR"), + Some(e.getMessage)) + }, + _ => executeNext(session, startTime, ctx.coreComponents.clock.nowMillis, OK, next, resolvedBatchName, None, None) + )).onFailure(m => + batchName(session).map { rn => + ctx.coreComponents.statsEngine.logCrash(session.scenario, session.groups, rn, m) + executeNext(session, + ctx.coreComponents.clock.nowMillis, + ctx.coreComponents.clock.nowMillis, + KO, + next, + rn, + Some("ERROR"), + Some(m)) + }) +} diff --git a/src/main/scala/ru/tinkoff/load/jdbc/actions/package.scala b/src/main/scala/ru/tinkoff/load/jdbc/actions/package.scala index 9fd1373..7069f32 100644 --- a/src/main/scala/ru/tinkoff/load/jdbc/actions/package.scala +++ b/src/main/scala/ru/tinkoff/load/jdbc/actions/package.scala @@ -14,7 +14,25 @@ package object actions { def rawSql(queryString: Expression[String]): RawSqlActionBuilder = RawSqlActionBuilder(requestName, queryString) def queryP(sql: Expression[String]): QueryActionParamsStep = QueryActionParamsStep(requestName, sql) - def query(sql: Expression[String]): QueryActionBuilder = QueryActionBuilder(requestName, sql, params = Seq.empty) + def query(sql: Expression[String]): QueryActionBuilder = QueryActionBuilder(requestName, sql, params = Seq.empty) + def batch(actions: BatchAction*): BatchActionBuilder = BatchActionBuilder(requestName, actions) + } + + final case class BatchInsertBaseAction(tableName: Expression[String], columns: Columns) { + def values(values: (String, Expression[Any])*): BatchInsertAction = BatchInsertAction(tableName, columns, values) + } + + final case class BatchUpdateBaseAction(tableName: Expression[String]) { + def set(updateValues: (String, Expression[Any])*): BatchUpdateValuesStepAction = + BatchUpdateValuesStepAction(tableName, updateValues) + } + + final case class BatchUpdateValuesStepAction(tableName: Expression[String], updateValues: Seq[(String, Expression[Any])]) { + def where(whereExpression: Expression[String]): BatchUpdateAction = { + BatchUpdateAction(tableName, updateValues, Some(whereExpression)) + } + + val all: BatchUpdateAction = BatchUpdateAction(tableName, updateValues) } case class QueryActionParamsStep(requestName: Expression[String], sql: Expression[String]) { @@ -73,4 +91,18 @@ package object actions { DBInsertAction(requestName, tableName, columns.names, next, ctx, sessionValues) } + sealed trait BatchAction + final case class BatchInsertAction(tableName: Expression[String], + columns: Columns, + sessionValues: Seq[(String, Expression[Any])]) + extends BatchAction + + final case class BatchUpdateAction(tableName: Expression[String], + updateValues: Seq[(String, Expression[Any])], + where: Option[Expression[String]] = None) + extends BatchAction + + final case class BatchActionBuilder(batchName: Expression[String], actions: Seq[BatchAction]) extends ActionBuilder { + override def build(ctx: ScenarioContext, next: Action): Action = DBBatchAction(batchName, actions, next, ctx) + } } diff --git a/src/main/scala/ru/tinkoff/load/jdbc/db/ConnectedDB.scala b/src/main/scala/ru/tinkoff/load/jdbc/db/ConnectedDB.scala index f59eee5..65649cc 100644 --- a/src/main/scala/ru/tinkoff/load/jdbc/db/ConnectedDB.scala +++ b/src/main/scala/ru/tinkoff/load/jdbc/db/ConnectedDB.scala @@ -26,4 +26,15 @@ case class ConnectedDB(pool: HikariDataSource) { result <- exec(ManagedConnection(c)) _ <- Try(c.close()) } yield result + + def executeBatch(implicit exec: ManagedConnection => Try[List[Int]]): Try[List[Int]] = + for { + c <- Try(pool.getConnection) + autoCommit <- Try(c.getAutoCommit) + _ <- Try(c.setAutoCommit(false)) + result <- exec(ManagedConnection(c)) + _ <- Try(c.commit()) + _ <- Try(c.setAutoCommit(autoCommit)) + _ <- Try(c.close()) + } yield result } diff --git a/src/main/scala/ru/tinkoff/load/jdbc/db/ManagedConnection.scala b/src/main/scala/ru/tinkoff/load/jdbc/db/ManagedConnection.scala index 599d84e..e93be10 100644 --- a/src/main/scala/ru/tinkoff/load/jdbc/db/ManagedConnection.scala +++ b/src/main/scala/ru/tinkoff/load/jdbc/db/ManagedConnection.scala @@ -71,4 +71,22 @@ case class ManagedConnection(connection: Connection) { prepareStatement(sql, params.toMap, Map.empty, (sql, c) => c.prepareStatement(sql)), s => Try(s.close()) )(stmt => Try(stmt.executeQuery().iterator.toList)) + + def batch(queries: Seq[SqlWithParam], batchSize: Int = 1000): Try[List[Int]] = + for { + stmt <- Try(connection.createStatement()) + intermediateResult <- queries + .map(_.substituteParams) + .foldLeft(Try(List.empty[Int], 0))((r, q) => + r.flatMap { + case (lr, counter) => + for { + lr <- Try(if (counter % batchSize == 0) lr ++ stmt.executeBatch() else lr) + _ <- Try(stmt.addBatch(q)) + } yield (lr, counter + 1) + }) + .map(_._1) + r <- Try(intermediateResult ++ stmt.executeBatch()) + _ <- Try(stmt.close()) + } yield r } diff --git a/src/main/scala/ru/tinkoff/load/jdbc/db/package.scala b/src/main/scala/ru/tinkoff/load/jdbc/db/package.scala index d1a8752..25b0e35 100644 --- a/src/main/scala/ru/tinkoff/load/jdbc/db/package.scala +++ b/src/main/scala/ru/tinkoff/load/jdbc/db/package.scala @@ -1,8 +1,7 @@ package ru.tinkoff.load.jdbc import java.sql.ResultSet -import java.time.LocalDateTime - +import java.time.{LocalDateTime, OffsetDateTime} import scala.util.Try package object db { @@ -38,13 +37,36 @@ package object db { } case class SqlWithParam(sql: String, params: Seq[(String, ParamVal)], outParams: Seq[(String, Int)] = Seq.empty) { + private val paramsMap = params.toMap + private def paramValueToSql(name: String) = + paramsMap.get(name) match { + case Some(IntParam(v)) => s"$v" + case Some(DoubleParam(v)) => s"$v" + case Some(StrParam(v)) => s"'$v'" + case Some(LongParam(v)) => s"$v" + case Some(NullParam) => "NULL" + case Some(DateParam(v)) => s"CAST(${v.toInstant(OffsetDateTime.now().getOffset).toEpochMilli} AS DATE)" + case None => "" + } + + def substituteParams: String = { + sql + .foldLeft(("", "", false)) { + case ((r, curName, false), '{') => (r, curName, true) + case ((r, curName, true), '}') => (s"$r ${paramValueToSql(curName.trim)}", "", false) + case ((r, curName, true), c) => (r, s"$curName$c", true) + case ((r, curName, false), c) => (s"$r$c", curName, false) + } + ._1 + } + def withOutParams(ps: Seq[(String, Int)]): SqlWithParam = SqlWithParam(sql, params, ps) def executeInsert(implicit managedConnection: ManagedConnection): Try[Int] = managedConnection.execute(sql, params) def call(implicit managedConnection: ManagedConnection): Try[Int] = managedConnection.call(sql, params, outParams) - def executeQuery(implicit managedConnection: ManagedConnection): Try[List[Map[String,Any]]] = + def executeQuery(implicit managedConnection: ManagedConnection): Try[List[Map[String, Any]]] = managedConnection.execSelect(sql, params) } diff --git a/src/test/scala/ru/tinkoff/load/jdbc/test/cases/Actions.scala b/src/test/scala/ru/tinkoff/load/jdbc/test/cases/Actions.scala index 4638a63..443434f 100644 --- a/src/test/scala/ru/tinkoff/load/jdbc/test/cases/Actions.scala +++ b/src/test/scala/ru/tinkoff/load/jdbc/test/cases/Actions.scala @@ -7,11 +7,13 @@ import ru.tinkoff.load.jdbc.actions.Columns object Actions { - def createTable(): actions.RawSqlActionBuilder = jdbc("Create Table") - .rawSql("CREATE TABLE TEST_TABLE (ID INT PRIMARY KEY, NAME VARCHAR(64));") + def createTable(): actions.RawSqlActionBuilder = + jdbc("Create Table") + .rawSql("CREATE TABLE TEST_TABLE (ID INT PRIMARY KEY, NAME VARCHAR(64));") - def createProcedure(): actions.RawSqlActionBuilder = jdbc("Procedure create") - .rawSql("""CREATE ALIAS TEST_PROCEDURE AS $$ + def createProcedure(): actions.RawSqlActionBuilder = + jdbc("Procedure create") + .rawSql("""CREATE ALIAS TEST_PROCEDURE AS $$ |String testProcedure(String p1, Long p2) { | String suf = p1 + "test"; | return p2.toString() + suf; @@ -28,18 +30,33 @@ object Actions { .call("TEST_PROCEDURE") .params("p1" -> "value1", "p2" -> 24L) - def selectTest: actions.QueryActionBuilder = jdbc("SELECT TEST") - .queryP("SELECT * FROM TEST_TABLE WHERE ID = {id}") - .params("id" -> 1) - .check( - allRecordsCheck{ - r => - r.isEmpty - }, - allResults.is(List( - Map("ID"-> 1, "NAME" -> "Test3") - )), - allResults.saveAs("R") + def batchTest: actions.BatchActionBuilder = jdbc("Batch records").batch( + insertInto("TEST_TABLE", Columns("ID", "NAME")).values("ID" -> 2, "NAME" -> "Test 56"), + insertInto("TEST_TABLE", Columns("ID", "NAME")).values("ID" -> 3, "NAME" -> "Test 78"), + update("TEST_TABLE").set("NAME" -> "TEST 5").where("ID = 2"), +// update("TEST_TABLE").set("NAME" -> "bird").all ) + def selectTest: actions.QueryActionBuilder = + jdbc("SELECT TEST") + .queryP("SELECT * FROM TEST_TABLE WHERE ID = {id}") + .params("id" -> 1) + .check( + allRecordsCheck { r => + r.isEmpty + }, + allResults.is( + List( + Map("ID" -> 1, "NAME" -> "Test3") + )), + allResults.saveAs("R") + ) + + def selectAfterBatch: actions.QueryActionBuilder = + jdbc("SELECT SOME") + .query("SELECT * FROM TEST_TABLE") + .check( + allResults.saveAs("RR") + ) + } diff --git a/src/test/scala/ru/tinkoff/load/jdbc/test/scenarios/BasicSimulation.scala b/src/test/scala/ru/tinkoff/load/jdbc/test/scenarios/BasicSimulation.scala index b8455e7..d52dbd0 100644 --- a/src/test/scala/ru/tinkoff/load/jdbc/test/scenarios/BasicSimulation.scala +++ b/src/test/scala/ru/tinkoff/load/jdbc/test/scenarios/BasicSimulation.scala @@ -17,8 +17,11 @@ class BasicSimulation { .exec(Actions.insertTest()) .exec(Actions.callTest()) .exec(Actions.selectTest) + .exec(Actions.batchTest) + .exec(Actions.selectAfterBatch) .exec { s => - print(s("R").as[List[Map[String, Any]]]) + println(s("R").as[List[Map[String, Any]]]) + println(s("RR").as[List[Map[String, Any]]]) s }