Skip to content

Commit

Permalink
feat: added dsl for batches
Browse files Browse the repository at this point in the history
  • Loading branch information
red-bashmak committed May 21, 2021
1 parent 73dd1fd commit 4772472
Show file tree
Hide file tree
Showing 8 changed files with 222 additions and 24 deletions.
8 changes: 5 additions & 3 deletions src/main/scala/ru/tinkoff/load/jdbc/JdbcDsl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@ package ru.tinkoff.load.jdbc

import io.gatling.core.protocol.Protocol
import io.gatling.core.session.Expression
import ru.tinkoff.load.jdbc.actions.DBBaseAction

import ru.tinkoff.load.jdbc.actions.{BatchUpdateBaseAction, BatchInsertBaseAction, Columns, DBBaseAction}
import ru.tinkoff.load.jdbc.check.JdbcCheckSupport
import ru.tinkoff.load.jdbc.protocol.{JdbcProtocolBuilder, JdbcProtocolBuilderBase}

trait JdbcDsl extends JdbcCheckSupport{
trait JdbcDsl extends JdbcCheckSupport {
def DB: JdbcProtocolBuilderBase.type = JdbcProtocolBuilderBase
def jdbc(name: Expression[String]): DBBaseAction = DBBaseAction(name)
def insertInto(tableName: Expression[String], columns: Columns): BatchInsertBaseAction =
BatchInsertBaseAction(tableName, columns)
def update(tableName: Expression[String]): BatchUpdateBaseAction = BatchUpdateBaseAction(tableName)

implicit def jdbcProtocolBuilder2jdbcProtocol(builder: JdbcProtocolBuilder): Protocol = builder.build
}
93 changes: 93 additions & 0 deletions src/main/scala/ru/tinkoff/load/jdbc/actions/DBBatchAction.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package ru.tinkoff.load.jdbc.actions

import io.gatling.commons.stats.{KO, OK}
import io.gatling.core.action.{Action, ChainableAction}
import io.gatling.core.session.{Expression, Session}
import io.gatling.core.structure.ScenarioContext
import io.gatling.core.util.NameGen
import io.gatling.commons.validation._
import ru.tinkoff.load.jdbc.db.{SQL, SqlWithParam}

final case class DBBatchAction(
batchName: Expression[String],
actions: Seq[BatchAction],
next: Action,
ctx: ScenarioContext
) extends ChainableAction with NameGen with ActionBase {

private implicit class TrSeq[+T](seq: Seq[T]) {
def traverse[S](f: T => Validation[S]): Validation[Seq[S]] = seq.foldRight(Seq.empty[S].success)(
(i, r) => r.flatMap(s => f(i).map(s.prepended))
)
}

override def name: String = genName("jdbcBatchAction")

private def resolveParams(session: Session, values: Seq[(String, Expression[Any])]) =
values.traverse { case (k, v) => v(session).map((k, _)) }.map(_.toMap)

private def resolveBatchAction(session: Session): PartialFunction[BatchAction, Validation[SqlWithParam]] = {

case BatchUpdateAction(tableName, updateValues, None) =>
for {
tName <- tableName(session)
iParams <- resolveParams(session, updateValues)
sql <- SQL(s"UPDATE $tName SET ${iParams.map(c => s"${c._1} = {${c._1}}").mkString(",")}")
.withParamsMap(iParams)
.success
} yield sql

case BatchUpdateAction(tableName, updateValues, Some(whereExpression)) =>
for {
tName <- tableName(session)
iParams <- resolveParams(session, updateValues)
resolvedWhere <- whereExpression(session)
sql <- SQL(s"UPDATE $tName SET ${iParams.map(c => s"${c._1} = {${c._1}}").mkString(",")} WHERE $resolvedWhere")
.withParamsMap(iParams)
.success
} yield sql

case BatchInsertAction(tableName, columns, sessionValues) =>
for {
tName <- tableName(session)
iParams <- resolveParams(session, sessionValues)
sql <- SQL(
s"INSERT INTO $tName (${columns.names.mkString(",")}) VALUES(${columns.names.map(s => s"{$s}").mkString(",")})")
.withParamsMap(iParams)
.success
} yield sql
}

override protected def execute(session: Session): Unit =
(for {
resolvedBatchName <- batchName(session)
sqlQueriesWithParams <- actions.traverse(resolveBatchAction(session))
startTime <- ctx.coreComponents.clock.nowMillis.success
} yield
db.executeBatch(implicit c => c.batch(sqlQueriesWithParams))
.fold(
e => {
println(s"ERROR: ${e.getMessage}")
executeNext(session,
startTime,
ctx.coreComponents.clock.nowMillis,
KO,
next,
resolvedBatchName,
Some("ERROR"),
Some(e.getMessage))
},
_ => executeNext(session, startTime, ctx.coreComponents.clock.nowMillis, OK, next, resolvedBatchName, None, None)
)).onFailure(m =>
batchName(session).map { rn =>
ctx.coreComponents.statsEngine.logCrash(session.scenario, session.groups, rn, m)
executeNext(session,
ctx.coreComponents.clock.nowMillis,
ctx.coreComponents.clock.nowMillis,
KO,
next,
rn,
Some("ERROR"),
Some(m))
})
}
34 changes: 33 additions & 1 deletion src/main/scala/ru/tinkoff/load/jdbc/actions/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,25 @@ package object actions {
def rawSql(queryString: Expression[String]): RawSqlActionBuilder = RawSqlActionBuilder(requestName, queryString)

def queryP(sql: Expression[String]): QueryActionParamsStep = QueryActionParamsStep(requestName, sql)
def query(sql: Expression[String]): QueryActionBuilder = QueryActionBuilder(requestName, sql, params = Seq.empty)
def query(sql: Expression[String]): QueryActionBuilder = QueryActionBuilder(requestName, sql, params = Seq.empty)
def batch(actions: BatchAction*): BatchActionBuilder = BatchActionBuilder(requestName, actions)
}

final case class BatchInsertBaseAction(tableName: Expression[String], columns: Columns) {
def values(values: (String, Expression[Any])*): BatchInsertAction = BatchInsertAction(tableName, columns, values)
}

final case class BatchUpdateBaseAction(tableName: Expression[String]) {
def set(updateValues: (String, Expression[Any])*): BatchUpdateValuesStepAction =
BatchUpdateValuesStepAction(tableName, updateValues)
}

final case class BatchUpdateValuesStepAction(tableName: Expression[String], updateValues: Seq[(String, Expression[Any])]) {
def where(whereExpression: Expression[String]): BatchUpdateAction = {
BatchUpdateAction(tableName, updateValues, Some(whereExpression))
}

val all: BatchUpdateAction = BatchUpdateAction(tableName, updateValues)
}

case class QueryActionParamsStep(requestName: Expression[String], sql: Expression[String]) {
Expand Down Expand Up @@ -73,4 +91,18 @@ package object actions {
DBInsertAction(requestName, tableName, columns.names, next, ctx, sessionValues)
}

sealed trait BatchAction
final case class BatchInsertAction(tableName: Expression[String],
columns: Columns,
sessionValues: Seq[(String, Expression[Any])])
extends BatchAction

final case class BatchUpdateAction(tableName: Expression[String],
updateValues: Seq[(String, Expression[Any])],
where: Option[Expression[String]] = None)
extends BatchAction

final case class BatchActionBuilder(batchName: Expression[String], actions: Seq[BatchAction]) extends ActionBuilder {
override def build(ctx: ScenarioContext, next: Action): Action = DBBatchAction(batchName, actions, next, ctx)
}
}
11 changes: 11 additions & 0 deletions src/main/scala/ru/tinkoff/load/jdbc/db/ConnectedDB.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,15 @@ case class ConnectedDB(pool: HikariDataSource) {
result <- exec(ManagedConnection(c))
_ <- Try(c.close())
} yield result

def executeBatch(implicit exec: ManagedConnection => Try[List[Int]]): Try[List[Int]] =
for {
c <- Try(pool.getConnection)
autoCommit <- Try(c.getAutoCommit)
_ <- Try(c.setAutoCommit(false))
result <- exec(ManagedConnection(c))
_ <- Try(c.commit())
_ <- Try(c.setAutoCommit(autoCommit))
_ <- Try(c.close())
} yield result
}
18 changes: 18 additions & 0 deletions src/main/scala/ru/tinkoff/load/jdbc/db/ManagedConnection.scala
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,22 @@ case class ManagedConnection(connection: Connection) {
prepareStatement(sql, params.toMap, Map.empty, (sql, c) => c.prepareStatement(sql)),
s => Try(s.close())
)(stmt => Try(stmt.executeQuery().iterator.toList))

def batch(queries: Seq[SqlWithParam], batchSize: Int = 1000): Try[List[Int]] =
for {
stmt <- Try(connection.createStatement())
intermediateResult <- queries
.map(_.substituteParams)
.foldLeft(Try(List.empty[Int], 0))((r, q) =>
r.flatMap {
case (lr, counter) =>
for {
lr <- Try(if (counter % batchSize == 0) lr ++ stmt.executeBatch() else lr)
_ <- Try(stmt.addBatch(q))
} yield (lr, counter + 1)
})
.map(_._1)
r <- Try(intermediateResult ++ stmt.executeBatch())
_ <- Try(stmt.close())
} yield r
}
28 changes: 25 additions & 3 deletions src/main/scala/ru/tinkoff/load/jdbc/db/package.scala
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
package ru.tinkoff.load.jdbc

import java.sql.ResultSet
import java.time.LocalDateTime

import java.time.{LocalDateTime, OffsetDateTime}
import scala.util.Try

package object db {
Expand Down Expand Up @@ -38,13 +37,36 @@ package object db {
}

case class SqlWithParam(sql: String, params: Seq[(String, ParamVal)], outParams: Seq[(String, Int)] = Seq.empty) {
private val paramsMap = params.toMap
private def paramValueToSql(name: String) =
paramsMap.get(name) match {
case Some(IntParam(v)) => s"$v"
case Some(DoubleParam(v)) => s"$v"
case Some(StrParam(v)) => s"'$v'"
case Some(LongParam(v)) => s"$v"
case Some(NullParam) => "NULL"
case Some(DateParam(v)) => s"CAST(${v.toInstant(OffsetDateTime.now().getOffset).toEpochMilli} AS DATE)"
case None => ""
}

def substituteParams: String = {
sql
.foldLeft(("", "", false)) {
case ((r, curName, false), '{') => (r, curName, true)
case ((r, curName, true), '}') => (s"$r ${paramValueToSql(curName.trim)}", "", false)
case ((r, curName, true), c) => (r, s"$curName$c", true)
case ((r, curName, false), c) => (s"$r$c", curName, false)
}
._1
}

def withOutParams(ps: Seq[(String, Int)]): SqlWithParam = SqlWithParam(sql, params, ps)

def executeInsert(implicit managedConnection: ManagedConnection): Try[Int] = managedConnection.execute(sql, params)

def call(implicit managedConnection: ManagedConnection): Try[Int] = managedConnection.call(sql, params, outParams)

def executeQuery(implicit managedConnection: ManagedConnection): Try[List[Map[String,Any]]] =
def executeQuery(implicit managedConnection: ManagedConnection): Try[List[Map[String, Any]]] =
managedConnection.execSelect(sql, params)
}

Expand Down
49 changes: 33 additions & 16 deletions src/test/scala/ru/tinkoff/load/jdbc/test/cases/Actions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@ import ru.tinkoff.load.jdbc.actions.Columns

object Actions {

def createTable(): actions.RawSqlActionBuilder = jdbc("Create Table")
.rawSql("CREATE TABLE TEST_TABLE (ID INT PRIMARY KEY, NAME VARCHAR(64));")
def createTable(): actions.RawSqlActionBuilder =
jdbc("Create Table")
.rawSql("CREATE TABLE TEST_TABLE (ID INT PRIMARY KEY, NAME VARCHAR(64));")

def createProcedure(): actions.RawSqlActionBuilder = jdbc("Procedure create")
.rawSql("""CREATE ALIAS TEST_PROCEDURE AS $$
def createProcedure(): actions.RawSqlActionBuilder =
jdbc("Procedure create")
.rawSql("""CREATE ALIAS TEST_PROCEDURE AS $$
|String testProcedure(String p1, Long p2) {
| String suf = p1 + "test";
| return p2.toString() + suf;
Expand All @@ -28,18 +30,33 @@ object Actions {
.call("TEST_PROCEDURE")
.params("p1" -> "value1", "p2" -> 24L)

def selectTest: actions.QueryActionBuilder = jdbc("SELECT TEST")
.queryP("SELECT * FROM TEST_TABLE WHERE ID = {id}")
.params("id" -> 1)
.check(
allRecordsCheck{
r =>
r.isEmpty
},
allResults.is(List(
Map("ID"-> 1, "NAME" -> "Test3")
)),
allResults.saveAs("R")
def batchTest: actions.BatchActionBuilder = jdbc("Batch records").batch(
insertInto("TEST_TABLE", Columns("ID", "NAME")).values("ID" -> 2, "NAME" -> "Test 56"),
insertInto("TEST_TABLE", Columns("ID", "NAME")).values("ID" -> 3, "NAME" -> "Test 78"),
update("TEST_TABLE").set("NAME" -> "TEST 5").where("ID = 2"),
// update("TEST_TABLE").set("NAME" -> "bird").all
)

def selectTest: actions.QueryActionBuilder =
jdbc("SELECT TEST")
.queryP("SELECT * FROM TEST_TABLE WHERE ID = {id}")
.params("id" -> 1)
.check(
allRecordsCheck { r =>
r.isEmpty
},
allResults.is(
List(
Map("ID" -> 1, "NAME" -> "Test3")
)),
allResults.saveAs("R")
)

def selectAfterBatch: actions.QueryActionBuilder =
jdbc("SELECT SOME")
.query("SELECT * FROM TEST_TABLE")
.check(
allResults.saveAs("RR")
)

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@ class BasicSimulation {
.exec(Actions.insertTest())
.exec(Actions.callTest())
.exec(Actions.selectTest)
.exec(Actions.batchTest)
.exec(Actions.selectAfterBatch)
.exec { s =>
print(s("R").as[List[Map[String, Any]]])
println(s("R").as[List[Map[String, Any]]])
println(s("RR").as[List[Map[String, Any]]])
s
}

Expand Down

0 comments on commit 4772472

Please sign in to comment.