diff --git a/build.sbt b/build.sbt index 31c6c178ab..e03c3e740a 100644 --- a/build.sbt +++ b/build.sbt @@ -131,6 +131,9 @@ lazy val root = project scope.jvm, scope.js, `scope-examples`, + sql.jvm, + sql.js, + `sql-zio`, schema.jvm, schema.js, `schema-avro`, @@ -299,6 +302,45 @@ lazy val scope = crossProject(JSPlatform, JVMPlatform) coverageMinimumBranchTotal := 65 ) +lazy val sql = crossProject(JSPlatform, JVMPlatform) + .crossType(CrossType.Full) + .dependsOn(schema, scope) + .settings(stdSettings("zio-blocks-sql", Seq(BuildHelper.Scala3, BuildHelper.Scala33))) + .settings(crossProjectSettings) + .settings(buildInfoSettings("zio.blocks.sql")) + .enablePlugins(BuildInfoPlugin) + .jvmSettings(mimaSettings(failOnProblem = false)) + .jsSettings(jsSettings) + .settings( + libraryDependencies ++= Seq( + "dev.zio" %%% "zio-test" % "2.1.24" % Test, + "dev.zio" %%% "zio-test-sbt" % "2.1.24" % Test + ), + coverageMinimumStmtTotal := 0, + coverageMinimumBranchTotal := 0 + ) + .jvmSettings( + libraryDependencies ++= Seq( + "org.xerial" % "sqlite-jdbc" % "3.49.1.0" % Test, + "org.postgresql" % "postgresql" % "42.7.5" % Test + ) + ) + +lazy val `sql-zio` = project + .settings(stdSettings("zio-blocks-sql-zio", Seq(BuildHelper.Scala3, BuildHelper.Scala33))) + .dependsOn(sql.jvm) + .settings(buildInfoSettings("zio.blocks.sql.zio")) + .enablePlugins(BuildInfoPlugin) + .settings( + libraryDependencies ++= Seq( + "dev.zio" %% "zio" % "2.1.24", + "dev.zio" %% "zio-test" % "2.1.24" % Test, + "dev.zio" %% "zio-test-sbt" % "2.1.24" % Test + ), + coverageMinimumStmtTotal := 0, + coverageMinimumBranchTotal := 0 + ) + lazy val `scope-examples` = project .settings(stdSettings("zio-blocks-scope-examples", Seq(BuildHelper.Scala3))) .dependsOn(scope.jvm) diff --git a/sql-zio/src/main/scala/zio/blocks/sql/zio/TransactorZIO.scala b/sql-zio/src/main/scala/zio/blocks/sql/zio/TransactorZIO.scala new file mode 100644 index 0000000000..3bc97d1a9d --- /dev/null +++ b/sql-zio/src/main/scala/zio/blocks/sql/zio/TransactorZIO.scala @@ -0,0 +1,36 @@ +package zio.blocks.sql.zio + +import zio._ +import zio.blocks.sql._ + +class TransactorZIO(underlying: Transactor) { + + def connect[A](f: DbCon ?=> A): Task[A] = + ZIO.attemptBlocking(underlying.connect(f)) + + def transact[A](f: DbTx ?=> A): Task[A] = + ZIO.attemptBlocking(underlying.transact(f)) +} + +object TransactorZIO { + def fromTransactor(transactor: Transactor): TransactorZIO = + new TransactorZIO(transactor) + + def fromUrl(url: String, dialect: SqlDialect): TransactorZIO = + new TransactorZIO(JdbcTransactor.fromUrl(url, dialect)) + + def fromUrl( + url: String, + user: String, + password: String, + dialect: SqlDialect + ): TransactorZIO = + new TransactorZIO(JdbcTransactor.fromUrl(url, user, password, dialect)) + + def fromDataSource(dataSource: javax.sql.DataSource, dialect: SqlDialect): TransactorZIO = + new TransactorZIO(JdbcTransactor.fromDataSource(dataSource, dialect)) + + // ZLayer for dependency injection + def layer(url: String, dialect: SqlDialect): ZLayer[Any, Nothing, TransactorZIO] = + ZLayer.succeed(fromUrl(url, dialect)) +} diff --git a/sql/js/src/main/scala/zio/blocks/sql/.gitkeep b/sql/js/src/main/scala/zio/blocks/sql/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sql/jvm/src/main/scala/zio/blocks/sql/.gitkeep b/sql/jvm/src/main/scala/zio/blocks/sql/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sql/jvm/src/main/scala/zio/blocks/sql/JdbcConnection.scala b/sql/jvm/src/main/scala/zio/blocks/sql/JdbcConnection.scala new file mode 100644 index 0000000000..f0489c4807 --- /dev/null +++ b/sql/jvm/src/main/scala/zio/blocks/sql/JdbcConnection.scala @@ -0,0 +1,46 @@ +package zio.blocks.sql + +import java.sql.Connection + +class JdbcConnection(val underlying: Connection) extends DbConnection { + + def prepareStatement(sql: String): DbPreparedStatement = + new JdbcPreparedStatement(underlying.prepareStatement(sql)) + + def close(): Unit = underlying.close() + + def isClosed: Boolean = underlying.isClosed + + def setAutoCommit(autoCommit: Boolean): Unit = underlying.setAutoCommit(autoCommit) + + def getAutoCommit: Boolean = underlying.getAutoCommit + + def commit(): Unit = underlying.commit() + + def rollback(): Unit = underlying.rollback() +} + +class JdbcPreparedStatement(val underlying: java.sql.PreparedStatement) extends DbPreparedStatement { + + def executeQuery(): DbResultSet = + new JdbcResultSet(underlying.executeQuery()) + + def executeUpdate(): Int = underlying.executeUpdate() + + def close(): Unit = underlying.close() + + def paramWriter: DbParamWriter = new JdbcParamWriter(underlying) + + def addBatch(): Unit = underlying.addBatch() + + def executeBatch(): Array[Int] = underlying.executeBatch() +} + +class JdbcResultSet(val underlying: java.sql.ResultSet) extends DbResultSet { + + def next(): Boolean = underlying.next() + + def close(): Unit = underlying.close() + + def reader: DbResultReader = new JdbcResultReader(underlying) +} diff --git a/sql/jvm/src/main/scala/zio/blocks/sql/JdbcParamWriter.scala b/sql/jvm/src/main/scala/zio/blocks/sql/JdbcParamWriter.scala new file mode 100644 index 0000000000..c31be3556e --- /dev/null +++ b/sql/jvm/src/main/scala/zio/blocks/sql/JdbcParamWriter.scala @@ -0,0 +1,41 @@ +package zio.blocks.sql + +import java.sql.PreparedStatement + +class JdbcParamWriter(val underlying: PreparedStatement) extends DbParamWriter { + + def setInt(index: Int, value: Int): Unit = underlying.setInt(index, value) + + def setLong(index: Int, value: Long): Unit = underlying.setLong(index, value) + + def setDouble(index: Int, value: Double): Unit = underlying.setDouble(index, value) + + def setFloat(index: Int, value: Float): Unit = underlying.setFloat(index, value) + + def setBoolean(index: Int, value: Boolean): Unit = underlying.setBoolean(index, value) + + def setString(index: Int, value: String): Unit = underlying.setString(index, value) + + def setBigDecimal(index: Int, value: java.math.BigDecimal): Unit = underlying.setBigDecimal(index, value) + + def setBytes(index: Int, value: Array[Byte]): Unit = underlying.setBytes(index, value) + + def setShort(index: Int, value: Short): Unit = underlying.setShort(index, value) + + def setByte(index: Int, value: Byte): Unit = underlying.setByte(index, value) + + def setLocalDate(index: Int, value: java.time.LocalDate): Unit = underlying.setObject(index, value) + + def setLocalDateTime(index: Int, value: java.time.LocalDateTime): Unit = underlying.setObject(index, value) + + def setLocalTime(index: Int, value: java.time.LocalTime): Unit = underlying.setObject(index, value) + + def setInstant(index: Int, value: java.time.Instant): Unit = + underlying.setTimestamp(index, java.sql.Timestamp.from(value)) + + def setDuration(index: Int, value: java.time.Duration): Unit = underlying.setString(index, value.toString) + + def setUUID(index: Int, value: java.util.UUID): Unit = underlying.setObject(index, value) + + def setNull(index: Int, sqlType: Int): Unit = underlying.setNull(index, sqlType) +} diff --git a/sql/jvm/src/main/scala/zio/blocks/sql/JdbcResultReader.scala b/sql/jvm/src/main/scala/zio/blocks/sql/JdbcResultReader.scala new file mode 100644 index 0000000000..631a3fa4e6 --- /dev/null +++ b/sql/jvm/src/main/scala/zio/blocks/sql/JdbcResultReader.scala @@ -0,0 +1,51 @@ +package zio.blocks.sql + +import java.sql.ResultSet +import java.util.UUID + +class JdbcResultReader(val underlying: ResultSet) extends DbResultReader { + + def getInt(index: Int): Int = underlying.getInt(index) + + def getLong(index: Int): Long = underlying.getLong(index) + + def getDouble(index: Int): Double = underlying.getDouble(index) + + def getFloat(index: Int): Float = underlying.getFloat(index) + + def getBoolean(index: Int): Boolean = underlying.getBoolean(index) + + def getString(index: Int): String = underlying.getString(index) + + def getBigDecimal(index: Int): java.math.BigDecimal = underlying.getBigDecimal(index) + + def getBytes(index: Int): Array[Byte] = underlying.getBytes(index) + + def getShort(index: Int): Short = underlying.getShort(index) + + def getByte(index: Int): Byte = underlying.getByte(index) + + def getLocalDate(index: Int): java.time.LocalDate = underlying.getObject(index, classOf[java.time.LocalDate]) + + def getLocalDateTime(index: Int): java.time.LocalDateTime = + underlying.getObject(index, classOf[java.time.LocalDateTime]) + + def getLocalTime(index: Int): java.time.LocalTime = underlying.getObject(index, classOf[java.time.LocalTime]) + + def getInstant(index: Int): java.time.Instant = { + val ts = underlying.getTimestamp(index) + if (ts == null) null else ts.toInstant + } + + def getDuration(index: Int): java.time.Duration = { + val s = underlying.getString(index) + if (s == null) null else java.time.Duration.parse(s) + } + + def getUUID(index: Int): UUID = { + val s = underlying.getString(index) + if (s == null) null else UUID.fromString(s) + } + + def wasNull: Boolean = underlying.wasNull() +} diff --git a/sql/jvm/src/main/scala/zio/blocks/sql/JdbcTransactor.scala b/sql/jvm/src/main/scala/zio/blocks/sql/JdbcTransactor.scala new file mode 100644 index 0000000000..60dae3966d --- /dev/null +++ b/sql/jvm/src/main/scala/zio/blocks/sql/JdbcTransactor.scala @@ -0,0 +1,62 @@ +package zio.blocks.sql + +import java.sql.{Connection, DriverManager} + +class JdbcTransactor( + connectionFactory: () => Connection, + val dialect: SqlDialect, + val sqlLogger: SqlLogger = SqlLogger.noop +) extends Transactor { + + def connect[A](f: DbCon ?=> A): A = { + val conn = connectionFactory() + val dbConn = new JdbcConnection(conn) + try { + given con: DbCon = new DbCon { + val connection: DbConnection = dbConn + val dialect: SqlDialect = JdbcTransactor.this.dialect + val logger: SqlLogger = JdbcTransactor.this.sqlLogger + } + f + } finally { + try dbConn.close() + catch { case _: Throwable => () } + } + } + + def transact[A](f: DbTx ?=> A): A = { + val conn = connectionFactory() + val dbConn = new JdbcConnection(conn) + conn.setAutoCommit(false) + try { + given tx: DbTx = new DbTx { + val connection: DbConnection = dbConn + val dialect: SqlDialect = JdbcTransactor.this.dialect + val logger: SqlLogger = JdbcTransactor.this.sqlLogger + } + val result = f + conn.commit() + result + } catch { + case e: Throwable => + try conn.rollback() + catch { case rb: Throwable => e.addSuppressed(rb) } + throw e + } finally { + try dbConn.close() + catch { case _: Throwable => () } + } + } +} + +object JdbcTransactor { + + def fromUrl(url: String, dialect: SqlDialect): JdbcTransactor = + new JdbcTransactor(() => DriverManager.getConnection(url), dialect) + + def fromUrl(url: String, user: String, password: String, dialect: SqlDialect): JdbcTransactor = + new JdbcTransactor(() => DriverManager.getConnection(url, user, password), dialect) + + def fromDataSource(dataSource: javax.sql.DataSource, dialect: SqlDialect): JdbcTransactor = + new JdbcTransactor(() => dataSource.getConnection, dialect) +} diff --git a/sql/jvm/src/test/scala/zio/blocks/sql/.gitkeep b/sql/jvm/src/test/scala/zio/blocks/sql/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sql/jvm/src/test/scala/zio/blocks/sql/RepoIntegrationSpec.scala b/sql/jvm/src/test/scala/zio/blocks/sql/RepoIntegrationSpec.scala new file mode 100644 index 0000000000..7f626a3abf --- /dev/null +++ b/sql/jvm/src/test/scala/zio/blocks/sql/RepoIntegrationSpec.scala @@ -0,0 +1,487 @@ +package zio.blocks.sql + +import zio.test._ +import zio.blocks.schema._ +import java.sql.DriverManager +import scala.collection.mutable.ArrayBuffer + +object RepoIntegrationSpec extends ZIOSpecDefault { + private val _ = Class.forName("org.sqlite.JDBC") + + case class User(id: Int, name: String, email: String) + object User { + implicit val schema: Schema[User] = Schema.derived + } + + enum Priority { + case Low, Medium, High + } + object Priority { + implicit val schema: Schema[Priority] = Schema.derived + } + + case class Task(id: Int, title: String, priority: Priority) + object Task { + implicit val schema: Schema[Task] = Schema.derived + } + + private val userTable = Table.derived[User](SqlDialect.SQLite) + private val taskTable = Table.derived[Task](SqlDialect.SQLite) + + private given DbCodec[User] = User.schema.deriving(DbCodecDeriver).derive + private given DbCodec[Task] = Task.schema.deriving(DbCodecDeriver).derive + private given DbCodec[Int] = implicitly[Schema[Int]].deriving(DbCodecDeriver).derive + private given DbCodec[Priority] = Priority.schema.deriving(DbCodecDeriver).derive + + private val intCodec: DbCodec[Int] = summon[DbCodec[Int]] + + private val userRepo = Repo(userTable, "id", intCodec, (_: User).id) + private val taskRepo = Repo(taskTable, "id", intCodec, (_: Task).id) + + private def withFreshDb[A](f: JdbcTransactor => A): A = { + val conn = DriverManager.getConnection("jdbc:sqlite::memory:") + val tx = new JdbcTransactor(() => conn, SqlDialect.SQLite) { + override def connect[B](f: DbCon ?=> B): B = { + val dbConn = new JdbcConnection(conn) + given con: DbCon = new DbCon { + val connection: DbConnection = dbConn + val dialect: SqlDialect = SqlDialect.SQLite + val logger: SqlLogger = SqlLogger.noop + } + f + } + } + tx.connect { + SqlOps.update( + Frag.const( + "CREATE TABLE IF NOT EXISTS user (id INTEGER NOT NULL, name TEXT NOT NULL, email TEXT NOT NULL)" + ) + ) + SqlOps.update( + Frag.const( + "CREATE TABLE IF NOT EXISTS task (id INTEGER NOT NULL, title TEXT NOT NULL, priority TEXT NOT NULL)" + ) + ) + } + f(tx) + } + + private def withFreshDbAndLogger[A](f: (JdbcTransactor, CapturingLogger) => A): A = { + val conn = DriverManager.getConnection("jdbc:sqlite::memory:") + val testLogger = new CapturingLogger + val tx = new JdbcTransactor(() => conn, SqlDialect.SQLite) { + override def connect[B](f: DbCon ?=> B): B = { + val dbConn = new JdbcConnection(conn) + given con: DbCon = new DbCon { + val connection: DbConnection = dbConn + val dialect: SqlDialect = SqlDialect.SQLite + val logger: SqlLogger = testLogger + } + f + } + } + tx.connect { + SqlOps.update( + Frag.const( + "CREATE TABLE IF NOT EXISTS user (id INTEGER NOT NULL, name TEXT NOT NULL, email TEXT NOT NULL)" + ) + ) + SqlOps.update( + Frag.const( + "CREATE TABLE IF NOT EXISTS task (id INTEGER NOT NULL, title TEXT NOT NULL, priority TEXT NOT NULL)" + ) + ) + } + testLogger.clear() + f(tx, testLogger) + } + + private class CapturingLogger extends SqlLogger { + val successes: ArrayBuffer[SqlLogger.SuccessEvent] = ArrayBuffer.empty + val errors: ArrayBuffer[SqlLogger.ErrorEvent] = ArrayBuffer.empty + + def onSuccess(event: SqlLogger.SuccessEvent): Unit = successes += event + def onError(event: SqlLogger.ErrorEvent): Unit = errors += event + + def clear(): Unit = { + successes.clear() + errors.clear() + } + } + + def spec: Spec[TestEnvironment, Any] = suite("RepoIntegrationSpec")( + test("insert and findById roundtrip") { + withFreshDb { tx => + tx.connect { + userRepo.insert(User(1, "Alice", "alice@test.com")) + val found = userRepo.findById(1) + assertTrue( + found.isDefined, + found.get.id == 1, + found.get.name == "Alice", + found.get.email == "alice@test.com" + ) + } + } + }, + test("insert returns 1 for single insert") { + withFreshDb { tx => + tx.connect { + val rows = userRepo.insert(User(1, "Alice", "alice@test.com")) + assertTrue(rows == 1) + } + } + }, + test("findById returns None for non-existing") { + withFreshDb { tx => + tx.connect { + val found = userRepo.findById(999) + assertTrue(found.isEmpty) + } + } + }, + test("findAll returns all inserted rows") { + withFreshDb { tx => + tx.connect { + userRepo.insert(User(1, "Alice", "a@test.com")) + userRepo.insert(User(2, "Bob", "b@test.com")) + userRepo.insert(User(3, "Charlie", "c@test.com")) + val all = userRepo.findAll + assertTrue(all.size == 3) + } + } + }, + test("count returns correct count") { + withFreshDb { tx => + tx.connect { + assertTrue(userRepo.count == 0L) + } + } + }, + test("count reflects insertions") { + withFreshDb { tx => + tx.connect { + userRepo.insert(User(1, "Alice", "a@test.com")) + userRepo.insert(User(2, "Bob", "b@test.com")) + assertTrue(userRepo.count == 2L) + } + } + }, + test("existsById returns true for existing, false for missing") { + withFreshDb { tx => + tx.connect { + userRepo.insert(User(1, "Alice", "a@test.com")) + assertTrue( + userRepo.existsById(1), + !userRepo.existsById(999) + ) + } + } + }, + test("update modifies existing row") { + withFreshDb { tx => + tx.connect { + userRepo.insert(User(1, "Alice", "old@test.com")) + userRepo.update(User(1, "Alice Updated", "new@test.com")) + val found = userRepo.findById(1) + assertTrue( + found.isDefined, + found.get.name == "Alice Updated", + found.get.email == "new@test.com" + ) + } + } + }, + test("update returns 1 for existing, 0 for non-existing") { + withFreshDb { tx => + tx.connect { + userRepo.insert(User(1, "Alice", "a@test.com")) + val updated = userRepo.update(User(1, "Alice Updated", "new@test.com")) + val notUpdated = userRepo.update(User(999, "Ghost", "ghost@test.com")) + assertTrue(updated == 1, notUpdated == 0) + } + } + }, + test("deleteById removes the row") { + withFreshDb { tx => + tx.connect { + userRepo.insert(User(1, "Alice", "a@test.com")) + userRepo.deleteById(1) + assertTrue(userRepo.findById(1).isEmpty) + } + } + }, + test("deleteById returns 1 for existing, 0 for non-existing") { + withFreshDb { tx => + tx.connect { + userRepo.insert(User(1, "Alice", "a@test.com")) + val deleted = userRepo.deleteById(1) + val notDeleted = userRepo.deleteById(999) + assertTrue(deleted == 1, notDeleted == 0) + } + } + }, + test("delete removes by entity") { + withFreshDb { tx => + tx.connect { + val user = User(1, "Alice", "a@test.com") + userRepo.insert(user) + userRepo.delete(user) + assertTrue(userRepo.findById(1).isEmpty) + } + } + }, + test("truncate removes all rows") { + withFreshDb { tx => + tx.connect { + userRepo.insert(User(1, "Alice", "a@test.com")) + userRepo.insert(User(2, "Bob", "b@test.com")) + userRepo.truncate() + assertTrue( + userRepo.count == 0L, + userRepo.findAll.isEmpty + ) + } + } + }, + test("full CRUD lifecycle") { + withFreshDb { tx => + tx.connect { + // Create + userRepo.insert(User(1, "Alice", "alice@test.com")) + userRepo.insert(User(2, "Bob", "bob@test.com")) + assertTrue(userRepo.count == 2L) + + // Read + val alice = userRepo.findById(1) + assertTrue(alice.get.name == "Alice") + + // Update + userRepo.update(User(1, "Alice Smith", "alice.smith@test.com")) + val updated = userRepo.findById(1) + assertTrue(updated.get.name == "Alice Smith") + + // Delete + userRepo.deleteById(2) + assertTrue(userRepo.count == 1L) + + // Truncate + userRepo.truncate() + assertTrue(userRepo.count == 0L) + } + } + }, + suite("enum integration")( + test("insert and findById with enum field") { + withFreshDb { tx => + tx.connect { + taskRepo.insert(Task(1, "Write tests", Priority.High)) + val found = taskRepo.findById(1) + assertTrue( + found.isDefined, + found.get.id == 1, + found.get.title == "Write tests", + found.get.priority == Priority.High + ) + } + } + }, + test("enum values round-trip through database") { + withFreshDb { tx => + tx.connect { + taskRepo.insert(Task(1, "Low task", Priority.Low)) + taskRepo.insert(Task(2, "Medium task", Priority.Medium)) + taskRepo.insert(Task(3, "High task", Priority.High)) + + val t1 = taskRepo.findById(1) + val t2 = taskRepo.findById(2) + val t3 = taskRepo.findById(3) + + assertTrue( + t1.get.priority == Priority.Low, + t2.get.priority == Priority.Medium, + t3.get.priority == Priority.High + ) + } + } + }, + test("update enum field") { + withFreshDb { tx => + tx.connect { + taskRepo.insert(Task(1, "A task", Priority.Low)) + taskRepo.update(Task(1, "A task", Priority.High)) + val found = taskRepo.findById(1) + assertTrue(found.get.priority == Priority.High) + } + } + }, + test("findAll with enum fields") { + withFreshDb { tx => + tx.connect { + taskRepo.insert(Task(1, "Task 1", Priority.Low)) + taskRepo.insert(Task(2, "Task 2", Priority.High)) + val all = taskRepo.findAll + assertTrue(all.size == 2) + } + } + } + ), + suite("insertReturning")( + test("returns the inserted entity") { + withFreshDb { tx => + tx.connect { + val returned = userRepo.insertReturning(User(1, "Alice", "alice@test.com")) + assertTrue( + returned.id == 1, + returned.name == "Alice", + returned.email == "alice@test.com" + ) + } + } + }, + test("entity exists in database after insertReturning") { + withFreshDb { tx => + tx.connect { + userRepo.insertReturning(User(1, "Bob", "bob@test.com")) + val found = userRepo.findById(1) + assertTrue( + found.isDefined, + found.get.name == "Bob" + ) + } + } + } + ), + suite("insertAll")( + test("inserts multiple entities in batch") { + withFreshDb { tx => + tx.connect { + val count = userRepo.insertAll( + List( + User(1, "Alice", "a@test.com"), + User(2, "Bob", "b@test.com"), + User(3, "Charlie", "c@test.com") + ) + ) + assertTrue(count == 3) + } + } + }, + test("all entities are queryable after insertAll") { + withFreshDb { tx => + tx.connect { + userRepo.insertAll( + List( + User(1, "Alice", "a@test.com"), + User(2, "Bob", "b@test.com"), + User(3, "Charlie", "c@test.com") + ) + ) + val all = userRepo.findAll + assertTrue( + all.size == 3, + all.map(_.name).toSet == Set("Alice", "Bob", "Charlie") + ) + } + } + }, + test("insertAll with empty iterable returns 0") { + withFreshDb { tx => + tx.connect { + val count = userRepo.insertAll(List.empty[User]) + assertTrue(count == 0, userRepo.count == 0L) + } + } + }, + test("insertAll with single entity") { + withFreshDb { tx => + tx.connect { + val count = userRepo.insertAll(List(User(1, "Solo", "solo@test.com"))) + assertTrue(count == 1, userRepo.count == 1L) + } + } + } + ), + suite("SqlLogger")( + test("onSuccess is called for insert") { + withFreshDbAndLogger { (tx, logger) => + tx.connect { + userRepo.insert(User(1, "Alice", "a@test.com")) + assertTrue( + logger.successes.size == 1, + logger.successes.head.sql.contains("INSERT INTO"), + logger.successes.head.rowCount == 1, + logger.errors.isEmpty + ) + } + } + }, + test("onSuccess is called for query") { + withFreshDbAndLogger { (tx, logger) => + tx.connect { + userRepo.insert(User(1, "Alice", "a@test.com")) + logger.clear() + val _ = userRepo.findAll + assertTrue( + logger.successes.size == 1, + logger.successes.head.sql.contains("SELECT"), + logger.successes.head.rowCount == 1, + logger.errors.isEmpty + ) + } + } + }, + test("onError is called on SQL error") { + withFreshDbAndLogger { (tx, logger) => + tx.connect { + try { + SqlOps.update(Frag.const("INSERT INTO nonexistent_table (id) VALUES (1)")) + } catch { + case _: Throwable => () + } + assertTrue( + logger.errors.size == 1, + logger.errors.head.sql.contains("nonexistent_table"), + logger.successes.isEmpty + ) + } + } + }, + test("duration is non-negative") { + withFreshDbAndLogger { (tx, logger) => + tx.connect { + userRepo.insert(User(1, "Alice", "a@test.com")) + assertTrue(!logger.successes.head.duration.isNegative) + } + } + }, + test("insertAll logs success") { + withFreshDbAndLogger { (tx, logger) => + tx.connect { + userRepo.insertAll( + List( + User(1, "Alice", "a@test.com"), + User(2, "Bob", "b@test.com") + ) + ) + assertTrue( + logger.successes.size == 1, + logger.successes.head.sql.contains("INSERT INTO"), + logger.successes.head.rowCount == 2, + logger.errors.isEmpty + ) + } + } + }, + test("noop logger does not throw") { + withFreshDb { tx => + tx.connect { + userRepo.insert(User(1, "Alice", "a@test.com")) + val _ = userRepo.findAll + assertTrue(true) + } + } + } + ) + ) +} diff --git a/sql/jvm/src/test/scala/zio/blocks/sql/TransactorSpec.scala b/sql/jvm/src/test/scala/zio/blocks/sql/TransactorSpec.scala new file mode 100644 index 0000000000..1d0e8a84a2 --- /dev/null +++ b/sql/jvm/src/test/scala/zio/blocks/sql/TransactorSpec.scala @@ -0,0 +1,501 @@ +package zio.blocks.sql + +import zio.test.* +import zio.blocks.schema.* +import java.sql.DriverManager + +object TransactorSpec extends ZIOSpecDefault { + private val _ = Class.forName("org.sqlite.JDBC") + + case class User(id: Int, name: String, email: String) + object User { + implicit val schema: Schema[User] = Schema.derived + } + + case class AllTypes( + intVal: Int, + longVal: Long, + doubleVal: Double, + floatVal: Float, + boolVal: Boolean, + strVal: String, + shortVal: Short, + byteVal: Byte + ) + object AllTypes { + implicit val schema: Schema[AllTypes] = Schema.derived + } + + case class WithOption(id: Int, nickname: Option[String]) + object WithOption { + implicit val schema: Schema[WithOption] = Schema.derived + } + + private val transactor = JdbcTransactor.fromUrl("jdbc:sqlite::memory:", SqlDialect.SQLite) + + private given DbCodec[User] = User.schema.deriving(DbCodecDeriver).derive + private given DbCodec[AllTypes] = AllTypes.schema.deriving(DbCodecDeriver).derive + private given DbCodec[WithOption] = WithOption.schema.deriving(DbCodecDeriver).derive + + private given DbCodec[Int] = implicitly[Schema[Int]].deriving(DbCodecDeriver).derive + private given DbCodec[String] = implicitly[Schema[String]].deriving(DbCodecDeriver).derive + private given DbCodec[Long] = implicitly[Schema[Long]].deriving(DbCodecDeriver).derive + private given DbCodec[Double] = implicitly[Schema[Double]].deriving(DbCodecDeriver).derive + private given DbCodec[Boolean] = + implicitly[Schema[Boolean]].deriving(DbCodecDeriver).derive + + private def sharedConnTransactor(): (JdbcTransactor, java.sql.Connection) = { + val conn = DriverManager.getConnection("jdbc:sqlite::memory:") + val tx = new JdbcTransactor(() => conn, SqlDialect.SQLite) { + override def connect[A](f: DbCon ?=> A): A = { + val dbConn = new JdbcConnection(conn) + given con: DbCon = new DbCon { + val connection: DbConnection = dbConn + val dialect: SqlDialect = SqlDialect.SQLite + val logger: SqlLogger = SqlLogger.noop + } + f + } + + override def transact[A](f: DbTx ?=> A): A = { + val dbConn = new JdbcConnection(conn) + conn.setAutoCommit(false) + try { + given tx: DbTx = new DbTx { + val connection: DbConnection = dbConn + val dialect: SqlDialect = SqlDialect.SQLite + val logger: SqlLogger = SqlLogger.noop + } + val result = f + conn.commit() + result + } catch { + case e: Throwable => + conn.rollback() + throw e + } finally conn.setAutoCommit(true) + } + } + (tx, conn) + } + + def spec: Spec[TestEnvironment, Any] = suite("TransactorSpec")( + test("connect executes queries") { + transactor.connect { + SqlOps.update(Frag.const("CREATE TABLE IF NOT EXISTS test_connect (id INTEGER NOT NULL)")) + SqlOps.update( + sql"INSERT INTO test_connect (id) VALUES (${DbValue.DbInt(1)})" + ) + val ids = SqlOps.query[Int](sql"SELECT id FROM test_connect") + assertTrue(ids == List(1)) + } + }, + test("INSERT and SELECT roundtrip") { + transactor.connect { + SqlOps.update( + Frag.const( + "CREATE TABLE IF NOT EXISTS users (id INTEGER NOT NULL, name TEXT NOT NULL, email TEXT NOT NULL)" + ) + ) + SqlOps.update( + sql"INSERT INTO users (id, name, email) VALUES (${DbValue.DbInt(1)}, ${DbValue.DbString("Alice")}, ${DbValue.DbString("alice@example.com")})" + ) + SqlOps.update( + sql"INSERT INTO users (id, name, email) VALUES (${DbValue.DbInt(2)}, ${DbValue.DbString("Bob")}, ${DbValue.DbString("bob@example.com")})" + ) + val users = SqlOps.query[User](sql"SELECT id, name, email FROM users ORDER BY id") + assertTrue( + users.length == 2, + users.head.id == 1, + users.head.name == "Alice", + users.head.email == "alice@example.com", + users(1).name == "Bob" + ) + } + }, + test("queryOne returns first result") { + transactor.connect { + SqlOps.update( + Frag.const( + "CREATE TABLE IF NOT EXISTS query_one_test (id INTEGER NOT NULL, val TEXT NOT NULL)" + ) + ) + SqlOps.update( + sql"INSERT INTO query_one_test (id, val) VALUES (${DbValue.DbInt(1)}, ${DbValue.DbString("first")})" + ) + SqlOps.update( + sql"INSERT INTO query_one_test (id, val) VALUES (${DbValue.DbInt(2)}, ${DbValue.DbString("second")})" + ) + val result = SqlOps.queryOne[String]( + sql"SELECT val FROM query_one_test WHERE id = ${DbValue.DbInt(1)}" + ) + assertTrue(result == Some("first")) + } + }, + test("empty result returns empty List") { + transactor.connect { + SqlOps.update(Frag.const("CREATE TABLE IF NOT EXISTS empty_test (id INTEGER NOT NULL)")) + val result = SqlOps.query[Int](sql"SELECT id FROM empty_test") + assertTrue(result.isEmpty) + } + }, + test("queryOne on empty result returns None") { + transactor.connect { + SqlOps.update(Frag.const("CREATE TABLE IF NOT EXISTS empty_one_test (id INTEGER NOT NULL)")) + val result = SqlOps.queryOne[Int](sql"SELECT id FROM empty_one_test") + assertTrue(result.isEmpty) + } + }, + test("transaction commits on success") { + val (tx, conn) = sharedConnTransactor() + try { + tx.connect { + SqlOps.update(Frag.const("CREATE TABLE tx_commit (id INTEGER NOT NULL, name TEXT NOT NULL)")) + } + tx.transact { + SqlOps.update( + sql"INSERT INTO tx_commit (id, name) VALUES (${DbValue.DbInt(1)}, ${DbValue.DbString("committed")})" + ) + } + tx.connect { + val rows = SqlOps.query[String](sql"SELECT name FROM tx_commit") + assertTrue(rows == List("committed")) + } + } finally conn.close() + }, + test("transaction rolls back on exception") { + val (tx, conn) = sharedConnTransactor() + try { + tx.connect { + SqlOps.update(Frag.const("CREATE TABLE tx_rollback (id INTEGER NOT NULL, name TEXT NOT NULL)")) + SqlOps.update( + sql"INSERT INTO tx_rollback (id, name) VALUES (${DbValue.DbInt(0)}, ${DbValue.DbString("before")})" + ) + } + try { + tx.transact { + SqlOps.update( + sql"INSERT INTO tx_rollback (id, name) VALUES (${DbValue.DbInt(1)}, ${DbValue.DbString("should_rollback")})" + ) + throw new RuntimeException("forced error") + } + } catch { + case _: RuntimeException => () + } + tx.connect { + val rows = SqlOps.query[String](sql"SELECT name FROM tx_rollback WHERE id = ${DbValue.DbInt(1)}") + assertTrue(rows.isEmpty) + } + } finally conn.close() + }, + test("update returns affected row count") { + transactor.connect { + SqlOps.update(Frag.const("CREATE TABLE IF NOT EXISTS count_test (id INTEGER NOT NULL)")) + SqlOps.update( + sql"INSERT INTO count_test (id) VALUES (${DbValue.DbInt(1)})" + ) + SqlOps.update( + sql"INSERT INTO count_test (id) VALUES (${DbValue.DbInt(2)})" + ) + val deleted = SqlOps.update(Frag.const("DELETE FROM count_test")) + assertTrue(deleted == 2) + } + }, + suite("type roundtrip tests")( + test("Long roundtrip") { + transactor.connect { + SqlOps.update(Frag.const("CREATE TABLE rt_long (v INTEGER NOT NULL)")) + SqlOps.update(sql"INSERT INTO rt_long (v) VALUES (${DbValue.DbLong(9876543210L)})") + val result = SqlOps.query[Long](sql"SELECT v FROM rt_long") + assertTrue(result == List(9876543210L)) + } + }, + test("Double roundtrip") { + transactor.connect { + SqlOps.update(Frag.const("CREATE TABLE rt_double (v REAL NOT NULL)")) + SqlOps.update(sql"INSERT INTO rt_double (v) VALUES (${DbValue.DbDouble(3.14159)})") + val result = SqlOps.query[Double](sql"SELECT v FROM rt_double") + assertTrue(result == List(3.14159)) + } + }, + test("Boolean roundtrip") { + transactor.connect { + SqlOps.update(Frag.const("CREATE TABLE rt_bool (v INTEGER NOT NULL)")) + SqlOps.update(sql"INSERT INTO rt_bool (v) VALUES (${DbValue.DbBoolean(true)})") + SqlOps.update(sql"INSERT INTO rt_bool (v) VALUES (${DbValue.DbBoolean(false)})") + val result = SqlOps.query[Boolean](sql"SELECT v FROM rt_bool ORDER BY v") + assertTrue(result == List(false, true)) + } + }, + test("all primitive types roundtrip") { + transactor.connect { + SqlOps.update( + Frag.const( + "CREATE TABLE rt_all (" + + "int_val INTEGER NOT NULL, " + + "long_val INTEGER NOT NULL, " + + "double_val REAL NOT NULL, " + + "float_val REAL NOT NULL, " + + "bool_val INTEGER NOT NULL, " + + "str_val TEXT NOT NULL, " + + "short_val INTEGER NOT NULL, " + + "byte_val INTEGER NOT NULL)" + ) + ) + SqlOps.update( + sql"INSERT INTO rt_all VALUES (${DbValue.DbInt(42)}, ${DbValue.DbLong(123456789L)}, ${DbValue.DbDouble(2.718)}, ${DbValue.DbFloat(1.5f)}, ${DbValue.DbBoolean(true)}, ${DbValue.DbString("test")}, ${DbValue.DbShort(100.toShort)}, ${DbValue.DbByte(7.toByte)})" + ) + val result = SqlOps.query[AllTypes](sql"SELECT * FROM rt_all") + assertTrue( + result.length == 1, + result.head.intVal == 42, + result.head.longVal == 123456789L, + result.head.doubleVal == 2.718, + result.head.floatVal == 1.5f, + result.head.boolVal == true, + result.head.strVal == "test", + result.head.shortVal == 100.toShort, + result.head.byteVal == 7.toByte + ) + } + }, + test("DbNull writeParams via Option None roundtrip") { + transactor.connect { + SqlOps.update(Frag.const("CREATE TABLE rt_null (id INTEGER NOT NULL, nick TEXT)")) + SqlOps.update( + sql"INSERT INTO rt_null (id, nick) VALUES (${DbValue.DbInt(1)}, ${DbValue.DbString("present")})" + ) + SqlOps.update( + sql"INSERT INTO rt_null (id, nick) VALUES (${DbValue.DbInt(2)}, ${DbValue.DbNull})" + ) + val results = SqlOps.query[WithOption](sql"SELECT id, nick FROM rt_null ORDER BY id") + assertTrue( + results.length == 2, + results(0) == WithOption(1, Some("present")), + results(1) == WithOption(2, None) + ) + } + }, + test("multiple rows insert and select") { + transactor.connect { + SqlOps.update(Frag.const("CREATE TABLE rt_multi (id INTEGER NOT NULL, name TEXT NOT NULL)")) + SqlOps.update(sql"INSERT INTO rt_multi VALUES (${DbValue.DbInt(1)}, ${DbValue.DbString("a")})") + SqlOps.update(sql"INSERT INTO rt_multi VALUES (${DbValue.DbInt(2)}, ${DbValue.DbString("b")})") + SqlOps.update(sql"INSERT INTO rt_multi VALUES (${DbValue.DbInt(3)}, ${DbValue.DbString("c")})") + val results = SqlOps.query[User]( + sql"SELECT id, name, name FROM rt_multi ORDER BY id" + ) + assertTrue( + results.length == 3, + results(0).id == 1, + results(1).id == 2, + results(2).id == 3 + ) + } + }, + test("queryOne returns None for non-existing row") { + transactor.connect { + SqlOps.update(Frag.const("CREATE TABLE rt_qone (id INTEGER NOT NULL)")) + SqlOps.update(sql"INSERT INTO rt_qone VALUES (${DbValue.DbInt(1)})") + val existing = SqlOps.queryOne[Int](sql"SELECT id FROM rt_qone WHERE id = ${DbValue.DbInt(1)}") + val nonExisting = SqlOps.queryOne[Int](sql"SELECT id FROM rt_qone WHERE id = ${DbValue.DbInt(999)}") + assertTrue( + existing == Some(1), + nonExisting.isEmpty + ) + } + }, + test("BigDecimal writeParams") { + transactor.connect { + SqlOps.update(Frag.const("CREATE TABLE rt_bigdec (v TEXT NOT NULL)")) + SqlOps.update(sql"INSERT INTO rt_bigdec (v) VALUES (${DbValue.DbBigDecimal(BigDecimal("123.456"))})") + val ps = summon[DbCon].connection.prepareStatement("SELECT v FROM rt_bigdec") + try { + val rs = ps.executeQuery() + try { + rs.next() + val str = rs.reader.getString(1) + assertTrue(str == "123.456") + } finally rs.close() + } finally ps.close() + } + }, + test("Short and Byte writeParams") { + transactor.connect { + SqlOps.update(Frag.const("CREATE TABLE rt_small (s INTEGER NOT NULL, b INTEGER NOT NULL)")) + SqlOps.update( + sql"INSERT INTO rt_small VALUES (${DbValue.DbShort(32000.toShort)}, ${DbValue.DbByte(127.toByte)})" + ) + val ps = summon[DbCon].connection.prepareStatement("SELECT s, b FROM rt_small") + try { + val rs = ps.executeQuery() + try { + rs.next() + val reader = rs.reader + val s = reader.getShort(1) + val b = reader.getByte(2) + assertTrue(s == 32000.toShort, b == 127.toByte) + } finally rs.close() + } finally ps.close() + } + }, + test("Float writeParams and read") { + transactor.connect { + SqlOps.update(Frag.const("CREATE TABLE rt_float (v REAL NOT NULL)")) + SqlOps.update(sql"INSERT INTO rt_float (v) VALUES (${DbValue.DbFloat(2.5f)})") + val ps = summon[DbCon].connection.prepareStatement("SELECT v FROM rt_float") + try { + val rs = ps.executeQuery() + try { + rs.next() + val v = rs.reader.getFloat(1) + assertTrue(v == 2.5f) + } finally rs.close() + } finally ps.close() + } + }, + test("Char roundtrip via DbChar") { + transactor.connect { + SqlOps.update(Frag.const("CREATE TABLE rt_char (v TEXT NOT NULL)")) + SqlOps.update(sql"INSERT INTO rt_char (v) VALUES (${DbValue.DbChar('X')})") + val result = SqlOps.query[String](sql"SELECT v FROM rt_char") + assertTrue(result == List("X")) + } + }, + test("UUID roundtrip via TEXT") { + transactor.connect { + val uuid = java.util.UUID.fromString("550e8400-e29b-41d4-a716-446655440000") + SqlOps.update(Frag.const("CREATE TABLE rt_uuid (v TEXT NOT NULL)")) + SqlOps.update(sql"INSERT INTO rt_uuid (v) VALUES (${DbValue.DbUUID(uuid)})") + val ps = summon[DbCon].connection.prepareStatement("SELECT v FROM rt_uuid") + try { + val rs = ps.executeQuery() + try { + rs.next() + val v = rs.reader.getUUID(1) + assertTrue(v == uuid) + } finally rs.close() + } finally ps.close() + } + }, + test("Duration roundtrip via TEXT") { + transactor.connect { + val dur = java.time.Duration.ofHours(2).plusMinutes(30) + SqlOps.update(Frag.const("CREATE TABLE rt_dur (v TEXT NOT NULL)")) + SqlOps.update(sql"INSERT INTO rt_dur (v) VALUES (${DbValue.DbDuration(dur)})") + val ps = summon[DbCon].connection.prepareStatement("SELECT v FROM rt_dur") + try { + val rs = ps.executeQuery() + try { + rs.next() + val v = rs.reader.getDuration(1) + assertTrue(v == dur) + } finally rs.close() + } finally ps.close() + } + }, + test("Instant writeParams via setTimestamp") { + transactor.connect { + val instant = java.time.Instant.parse("2024-06-15T10:30:00Z") + SqlOps.update(Frag.const("CREATE TABLE rt_inst (v TEXT NOT NULL)")) + SqlOps.update(sql"INSERT INTO rt_inst (v) VALUES (${DbValue.DbInstant(instant)})") + val ps = summon[DbCon].connection.prepareStatement("SELECT v FROM rt_inst") + try { + val rs = ps.executeQuery() + try { + rs.next() + val v = rs.reader.getString(1) + assertTrue(v != null && v.nonEmpty) + } finally rs.close() + } finally ps.close() + } + }, + test("Bytes roundtrip") { + transactor.connect { + val bytes = Array[Byte](10, 20, 30, 40, 50) + SqlOps.update(Frag.const("CREATE TABLE rt_bytes (v BLOB NOT NULL)")) + SqlOps.update(sql"INSERT INTO rt_bytes (v) VALUES (${DbValue.DbBytes(bytes)})") + val ps = summon[DbCon].connection.prepareStatement("SELECT v FROM rt_bytes") + try { + val rs = ps.executeQuery() + try { + rs.next() + val read = rs.reader.getBytes(1) + assertTrue(read.sameElements(bytes)) + } finally rs.close() + } finally ps.close() + } + } + ), + suite("Frag extension methods")( + test("frag.query delegates to SqlOps.query") { + transactor.connect { + SqlOps.update(Frag.const("CREATE TABLE ext_query (id INTEGER NOT NULL, name TEXT NOT NULL)")) + SqlOps.update(sql"INSERT INTO ext_query VALUES (${DbValue.DbInt(1)}, ${DbValue.DbString("a")})") + SqlOps.update(sql"INSERT INTO ext_query VALUES (${DbValue.DbInt(2)}, ${DbValue.DbString("b")})") + val viaOps = SqlOps.query[Int](sql"SELECT id FROM ext_query ORDER BY id") + val viaExt = sql"SELECT id FROM ext_query ORDER BY id".query[Int] + assertTrue(viaOps == viaExt, viaExt == List(1, 2)) + } + }, + test("frag.queryOne returns Some for match, None for no match") { + transactor.connect { + SqlOps.update(Frag.const("CREATE TABLE ext_qone (id INTEGER NOT NULL)")) + SqlOps.update(sql"INSERT INTO ext_qone VALUES (${DbValue.DbInt(42)})") + val found = sql"SELECT id FROM ext_qone WHERE id = ${DbValue.DbInt(42)}".queryOne[Int] + val notFound = sql"SELECT id FROM ext_qone WHERE id = ${DbValue.DbInt(999)}".queryOne[Int] + assertTrue(found == Some(42), notFound.isEmpty) + } + }, + test("frag.queryLimit returns at most N rows") { + transactor.connect { + SqlOps.update(Frag.const("CREATE TABLE ext_qlimit (id INTEGER NOT NULL)")) + SqlOps.update(sql"INSERT INTO ext_qlimit VALUES (${DbValue.DbInt(1)})") + SqlOps.update(sql"INSERT INTO ext_qlimit VALUES (${DbValue.DbInt(2)})") + SqlOps.update(sql"INSERT INTO ext_qlimit VALUES (${DbValue.DbInt(3)})") + SqlOps.update(sql"INSERT INTO ext_qlimit VALUES (${DbValue.DbInt(4)})") + SqlOps.update(sql"INSERT INTO ext_qlimit VALUES (${DbValue.DbInt(5)})") + val limited = sql"SELECT id FROM ext_qlimit ORDER BY id".queryLimit[Int](2) + val all = sql"SELECT id FROM ext_qlimit ORDER BY id".query[Int] + assertTrue(limited == List(1, 2), all.length == 5) + } + }, + test("frag.queryLimit with limit larger than result set returns all") { + transactor.connect { + SqlOps.update(Frag.const("CREATE TABLE ext_qlimit2 (id INTEGER NOT NULL)")) + SqlOps.update(sql"INSERT INTO ext_qlimit2 VALUES (${DbValue.DbInt(1)})") + SqlOps.update(sql"INSERT INTO ext_qlimit2 VALUES (${DbValue.DbInt(2)})") + val result = sql"SELECT id FROM ext_qlimit2 ORDER BY id".queryLimit[Int](100) + assertTrue(result == List(1, 2)) + } + }, + test("frag.update returns affected row count") { + transactor.connect { + SqlOps.update(Frag.const("CREATE TABLE ext_upd (id INTEGER NOT NULL)")) + sql"INSERT INTO ext_upd VALUES (${DbValue.DbInt(1)})".update + sql"INSERT INTO ext_upd VALUES (${DbValue.DbInt(2)})".update + val count = Frag.const("DELETE FROM ext_upd").update + assertTrue(count == 2) + } + }, + test("SqlOps.queryLimit stops early") { + transactor.connect { + SqlOps.update(Frag.const("CREATE TABLE qlimit_ops (id INTEGER NOT NULL)")) + SqlOps.update(sql"INSERT INTO qlimit_ops VALUES (${DbValue.DbInt(1)})") + SqlOps.update(sql"INSERT INTO qlimit_ops VALUES (${DbValue.DbInt(2)})") + SqlOps.update(sql"INSERT INTO qlimit_ops VALUES (${DbValue.DbInt(3)})") + val result = SqlOps.queryLimit[Int](sql"SELECT id FROM qlimit_ops ORDER BY id", 2) + assertTrue(result == List(1, 2)) + } + }, + test("SqlOps.queryLimit with zero returns empty") { + transactor.connect { + SqlOps.update(Frag.const("CREATE TABLE qlimit_zero (id INTEGER NOT NULL)")) + SqlOps.update(sql"INSERT INTO qlimit_zero VALUES (${DbValue.DbInt(1)})") + val result = SqlOps.queryLimit[Int](sql"SELECT id FROM qlimit_zero", 0) + assertTrue(result.isEmpty) + } + } + ) + ) +} diff --git a/sql/shared/src/main/scala/zio/blocks/sql/DbCodec.scala b/sql/shared/src/main/scala/zio/blocks/sql/DbCodec.scala new file mode 100644 index 0000000000..0118caffbc --- /dev/null +++ b/sql/shared/src/main/scala/zio/blocks/sql/DbCodec.scala @@ -0,0 +1,53 @@ +package zio.blocks.sql + +trait DbCodec[A] { + def columns: IndexedSeq[String] + def readValue(reader: DbResultReader, startIndex: Int): A + def writeValue(writer: DbParamWriter, startIndex: Int, value: A): Unit + def toDbValues(value: A): IndexedSeq[DbValue] + def columnCount: Int = columns.size +} + +object DbCodec { + def apply[A](implicit codec: DbCodec[A]): DbCodec[A] = codec +} + +trait DbResultReader { + def getInt(index: Int): Int + def getLong(index: Int): Long + def getDouble(index: Int): Double + def getFloat(index: Int): Float + def getBoolean(index: Int): Boolean + def getString(index: Int): String + def getBigDecimal(index: Int): java.math.BigDecimal + def getBytes(index: Int): Array[Byte] + def getShort(index: Int): Short + def getByte(index: Int): Byte + def getLocalDate(index: Int): java.time.LocalDate + def getLocalDateTime(index: Int): java.time.LocalDateTime + def getLocalTime(index: Int): java.time.LocalTime + def getInstant(index: Int): java.time.Instant + def getDuration(index: Int): java.time.Duration + def getUUID(index: Int): java.util.UUID + def wasNull: Boolean +} + +trait DbParamWriter { + def setInt(index: Int, value: Int): Unit + def setLong(index: Int, value: Long): Unit + def setDouble(index: Int, value: Double): Unit + def setFloat(index: Int, value: Float): Unit + def setBoolean(index: Int, value: Boolean): Unit + def setString(index: Int, value: String): Unit + def setBigDecimal(index: Int, value: java.math.BigDecimal): Unit + def setBytes(index: Int, value: Array[Byte]): Unit + def setShort(index: Int, value: Short): Unit + def setByte(index: Int, value: Byte): Unit + def setLocalDate(index: Int, value: java.time.LocalDate): Unit + def setLocalDateTime(index: Int, value: java.time.LocalDateTime): Unit + def setLocalTime(index: Int, value: java.time.LocalTime): Unit + def setInstant(index: Int, value: java.time.Instant): Unit + def setDuration(index: Int, value: java.time.Duration): Unit + def setUUID(index: Int, value: java.util.UUID): Unit + def setNull(index: Int, sqlType: Int): Unit +} diff --git a/sql/shared/src/main/scala/zio/blocks/sql/DbCodecDeriver.scala b/sql/shared/src/main/scala/zio/blocks/sql/DbCodecDeriver.scala new file mode 100644 index 0000000000..abac57343f --- /dev/null +++ b/sql/shared/src/main/scala/zio/blocks/sql/DbCodecDeriver.scala @@ -0,0 +1,593 @@ +package zio.blocks.sql + +import zio.blocks.schema._ +import zio.blocks.schema.binding._ +import zio.blocks.schema.derive._ +import zio.blocks.docs.Doc +import zio.blocks.typeid.TypeId + +class DbCodecDeriver(columnNameMapper: SqlNameMapper = SqlNameMapper.SnakeCase) extends Deriver[DbCodec] { + + override def derivePrimitive[A]( + primitiveType: PrimitiveType[A], + typeId: TypeId[A], + binding: Binding.Primitive[A], + doc: Doc, + modifiers: Seq[Modifier.Reflect], + defaultValue: Option[A], + examples: Seq[A] + ): Lazy[DbCodec[A]] = Lazy { + primitiveType match { + case PrimitiveType.Unit => unitCodec.asInstanceOf[DbCodec[A]] + case _: PrimitiveType.Boolean => booleanCodec.asInstanceOf[DbCodec[A]] + case _: PrimitiveType.Byte => byteCodec.asInstanceOf[DbCodec[A]] + case _: PrimitiveType.Short => shortCodec.asInstanceOf[DbCodec[A]] + case _: PrimitiveType.Int => intCodec.asInstanceOf[DbCodec[A]] + case _: PrimitiveType.Long => longCodec.asInstanceOf[DbCodec[A]] + case _: PrimitiveType.Float => floatCodec.asInstanceOf[DbCodec[A]] + case _: PrimitiveType.Double => doubleCodec.asInstanceOf[DbCodec[A]] + case _: PrimitiveType.Char => charCodec.asInstanceOf[DbCodec[A]] + case _: PrimitiveType.String => stringCodec.asInstanceOf[DbCodec[A]] + case _: PrimitiveType.BigDecimal => bigDecimalCodec.asInstanceOf[DbCodec[A]] + case _: PrimitiveType.Duration => durationCodec.asInstanceOf[DbCodec[A]] + case _: PrimitiveType.Instant => instantCodec.asInstanceOf[DbCodec[A]] + case _: PrimitiveType.LocalDate => localDateCodec.asInstanceOf[DbCodec[A]] + case _: PrimitiveType.LocalDateTime => + localDateTimeCodec.asInstanceOf[DbCodec[A]] + case _: PrimitiveType.LocalTime => localTimeCodec.asInstanceOf[DbCodec[A]] + case _: PrimitiveType.UUID => uuidCodec.asInstanceOf[DbCodec[A]] + case other => + throw new UnsupportedOperationException( + s"DbCodec does not support primitive type: ${other.getClass.getSimpleName}" + ) + } + } + + override def deriveRecord[F[_, _], A]( + fields: IndexedSeq[Term[F, A, ?]], + typeId: TypeId[A], + binding: Binding.Record[A], + doc: Doc, + modifiers: Seq[Modifier.Reflect], + defaultValue: Option[A], + examples: Seq[A] + )(implicit F: HasBinding[F], D: HasInstance[F]): Lazy[DbCodec[A]] = Lazy { + val recordBinding = binding + val constructor = recordBinding.constructor + val deconstructor = recordBinding.deconstructor + val len = fields.length + + val fieldNames = new Array[String](len) + val fieldCodecs = new Array[DbCodec[Any]](len) + val fieldTransient = new Array[Boolean](len) + + val reflects = new Array[Reflect[F, ?]](len) + var idx = 0 + while (idx < len) { + reflects(idx) = fields(idx).value + idx += 1 + } + val registers: IndexedSeq[Register[Any]] = + scala.collection.immutable.ArraySeq.unsafeWrapArray( + Reflect.Record.registers(reflects.asInstanceOf[Array[Reflect[F, ?]]]) + ) + + idx = 0 + while (idx < len) { + val field = fields(idx) + val isTransient = field.modifiers.exists(_.isInstanceOf[Modifier.transient]) + fieldTransient(idx) = isTransient + + if (!isTransient) { + val renamed = field.modifiers.collectFirst { case m: Modifier.rename => m.name } + fieldNames(idx) = renamed.getOrElse(columnNameMapper(field.name)) + + fieldCodecs(idx) = D.instance(field.value.metadata).force.asInstanceOf[DbCodec[Any]] + } + idx += 1 + } + + val activeFieldIndices: Array[Int] = (0 until len).filter(i => !fieldTransient(i)).toArray + val allColumns: IndexedSeq[String] = { + val builder = IndexedSeq.newBuilder[String] + var fi = 0 + while (fi < activeFieldIndices.length) { + val i = activeFieldIndices(fi) + val codec = fieldCodecs(i) + val fieldName = fieldNames(i) + if (codec.columnCount == 1) { + builder += fieldName + } else { + val cols = codec.columns + var ci = 0 + while (ci < cols.length) { + builder += fieldName + "_" + cols(ci) + ci += 1 + } + } + fi += 1 + } + builder.result() + } + + new DbCodec[A] { + val columns: IndexedSeq[String] = allColumns + + def readValue(reader: DbResultReader, startIndex: Int): A = { + val regs = Registers(constructor.usedRegisters) + var colIdx = startIndex + var fi = 0 + while (fi < activeFieldIndices.length) { + val i = activeFieldIndices(fi) + val codec = fieldCodecs(i) + val fieldValue = codec.readValue(reader, colIdx) + registers(i).set(regs, 0, fieldValue) + colIdx += codec.columnCount + fi += 1 + } + var ti = 0 + while (ti < len) { + if (fieldTransient(ti)) { + fields(ti).value.getDefaultValue(F) match { + case Some(dv) => registers(ti).set(regs, 0, dv) + case None => + } + } + ti += 1 + } + constructor.construct(regs, 0) + } + + def writeValue(writer: DbParamWriter, startIndex: Int, value: A): Unit = { + val regs = Registers(deconstructor.usedRegisters) + deconstructor.deconstruct(regs, 0, value) + var colIdx = startIndex + var fi = 0 + while (fi < activeFieldIndices.length) { + val i = activeFieldIndices(fi) + val codec = fieldCodecs(i) + codec.writeValue(writer, colIdx, registers(i).get(regs, 0)) + colIdx += codec.columnCount + fi += 1 + } + } + + def toDbValues(value: A): IndexedSeq[DbValue] = { + val regs = Registers(deconstructor.usedRegisters) + deconstructor.deconstruct(regs, 0, value) + val builder = IndexedSeq.newBuilder[DbValue] + var fi = 0 + while (fi < activeFieldIndices.length) { + val i = activeFieldIndices(fi) + val codec = fieldCodecs(i) + builder ++= codec.toDbValues(registers(i).get(regs, 0)) + fi += 1 + } + builder.result() + } + } + } + + override def deriveWrapper[F[_, _], A, B]( + wrapped: Reflect[F, B], + typeId: TypeId[A], + binding: Binding.Wrapper[A, B], + doc: Doc, + modifiers: Seq[Modifier.Reflect], + defaultValue: Option[A], + examples: Seq[A] + )(implicit F: HasBinding[F], D: HasInstance[F]): Lazy[DbCodec[A]] = + D.instance(wrapped.metadata).map { wrappedCodec => + val wc = wrappedCodec.asInstanceOf[DbCodec[B]] + new DbCodec[A] { + val columns: IndexedSeq[String] = wc.columns + + def readValue(reader: DbResultReader, startIndex: Int): A = + binding.wrap(wc.readValue(reader, startIndex)) + + def writeValue(writer: DbParamWriter, startIndex: Int, value: A): Unit = + wc.writeValue(writer, startIndex, binding.unwrap(value)) + + def toDbValues(value: A): IndexedSeq[DbValue] = + wc.toDbValues(binding.unwrap(value)) + } + } + + override def deriveVariant[F[_, _], A]( + cases: IndexedSeq[Term[F, A, ?]], + typeId: TypeId[A], + binding: Binding.Variant[A], + doc: Doc, + modifiers: Seq[Modifier.Reflect], + defaultValue: Option[A], + examples: Seq[A] + )(implicit F: HasBinding[F], D: HasInstance[F]): Lazy[DbCodec[A]] = Lazy { + if (isOptionType(typeId, cases)) { + val someCase = cases(1) + val someRecord = someCase.value.asRecord.get + val innerField = someRecord.fields(0) + val innerCodec = D.instance(innerField.value.metadata).force.asInstanceOf[DbCodec[Any]] + + new DbCodec[A] { + val columns: IndexedSeq[String] = innerCodec.columns + + // Note: wasNull reflects the last column read by the inner codec. + // For single-column types (the common case), this is correct. + // For multi-column inner types, wasNull only reflects the last column, + // so a NULL in an earlier column may not be detected. + def readValue(reader: DbResultReader, startIndex: Int): A = { + val innerValue = innerCodec.readValue(reader, startIndex) + val result: Any = if (reader.wasNull) None else Some(innerValue) + result.asInstanceOf[A] + } + + def writeValue(writer: DbParamWriter, startIndex: Int, value: A): Unit = { + val opt = value.asInstanceOf[Option[Any]] + opt match { + case Some(v) => innerCodec.writeValue(writer, startIndex, v) + case None => + var i = 0 + while (i < innerCodec.columnCount) { + writer.setNull(startIndex + i, 0) + i += 1 + } + } + } + + def toDbValues(value: A): IndexedSeq[DbValue] = { + val opt = value.asInstanceOf[Option[Any]] + opt match { + case Some(v) => innerCodec.toDbValues(v) + case None => + val builder = IndexedSeq.newBuilder[DbValue] + var i = 0 + while (i < innerCodec.columnCount) { + builder += DbValue.DbNull + i += 1 + } + builder.result() + } + } + } + } else if (isSimpleEnum(cases)) { + val discr = binding.discriminator + val constructorByName = buildConstructorMap(cases) + val caseNames: Array[String] = collectCaseNames(cases) + val caseNamesJoined = caseNames.mkString(", ") + + new DbCodec[A] { + val columns: IndexedSeq[String] = IndexedSeq("value") + + def readValue(reader: DbResultReader, startIndex: Int): A = { + val name = reader.getString(startIndex) + constructorByName.get(name) match { + case Some(ctor) => ctor.construct(null, 0).asInstanceOf[A] + case None => + throw new IllegalArgumentException( + s"Unknown enum variant '${name}'. Expected one of: ${caseNamesJoined}" + ) + } + } + + def writeValue(writer: DbParamWriter, startIndex: Int, value: A): Unit = + writer.setString(startIndex, enumName(discr, cases, value)) + + def toDbValues(value: A): IndexedSeq[DbValue] = + IndexedSeq(DbValue.DbString(enumName(discr, cases, value))) + } + } else { + throw new UnsupportedOperationException( + "DbCodec does not support sum types (sealed trait/enum) with data fields as SQL columns" + ) + } + } + + override def deriveSequence[F[_, _], C[_], A]( + element: Reflect[F, A], + typeId: TypeId[C[A]], + binding: Binding.Seq[C, A], + doc: Doc, + modifiers: Seq[Modifier.Reflect], + defaultValue: Option[C[A]], + examples: Seq[C[A]] + )(implicit F: HasBinding[F], D: HasInstance[F]): Lazy[DbCodec[C[A]]] = + Lazy { + throw new UnsupportedOperationException( + "DbCodec does not support collection types (Seq, List, etc.) as SQL columns" + ) + } + + override def deriveMap[F[_, _], M[_, _], K, V]( + key: Reflect[F, K], + value: Reflect[F, V], + typeId: TypeId[M[K, V]], + binding: Binding.Map[M, K, V], + doc: Doc, + modifiers: Seq[Modifier.Reflect], + defaultValue: Option[M[K, V]], + examples: Seq[M[K, V]] + )(implicit F: HasBinding[F], D: HasInstance[F]): Lazy[DbCodec[M[K, V]]] = + Lazy { + throw new UnsupportedOperationException( + "DbCodec does not support Map types as SQL columns" + ) + } + + override def deriveDynamic[F[_, _]]( + binding: Binding.Dynamic, + doc: Doc, + modifiers: Seq[Modifier.Reflect], + defaultValue: Option[DynamicValue], + examples: Seq[DynamicValue] + )(implicit F: HasBinding[F], D: HasInstance[F]): Lazy[DbCodec[DynamicValue]] = + Lazy { + throw new UnsupportedOperationException( + "DbCodec does not support DynamicValue as SQL columns" + ) + } + + private def isOptionType[F[_, _], A]( + typeId: TypeId[A], + cases: IndexedSeq[Term[F, A, ?]] + ): Boolean = + typeId.owner == zio.blocks.typeid.Owner.fromPackagePath("scala") && + typeId.name == "Option" && + cases.length == 2 && + cases(1).name == "Some" + + private def isSimpleEnum[F[_, _], A](cases: IndexedSeq[Term[F, A, ?]]): Boolean = + cases.forall { case_ => + val caseReflect = case_.value + caseReflect.asRecord.exists(_.fields.isEmpty) || + (caseReflect.isVariant && caseReflect.asVariant.exists(v => isSimpleEnum(v.cases))) + } + + private def buildConstructorMap[F[_, _], A]( + cases: IndexedSeq[Term[F, A, ?]] + )(implicit F: HasBinding[F]): Map[String, Constructor[?]] = { + val builder = Map.newBuilder[String, Constructor[?]] + collectConstructors(cases, builder) + builder.result() + } + + private def collectConstructors[F[_, _], A]( + cases: IndexedSeq[Term[F, A, ?]], + builder: scala.collection.mutable.Builder[(String, Constructor[?]), Map[String, Constructor[?]]] + )(implicit F: HasBinding[F]): Unit = { + var idx = 0 + while (idx < cases.length) { + val case_ = cases(idx) + val caseReflect = case_.value + if (caseReflect.isVariant) { + val nestedVariant = caseReflect.asVariant.get + collectConstructors( + nestedVariant.cases.asInstanceOf[IndexedSeq[Term[F, A, ?]]], + builder + ) + } else { + val recordBinding = F.binding(caseReflect.asRecord.get.recordBinding) + val constructor = recordBinding.asInstanceOf[Binding.Record[Any]].constructor + builder += (case_.name -> constructor) + } + idx += 1 + } + } + + private def collectCaseNames[F[_, _], A](cases: IndexedSeq[Term[F, A, ?]]): Array[String] = { + val builder = Array.newBuilder[String] + def go(cs: IndexedSeq[Term[F, A, ?]]): Unit = { + var idx = 0 + while (idx < cs.length) { + val case_ = cs(idx) + val caseReflect = case_.value + if (caseReflect.isVariant) { + go(caseReflect.asVariant.get.cases.asInstanceOf[IndexedSeq[Term[F, A, ?]]]) + } else { + builder += case_.name + } + idx += 1 + } + } + go(cases) + builder.result() + } + + private def enumName[F[_, _], A]( + discr: Discriminator[A], + cases: IndexedSeq[Term[F, A, ?]], + value: A + )(implicit F: HasBinding[F]): String = { + val idx = discr.discriminate(value) + val case_ = cases(idx) + if (case_.value.isVariant) { + val nestedVariant = case_.value.asVariant.get.asInstanceOf[Reflect.Variant[F, A]] + val nestedDiscr = F.binding(nestedVariant.variantBinding).asInstanceOf[Binding.Variant[A]].discriminator + enumName(nestedDiscr, nestedVariant.cases.asInstanceOf[IndexedSeq[Term[F, A, ?]]], value) + } else { + case_.name + } + } + + private val unitCodec: DbCodec[Unit] = new DbCodec[Unit] { + val columns: IndexedSeq[String] = IndexedSeq.empty + def readValue(reader: DbResultReader, startIndex: Int): Unit = () + def writeValue(writer: DbParamWriter, startIndex: Int, value: Unit): Unit = () + def toDbValues(value: Unit): IndexedSeq[DbValue] = IndexedSeq.empty + override def columnCount: Int = 0 + } + + private val booleanCodec: DbCodec[Boolean] = new DbCodec[Boolean] { + val columns: IndexedSeq[String] = IndexedSeq("value") + def readValue(reader: DbResultReader, startIndex: Int): Boolean = + reader.getBoolean(startIndex) + def writeValue(writer: DbParamWriter, startIndex: Int, value: Boolean): Unit = + writer.setBoolean(startIndex, value) + def toDbValues(value: Boolean): IndexedSeq[DbValue] = + IndexedSeq(DbValue.DbBoolean(value)) + } + + private val byteCodec: DbCodec[Byte] = new DbCodec[Byte] { + val columns: IndexedSeq[String] = IndexedSeq("value") + def readValue(reader: DbResultReader, startIndex: Int): Byte = + reader.getByte(startIndex) + def writeValue(writer: DbParamWriter, startIndex: Int, value: Byte): Unit = + writer.setByte(startIndex, value) + def toDbValues(value: Byte): IndexedSeq[DbValue] = + IndexedSeq(DbValue.DbByte(value)) + } + + private val shortCodec: DbCodec[Short] = new DbCodec[Short] { + val columns: IndexedSeq[String] = IndexedSeq("value") + def readValue(reader: DbResultReader, startIndex: Int): Short = + reader.getShort(startIndex) + def writeValue(writer: DbParamWriter, startIndex: Int, value: Short): Unit = + writer.setShort(startIndex, value) + def toDbValues(value: Short): IndexedSeq[DbValue] = + IndexedSeq(DbValue.DbShort(value)) + } + + private val intCodec: DbCodec[Int] = new DbCodec[Int] { + val columns: IndexedSeq[String] = IndexedSeq("value") + def readValue(reader: DbResultReader, startIndex: Int): Int = + reader.getInt(startIndex) + def writeValue(writer: DbParamWriter, startIndex: Int, value: Int): Unit = + writer.setInt(startIndex, value) + def toDbValues(value: Int): IndexedSeq[DbValue] = + IndexedSeq(DbValue.DbInt(value)) + } + + private val longCodec: DbCodec[Long] = new DbCodec[Long] { + val columns: IndexedSeq[String] = IndexedSeq("value") + def readValue(reader: DbResultReader, startIndex: Int): Long = + reader.getLong(startIndex) + def writeValue(writer: DbParamWriter, startIndex: Int, value: Long): Unit = + writer.setLong(startIndex, value) + def toDbValues(value: Long): IndexedSeq[DbValue] = + IndexedSeq(DbValue.DbLong(value)) + } + + private val floatCodec: DbCodec[Float] = new DbCodec[Float] { + val columns: IndexedSeq[String] = IndexedSeq("value") + def readValue(reader: DbResultReader, startIndex: Int): Float = + reader.getFloat(startIndex) + def writeValue(writer: DbParamWriter, startIndex: Int, value: Float): Unit = + writer.setFloat(startIndex, value) + def toDbValues(value: Float): IndexedSeq[DbValue] = + IndexedSeq(DbValue.DbFloat(value)) + } + + private val doubleCodec: DbCodec[Double] = new DbCodec[Double] { + val columns: IndexedSeq[String] = IndexedSeq("value") + def readValue(reader: DbResultReader, startIndex: Int): Double = + reader.getDouble(startIndex) + def writeValue(writer: DbParamWriter, startIndex: Int, value: Double): Unit = + writer.setDouble(startIndex, value) + def toDbValues(value: Double): IndexedSeq[DbValue] = + IndexedSeq(DbValue.DbDouble(value)) + } + + private val charCodec: DbCodec[Char] = new DbCodec[Char] { + val columns: IndexedSeq[String] = IndexedSeq("value") + def readValue(reader: DbResultReader, startIndex: Int): Char = { + val s = reader.getString(startIndex) + if (s != null && s.length > 0) s.charAt(0) else '\u0000' + } + def writeValue(writer: DbParamWriter, startIndex: Int, value: Char): Unit = + writer.setString(startIndex, value.toString) + def toDbValues(value: Char): IndexedSeq[DbValue] = + IndexedSeq(DbValue.DbChar(value)) + } + + private val stringCodec: DbCodec[String] = new DbCodec[String] { + val columns: IndexedSeq[String] = IndexedSeq("value") + def readValue(reader: DbResultReader, startIndex: Int): String = + reader.getString(startIndex) + def writeValue(writer: DbParamWriter, startIndex: Int, value: String): Unit = + writer.setString(startIndex, value) + def toDbValues(value: String): IndexedSeq[DbValue] = + IndexedSeq(DbValue.DbString(value)) + } + + private val bigDecimalCodec: DbCodec[BigDecimal] = new DbCodec[BigDecimal] { + val columns: IndexedSeq[String] = IndexedSeq("value") + def readValue(reader: DbResultReader, startIndex: Int): BigDecimal = { + val jbd = reader.getBigDecimal(startIndex) + if (jbd != null) scala.BigDecimal(jbd) else null.asInstanceOf[BigDecimal] + } + def writeValue(writer: DbParamWriter, startIndex: Int, value: BigDecimal): Unit = + writer.setBigDecimal(startIndex, value.bigDecimal) + def toDbValues(value: BigDecimal): IndexedSeq[DbValue] = + IndexedSeq(DbValue.DbBigDecimal(value)) + } + + private val durationCodec: DbCodec[java.time.Duration] = + new DbCodec[java.time.Duration] { + val columns: IndexedSeq[String] = IndexedSeq("value") + def readValue(reader: DbResultReader, startIndex: Int): java.time.Duration = + reader.getDuration(startIndex) + def writeValue(writer: DbParamWriter, startIndex: Int, value: java.time.Duration): Unit = + writer.setDuration(startIndex, value) + def toDbValues(value: java.time.Duration): IndexedSeq[DbValue] = + IndexedSeq(DbValue.DbDuration(value)) + } + + private val instantCodec: DbCodec[java.time.Instant] = + new DbCodec[java.time.Instant] { + val columns: IndexedSeq[String] = IndexedSeq("value") + def readValue(reader: DbResultReader, startIndex: Int): java.time.Instant = + reader.getInstant(startIndex) + def writeValue(writer: DbParamWriter, startIndex: Int, value: java.time.Instant): Unit = + writer.setInstant(startIndex, value) + def toDbValues(value: java.time.Instant): IndexedSeq[DbValue] = + IndexedSeq(DbValue.DbInstant(value)) + } + + private val localDateCodec: DbCodec[java.time.LocalDate] = + new DbCodec[java.time.LocalDate] { + val columns: IndexedSeq[String] = IndexedSeq("value") + def readValue(reader: DbResultReader, startIndex: Int): java.time.LocalDate = + reader.getLocalDate(startIndex) + def writeValue(writer: DbParamWriter, startIndex: Int, value: java.time.LocalDate): Unit = + writer.setLocalDate(startIndex, value) + def toDbValues(value: java.time.LocalDate): IndexedSeq[DbValue] = + IndexedSeq(DbValue.DbLocalDate(value)) + } + + private val localDateTimeCodec: DbCodec[java.time.LocalDateTime] = + new DbCodec[java.time.LocalDateTime] { + val columns: IndexedSeq[String] = IndexedSeq("value") + def readValue(reader: DbResultReader, startIndex: Int): java.time.LocalDateTime = + reader.getLocalDateTime(startIndex) + def writeValue(writer: DbParamWriter, startIndex: Int, value: java.time.LocalDateTime): Unit = + writer.setLocalDateTime(startIndex, value) + def toDbValues(value: java.time.LocalDateTime): IndexedSeq[DbValue] = + IndexedSeq(DbValue.DbLocalDateTime(value)) + } + + private val localTimeCodec: DbCodec[java.time.LocalTime] = + new DbCodec[java.time.LocalTime] { + val columns: IndexedSeq[String] = IndexedSeq("value") + def readValue(reader: DbResultReader, startIndex: Int): java.time.LocalTime = + reader.getLocalTime(startIndex) + def writeValue(writer: DbParamWriter, startIndex: Int, value: java.time.LocalTime): Unit = + writer.setLocalTime(startIndex, value) + def toDbValues(value: java.time.LocalTime): IndexedSeq[DbValue] = + IndexedSeq(DbValue.DbLocalTime(value)) + } + + private val uuidCodec: DbCodec[java.util.UUID] = + new DbCodec[java.util.UUID] { + val columns: IndexedSeq[String] = IndexedSeq("value") + def readValue(reader: DbResultReader, startIndex: Int): java.util.UUID = + reader.getUUID(startIndex) + def writeValue(writer: DbParamWriter, startIndex: Int, value: java.util.UUID): Unit = + writer.setUUID(startIndex, value) + def toDbValues(value: java.util.UUID): IndexedSeq[DbValue] = + IndexedSeq(DbValue.DbUUID(value)) + } +} + +object DbCodecDeriver extends DbCodecDeriver(SqlNameMapper.SnakeCase) { + def withColumnNameMapper(mapper: SqlNameMapper): DbCodecDeriver = + new DbCodecDeriver(mapper) +} diff --git a/sql/shared/src/main/scala/zio/blocks/sql/DbCon.scala b/sql/shared/src/main/scala/zio/blocks/sql/DbCon.scala new file mode 100644 index 0000000000..540ad5041b --- /dev/null +++ b/sql/shared/src/main/scala/zio/blocks/sql/DbCon.scala @@ -0,0 +1,7 @@ +package zio.blocks.sql + +trait DbCon { + def connection: DbConnection + def dialect: SqlDialect + def logger: SqlLogger +} diff --git a/sql/shared/src/main/scala/zio/blocks/sql/DbConnection.scala b/sql/shared/src/main/scala/zio/blocks/sql/DbConnection.scala new file mode 100644 index 0000000000..1b045395fe --- /dev/null +++ b/sql/shared/src/main/scala/zio/blocks/sql/DbConnection.scala @@ -0,0 +1,26 @@ +package zio.blocks.sql + +trait DbConnection extends AutoCloseable { + def prepareStatement(sql: String): DbPreparedStatement + def close(): Unit + def isClosed: Boolean + def setAutoCommit(autoCommit: Boolean): Unit + def getAutoCommit: Boolean + def commit(): Unit + def rollback(): Unit +} + +trait DbPreparedStatement extends AutoCloseable { + def executeQuery(): DbResultSet + def executeUpdate(): Int + def close(): Unit + def paramWriter: DbParamWriter + def addBatch(): Unit + def executeBatch(): Array[Int] +} + +trait DbResultSet extends AutoCloseable { + def next(): Boolean + def close(): Unit + def reader: DbResultReader +} diff --git a/sql/shared/src/main/scala/zio/blocks/sql/DbTx.scala b/sql/shared/src/main/scala/zio/blocks/sql/DbTx.scala new file mode 100644 index 0000000000..a985963e3f --- /dev/null +++ b/sql/shared/src/main/scala/zio/blocks/sql/DbTx.scala @@ -0,0 +1,3 @@ +package zio.blocks.sql + +trait DbTx extends DbCon diff --git a/sql/shared/src/main/scala/zio/blocks/sql/DbValue.scala b/sql/shared/src/main/scala/zio/blocks/sql/DbValue.scala new file mode 100644 index 0000000000..d3b4f3d5d8 --- /dev/null +++ b/sql/shared/src/main/scala/zio/blocks/sql/DbValue.scala @@ -0,0 +1,27 @@ +package zio.blocks.sql + +import java.time.{Duration, Instant, LocalDate, LocalDateTime, LocalTime} +import java.util.UUID + +sealed trait DbValue + +object DbValue { + case object DbNull extends DbValue + final case class DbInt(value: Int) extends DbValue + final case class DbLong(value: Long) extends DbValue + final case class DbDouble(value: Double) extends DbValue + final case class DbFloat(value: Float) extends DbValue + final case class DbBoolean(value: Boolean) extends DbValue + final case class DbString(value: String) extends DbValue + final case class DbBigDecimal(value: scala.BigDecimal) extends DbValue + final case class DbBytes(value: Array[Byte]) extends DbValue + final case class DbShort(value: Short) extends DbValue + final case class DbByte(value: Byte) extends DbValue + final case class DbChar(value: Char) extends DbValue + final case class DbLocalDate(value: LocalDate) extends DbValue + final case class DbLocalDateTime(value: LocalDateTime) extends DbValue + final case class DbLocalTime(value: LocalTime) extends DbValue + final case class DbInstant(value: Instant) extends DbValue + final case class DbDuration(value: Duration) extends DbValue + final case class DbUUID(value: UUID) extends DbValue +} diff --git a/sql/shared/src/main/scala/zio/blocks/sql/Ddl.scala b/sql/shared/src/main/scala/zio/blocks/sql/Ddl.scala new file mode 100644 index 0000000000..e0572f190b --- /dev/null +++ b/sql/shared/src/main/scala/zio/blocks/sql/Ddl.scala @@ -0,0 +1,17 @@ +package zio.blocks.sql + +final case class ColumnDef(name: String, sqlType: String, nullable: Boolean) + +object Ddl { + + def createTable(tableName: String, columns: IndexedSeq[ColumnDef]): Frag = { + val colDefs = columns.map { col => + val nullStr = if (col.nullable) "" else " NOT NULL" + s" ${col.name} ${col.sqlType}$nullStr" + } + Frag.const(s"CREATE TABLE IF NOT EXISTS $tableName (\n${colDefs.mkString(",\n")}\n)") + } + + def dropTable(tableName: String): Frag = + Frag.const(s"DROP TABLE IF EXISTS $tableName") +} diff --git a/sql/shared/src/main/scala/zio/blocks/sql/Frag.scala b/sql/shared/src/main/scala/zio/blocks/sql/Frag.scala new file mode 100644 index 0000000000..97c9e4b707 --- /dev/null +++ b/sql/shared/src/main/scala/zio/blocks/sql/Frag.scala @@ -0,0 +1,52 @@ +package zio.blocks.sql + +final case class Frag(parts: IndexedSeq[String], params: IndexedSeq[DbValue]) { + + def ++(other: Frag): Frag = + if (parts.isEmpty) other + else if (other.parts.isEmpty) this + else { + val mergedParts = parts.init ++ IndexedSeq(parts.last + other.parts.head) ++ other.parts.tail + Frag(mergedParts, params ++ other.params) + } + + def sql(dialect: SqlDialect): String = { + val sb = new StringBuilder + var paramIdx = 1 + var i = 0 + while (i < parts.length) { + sb.append(parts(i)) + if (i < params.length) { + sb.append(dialect.paramPlaceholder(paramIdx)) + paramIdx += 1 + } + i += 1 + } + sb.toString() + } + + def queryParams: IndexedSeq[DbValue] = params + + def isEmpty: Boolean = parts.forall(_.isEmpty) && params.isEmpty +} + +object Frag { + val empty: Frag = Frag(IndexedSeq(""), IndexedSeq.empty) + + def const(sqlStr: String): Frag = Frag(IndexedSeq(sqlStr), IndexedSeq.empty) + + extension (frag: Frag) { + + def query[A](using DbCon, DbCodec[A]): List[A] = + SqlOps.query[A](frag) + + def queryOne[A](using DbCon, DbCodec[A]): Option[A] = + SqlOps.queryOne[A](frag) + + def queryLimit[A](limit: Int)(using DbCon, DbCodec[A]): List[A] = + SqlOps.queryLimit[A](frag, limit) + + def update(using DbCon): Int = + SqlOps.update(frag) + } +} diff --git a/sql/shared/src/main/scala/zio/blocks/sql/Repo.scala b/sql/shared/src/main/scala/zio/blocks/sql/Repo.scala new file mode 100644 index 0000000000..0202f55108 --- /dev/null +++ b/sql/shared/src/main/scala/zio/blocks/sql/Repo.scala @@ -0,0 +1,158 @@ +package zio.blocks.sql + +class Repo[E, ID]( + val table: Table[E], + val idColumn: String, + val idCodec: DbCodec[ID], + val getId: E => ID +) { + + private val allCols: String = table.columns.mkString(", ") + private val tbl: String = table.name + private val codec: DbCodec[E] = table.codec + + private val longCodec: DbCodec[Long] = new DbCodec[Long] { + val columns: IndexedSeq[String] = IndexedSeq("count") + def readValue(reader: DbResultReader, startIndex: Int): Long = reader.getLong(startIndex) + def writeValue(writer: DbParamWriter, startIndex: Int, value: Long): Unit = + writer.setLong(startIndex, value) + def toDbValues(value: Long): IndexedSeq[DbValue] = IndexedSeq(DbValue.DbLong(value)) + } + + // === Read Operations === + + def findAll(using con: DbCon): List[E] = { + val frag = Frag.const(s"SELECT $allCols FROM $tbl") + SqlOps.query[E](frag)(using con, codec) + } + + def findById(id: ID)(using con: DbCon): Option[E] = { + val frag = Frag( + IndexedSeq(s"SELECT $allCols FROM $tbl WHERE $idColumn = ", ""), + idCodec.toDbValues(id) + ) + SqlOps.queryOne[E](frag)(using con, codec) + } + + def existsById(id: ID)(using con: DbCon): Boolean = + findById(id).isDefined + + def count(using con: DbCon): Long = { + val frag = Frag.const(s"SELECT COUNT(*) FROM $tbl") + SqlOps.queryOne[Long](frag)(using con, longCodec).getOrElse(0L) + } + + // === Write Operations === + + def insert(entity: E)(using con: DbCon): Int = { + val values = codec.toDbValues(entity) + val frag = Repo.buildInsertFrag(tbl, allCols, values) + SqlOps.update(frag)(using con) + } + + // Inserts the entity and returns it by re-reading from the database using its ID. + // Note: for auto-generated IDs, use RETURNING clause or getGeneratedKeys instead. + def insertReturning(entity: E)(using con: DbCon): E = { + insert(entity) + findById(getId(entity)).getOrElse( + throw new NoSuchElementException(s"Entity not found after insert in table $tbl") + ) + } + + def insertAll(entities: Iterable[E])(using con: DbCon): Int = { + if (entities.isEmpty) return 0 + val first = entities.head + val values = codec.toDbValues(first) + val sqlStr = Repo.buildInsertFrag(tbl, allCols, values).sql(con.dialect) + val start = System.nanoTime() + try { + val ps = con.connection.prepareStatement(sqlStr) + try { + entities.foreach { entity => + val vals = codec.toDbValues(entity) + SqlOps.writeParams(ps.paramWriter, vals) + ps.addBatch() + } + val counts = ps.executeBatch() + val total = counts.sum + val duration = java.time.Duration.ofNanos(System.nanoTime() - start) + con.logger.onSuccess(SqlLogger.SuccessEvent(sqlStr, IndexedSeq.empty, duration, total)) + total + } finally ps.close() + } catch { + case e: Throwable => + val duration = java.time.Duration.ofNanos(System.nanoTime() - start) + con.logger.onError(SqlLogger.ErrorEvent(sqlStr, IndexedSeq.empty, duration, e)) + throw e + } + } + + def update(entity: E)(using con: DbCon): Int = { + val entityValues = codec.toDbValues(entity) + val idValues = idCodec.toDbValues(getId(entity)) + val frag = Repo.buildUpdateFrag(tbl, table.columns, entityValues, idColumn, idValues) + SqlOps.update(frag)(using con) + } + + def deleteById(id: ID)(using con: DbCon): Int = { + val frag = Frag( + IndexedSeq(s"DELETE FROM $tbl WHERE $idColumn = ", ""), + idCodec.toDbValues(id) + ) + SqlOps.update(frag)(using con) + } + + def delete(entity: E)(using con: DbCon): Int = + deleteById(getId(entity)) + + def truncate()(using con: DbCon): Int = + SqlOps.update(Frag.const(s"DELETE FROM $tbl"))(using con) +} + +object Repo { + + def apply[E, ID]( + table: Table[E], + idColumn: String, + idCodec: DbCodec[ID], + getId: E => ID + ): Repo[E, ID] = new Repo(table, idColumn, idCodec, getId) + + private[sql] def buildInsertFrag( + tableName: String, + allColumns: String, + values: IndexedSeq[DbValue] + ): Frag = + if (values.isEmpty) Frag.const(s"INSERT INTO $tableName ($allColumns) VALUES ()") + else { + val parts = + IndexedSeq(s"INSERT INTO $tableName ($allColumns) VALUES (") ++ + IndexedSeq.fill(values.size - 1)(", ") :+ + ")" + Frag(parts, values) + } + + private[sql] def buildUpdateFrag( + tableName: String, + columns: IndexedSeq[String], + entityValues: IndexedSeq[DbValue], + idColumn: String, + idValues: IndexedSeq[DbValue] + ): Frag = { + val allValues = entityValues ++ idValues + val partsB = IndexedSeq.newBuilder[String] + + partsB += s"UPDATE $tableName SET ${columns(0)} = " + + var i = 1 + while (i < columns.size) { + partsB += s", ${columns(i)} = " + i += 1 + } + + partsB += s" WHERE $idColumn = " + partsB += "" + + Frag(partsB.result(), allValues) + } +} diff --git a/sql/shared/src/main/scala/zio/blocks/sql/SqlDialect.scala b/sql/shared/src/main/scala/zio/blocks/sql/SqlDialect.scala new file mode 100644 index 0000000000..5b1b9bb3b8 --- /dev/null +++ b/sql/shared/src/main/scala/zio/blocks/sql/SqlDialect.scala @@ -0,0 +1,63 @@ +package zio.blocks.sql + +sealed trait SqlDialect { + def name: String + def typeName(dbValue: DbValue): String + def paramPlaceholder(index: Int): String +} + +object SqlDialect { + case object PostgreSQL extends SqlDialect { + val name: String = "PostgreSQL" + + def typeName(dbValue: DbValue): String = dbValue match { + case DbValue.DbNull => "NULL" + case _: DbValue.DbInt => "INTEGER" + case _: DbValue.DbLong => "BIGINT" + case _: DbValue.DbDouble => "DOUBLE PRECISION" + case _: DbValue.DbFloat => "REAL" + case _: DbValue.DbBoolean => "BOOLEAN" + case _: DbValue.DbString => "TEXT" + case _: DbValue.DbBigDecimal => "NUMERIC" + case _: DbValue.DbBytes => "BYTEA" + case _: DbValue.DbShort => "SMALLINT" + case _: DbValue.DbByte => "SMALLINT" + case _: DbValue.DbChar => "CHAR(1)" + case _: DbValue.DbLocalDate => "DATE" + case _: DbValue.DbLocalDateTime => "TIMESTAMP" + case _: DbValue.DbLocalTime => "TIME" + case _: DbValue.DbInstant => "TIMESTAMPTZ" + case _: DbValue.DbDuration => "INTERVAL" + case _: DbValue.DbUUID => "UUID" + } + + def paramPlaceholder(index: Int): String = s"$$$index" + } + + case object SQLite extends SqlDialect { + val name: String = "SQLite" + + def typeName(dbValue: DbValue): String = dbValue match { + case DbValue.DbNull => "NULL" + case _: DbValue.DbInt => "INTEGER" + case _: DbValue.DbLong => "INTEGER" + case _: DbValue.DbDouble => "REAL" + case _: DbValue.DbFloat => "REAL" + case _: DbValue.DbBoolean => "INTEGER" + case _: DbValue.DbString => "TEXT" + case _: DbValue.DbBigDecimal => "TEXT" + case _: DbValue.DbBytes => "BLOB" + case _: DbValue.DbShort => "INTEGER" + case _: DbValue.DbByte => "INTEGER" + case _: DbValue.DbChar => "TEXT" + case _: DbValue.DbLocalDate => "TEXT" + case _: DbValue.DbLocalDateTime => "TEXT" + case _: DbValue.DbLocalTime => "TEXT" + case _: DbValue.DbInstant => "TEXT" + case _: DbValue.DbDuration => "TEXT" + case _: DbValue.DbUUID => "TEXT" + } + + def paramPlaceholder(index: Int): String = "?" + } +} diff --git a/sql/shared/src/main/scala/zio/blocks/sql/SqlInterpolator.scala b/sql/shared/src/main/scala/zio/blocks/sql/SqlInterpolator.scala new file mode 100644 index 0000000000..4eb8e22e2f --- /dev/null +++ b/sql/shared/src/main/scala/zio/blocks/sql/SqlInterpolator.scala @@ -0,0 +1,94 @@ +package zio.blocks.sql + +import java.time._ +import java.util.UUID + +trait DbParam[A] { + def toDbValue(value: A): DbValue +} + +object DbParam { + def apply[A](implicit p: DbParam[A]): DbParam[A] = p + + given DbParam[Int] with { + def toDbValue(v: Int): DbValue = DbValue.DbInt(v) + } + + given DbParam[Long] with { + def toDbValue(v: Long): DbValue = DbValue.DbLong(v) + } + + given DbParam[Double] with { + def toDbValue(v: Double): DbValue = DbValue.DbDouble(v) + } + + given DbParam[Float] with { + def toDbValue(v: Float): DbValue = DbValue.DbFloat(v) + } + + given DbParam[Boolean] with { + def toDbValue(v: Boolean): DbValue = DbValue.DbBoolean(v) + } + + given DbParam[String] with { + def toDbValue(v: String): DbValue = DbValue.DbString(v) + } + + given DbParam[Short] with { + def toDbValue(v: Short): DbValue = DbValue.DbShort(v) + } + + given DbParam[Byte] with { + def toDbValue(v: Byte): DbValue = DbValue.DbByte(v) + } + + given DbParam[BigDecimal] with { + def toDbValue(v: BigDecimal): DbValue = DbValue.DbBigDecimal(v) + } + + given dbParamBytes: DbParam[Array[Byte]] with { + def toDbValue(v: Array[Byte]): DbValue = DbValue.DbBytes(v) + } + + given DbParam[LocalDate] with { + def toDbValue(v: LocalDate): DbValue = DbValue.DbLocalDate(v) + } + + given DbParam[LocalDateTime] with { + def toDbValue(v: LocalDateTime): DbValue = DbValue.DbLocalDateTime(v) + } + + given DbParam[LocalTime] with { + def toDbValue(v: LocalTime): DbValue = DbValue.DbLocalTime(v) + } + + given DbParam[Instant] with { + def toDbValue(v: Instant): DbValue = DbValue.DbInstant(v) + } + + given DbParam[Duration] with { + def toDbValue(v: Duration): DbValue = DbValue.DbDuration(v) + } + + given DbParam[UUID] with { + def toDbValue(v: UUID): DbValue = DbValue.DbUUID(v) + } + + given DbParam[DbValue] with { + def toDbValue(v: DbValue): DbValue = v + } + + given [A](using inner: DbParam[A]): DbParam[Option[A]] with { + def toDbValue(v: Option[A]): DbValue = v match { + case Some(a) => inner.toDbValue(a) + case None => DbValue.DbNull + } + } +} + +extension (sc: StringContext) { + def sql(args: DbValue*): Frag = + Frag(sc.parts.toIndexedSeq, args.toIndexedSeq) +} + +given dbParamToDbValue[A](using p: DbParam[A]): Conversion[A, DbValue] = p.toDbValue(_) diff --git a/sql/shared/src/main/scala/zio/blocks/sql/SqlLogger.scala b/sql/shared/src/main/scala/zio/blocks/sql/SqlLogger.scala new file mode 100644 index 0000000000..163c8c0428 --- /dev/null +++ b/sql/shared/src/main/scala/zio/blocks/sql/SqlLogger.scala @@ -0,0 +1,30 @@ +package zio.blocks.sql + +import java.time.Duration + +trait SqlLogger { + def onSuccess(event: SqlLogger.SuccessEvent): Unit + def onError(event: SqlLogger.ErrorEvent): Unit +} + +object SqlLogger { + + final case class SuccessEvent( + sql: String, + params: IndexedSeq[DbValue], + duration: Duration, + rowCount: Int + ) + + final case class ErrorEvent( + sql: String, + params: IndexedSeq[DbValue], + duration: Duration, + error: Throwable + ) + + val noop: SqlLogger = new SqlLogger { + def onSuccess(event: SuccessEvent): Unit = () + def onError(event: ErrorEvent): Unit = () + } +} diff --git a/sql/shared/src/main/scala/zio/blocks/sql/SqlNameMapper.scala b/sql/shared/src/main/scala/zio/blocks/sql/SqlNameMapper.scala new file mode 100644 index 0000000000..5612ef03dd --- /dev/null +++ b/sql/shared/src/main/scala/zio/blocks/sql/SqlNameMapper.scala @@ -0,0 +1,105 @@ +/* + * Copyright 2024-2026 John A. De Goes and the ZIO Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package zio.blocks.sql + +import java.lang.Character._ +import java.lang + +/** + * A sealed trait that represents a strategy for mapping column names. Classes + * or objects that extend `SqlNameMapper` provide an implementation of the + * `apply` method, which specifies how a field name is transformed to a column + * name. + */ +sealed trait SqlNameMapper extends (String => String) + +/** + * The `SqlNameMapper` object provides predefined strategies for transforming + * field names into SQL column naming conventions. + * + * This object defines several `SqlNameMapper` implementations: + * + * - `SnakeCase`: Transforms strings to snake_case (default for SQL). For + * example, "firstName" → "first_name", "userID" → "user_id". + * - `Identity`: Returns the input string as-is, performing no transformation. + * - `Custom`: Allows for user-defined transformations by applying a given + * function to the input column name. + */ +object SqlNameMapper { + + private[this] def enforceSnakeCase(s: String): String = { + val len = s.length + val sb = new lang.StringBuilder(len << 1) + var i = 0 + var isPrecedingNotUpperCased = false + while (i < len) isPrecedingNotUpperCased = { + val ch = s.charAt(i) + i += 1 + if (ch == '_' || ch == '-') { + sb.append('_') + false + } else if (!isUpperCase(ch)) { + sb.append(ch) + true + } else { + if (isPrecedingNotUpperCased || i > 1 && i < len && !isUpperCase(s.charAt(i))) + sb.append('_') + sb.append(toLowerCase(ch)) + false + } + } + sb.toString + } + + /** + * A predefined implementation of the [[SqlNameMapper]] trait that converts a + * given field name into snake_case format by replacing transitions between + * uppercase and lowercase letters with underscores (`_`) and converting all + * characters to lowercase. + * + * For example, "firstName" is transformed into "first_name", and "userID" is + * transformed into "user_id". + */ + case object SnakeCase extends SqlNameMapper { + override def apply(fieldName: String): String = enforceSnakeCase(fieldName) + } + + /** + * An implementation of the `SqlNameMapper` trait that performs no + * transformation on the provided field name. The identity operation is + * applied, where the input string is returned unchanged. + */ + case object Identity extends SqlNameMapper { + override def apply(fieldName: String): String = fieldName + } + + /** + * A case class that provides a custom implementation of the `SqlNameMapper` + * trait. + * + * The `Custom` class allows for the transformation of field names using a + * user-defined function. This transformation logic is encapsulated in the + * function `f` provided at instantiation. + * + * @param f + * A function that defines how to transform a field name. The function takes + * a string as input and returns the transformed column name as output. + */ + final case class Custom(f: String => String) extends SqlNameMapper { + override def apply(fieldName: String): String = f(fieldName) + } +} diff --git a/sql/shared/src/main/scala/zio/blocks/sql/SqlOps.scala b/sql/shared/src/main/scala/zio/blocks/sql/SqlOps.scala new file mode 100644 index 0000000000..5a7a4cb75d --- /dev/null +++ b/sql/shared/src/main/scala/zio/blocks/sql/SqlOps.scala @@ -0,0 +1,134 @@ +package zio.blocks.sql + +object SqlOps { + + def query[A](frag: Frag)(using con: DbCon, codec: DbCodec[A]): List[A] = { + val sqlStr = frag.sql(con.dialect) + val start = System.nanoTime() + try { + val ps = con.connection.prepareStatement(sqlStr) + try { + writeParams(ps.paramWriter, frag.queryParams) + val rs = ps.executeQuery() + try { + val reader = rs.reader + val builder = List.newBuilder[A] + var count = 0 + while (rs.next()) { + builder += codec.readValue(reader, 1) + count += 1 + } + val duration = java.time.Duration.ofNanos(System.nanoTime() - start) + con.logger.onSuccess(SqlLogger.SuccessEvent(sqlStr, frag.queryParams, duration, count)) + builder.result() + } finally rs.close() + } finally ps.close() + } catch { + case e: Throwable => + val duration = java.time.Duration.ofNanos(System.nanoTime() - start) + con.logger.onError(SqlLogger.ErrorEvent(sqlStr, frag.queryParams, duration, e)) + throw e + } + } + + def queryLimit[A](frag: Frag, limit: Int)(using con: DbCon, codec: DbCodec[A]): List[A] = { + val sqlStr = frag.sql(con.dialect) + val start = System.nanoTime() + try { + val ps = con.connection.prepareStatement(sqlStr) + try { + writeParams(ps.paramWriter, frag.queryParams) + val rs = ps.executeQuery() + try { + val reader = rs.reader + val builder = List.newBuilder[A] + var count = 0 + while (count < limit && rs.next()) { + builder += codec.readValue(reader, 1) + count += 1 + } + val duration = java.time.Duration.ofNanos(System.nanoTime() - start) + con.logger.onSuccess(SqlLogger.SuccessEvent(sqlStr, frag.queryParams, duration, count)) + builder.result() + } finally rs.close() + } finally ps.close() + } catch { + case e: Throwable => + val duration = java.time.Duration.ofNanos(System.nanoTime() - start) + con.logger.onError(SqlLogger.ErrorEvent(sqlStr, frag.queryParams, duration, e)) + throw e + } + } + + def queryOne[A](frag: Frag)(using con: DbCon, codec: DbCodec[A]): Option[A] = { + val sqlStr = frag.sql(con.dialect) + val start = System.nanoTime() + try { + val ps = con.connection.prepareStatement(sqlStr) + try { + writeParams(ps.paramWriter, frag.queryParams) + val rs = ps.executeQuery() + try { + val result = if (rs.next()) Some(codec.readValue(rs.reader, 1)) else None + val count = if (result.isDefined) 1 else 0 + val duration = java.time.Duration.ofNanos(System.nanoTime() - start) + con.logger.onSuccess(SqlLogger.SuccessEvent(sqlStr, frag.queryParams, duration, count)) + result + } finally rs.close() + } finally ps.close() + } catch { + case e: Throwable => + val duration = java.time.Duration.ofNanos(System.nanoTime() - start) + con.logger.onError(SqlLogger.ErrorEvent(sqlStr, frag.queryParams, duration, e)) + throw e + } + } + + def update(frag: Frag)(using con: DbCon): Int = { + val sqlStr = frag.sql(con.dialect) + val start = System.nanoTime() + try { + val ps = con.connection.prepareStatement(sqlStr) + try { + writeParams(ps.paramWriter, frag.queryParams) + val count = ps.executeUpdate() + val duration = java.time.Duration.ofNanos(System.nanoTime() - start) + con.logger.onSuccess(SqlLogger.SuccessEvent(sqlStr, frag.queryParams, duration, count)) + count + } finally ps.close() + } catch { + case e: Throwable => + val duration = java.time.Duration.ofNanos(System.nanoTime() - start) + con.logger.onError(SqlLogger.ErrorEvent(sqlStr, frag.queryParams, duration, e)) + throw e + } + } + + private[sql] def writeParams(writer: DbParamWriter, params: IndexedSeq[DbValue]): Unit = { + var i = 0 + while (i < params.length) { + val idx = i + 1 + params(i) match { + case DbValue.DbNull => writer.setNull(idx, 0) + case DbValue.DbInt(v) => writer.setInt(idx, v) + case DbValue.DbLong(v) => writer.setLong(idx, v) + case DbValue.DbDouble(v) => writer.setDouble(idx, v) + case DbValue.DbFloat(v) => writer.setFloat(idx, v) + case DbValue.DbBoolean(v) => writer.setBoolean(idx, v) + case DbValue.DbString(v) => writer.setString(idx, v) + case DbValue.DbBigDecimal(v) => writer.setBigDecimal(idx, v.bigDecimal) + case DbValue.DbBytes(v) => writer.setBytes(idx, v) + case DbValue.DbShort(v) => writer.setShort(idx, v) + case DbValue.DbByte(v) => writer.setByte(idx, v) + case DbValue.DbChar(v) => writer.setString(idx, v.toString) + case DbValue.DbLocalDate(v) => writer.setLocalDate(idx, v) + case DbValue.DbLocalDateTime(v) => writer.setLocalDateTime(idx, v) + case DbValue.DbLocalTime(v) => writer.setLocalTime(idx, v) + case DbValue.DbInstant(v) => writer.setInstant(idx, v) + case DbValue.DbDuration(v) => writer.setDuration(idx, v) + case DbValue.DbUUID(v) => writer.setUUID(idx, v) + } + i += 1 + } + } +} diff --git a/sql/shared/src/main/scala/zio/blocks/sql/Table.scala b/sql/shared/src/main/scala/zio/blocks/sql/Table.scala new file mode 100644 index 0000000000..01b469f815 --- /dev/null +++ b/sql/shared/src/main/scala/zio/blocks/sql/Table.scala @@ -0,0 +1,50 @@ +package zio.blocks.sql + +import zio.blocks.schema._ + +final case class Table[A](name: String, codec: DbCodec[A], dialect: SqlDialect) { + def columns: IndexedSeq[String] = codec.columns + + def createTable: Frag = { + val columnDefs = codec.columns.map { col => + ColumnDef(col, dialect.typeName(DbValue.DbString("")), nullable = false) + } + Ddl.createTable(name, columnDefs) + } + + def dropTable: Frag = Ddl.dropTable(name) +} + +object Table { + + def derived[A](dialect: SqlDialect)(implicit schema: Schema[A]): Table[A] = { + val codec = schema.deriving(DbCodecDeriver).derive + val tableName = deriveTableName(schema) + Table(tableName, codec, dialect) + } + + def derived[A](tableName: String, dialect: SqlDialect)(implicit schema: Schema[A]): Table[A] = { + val codec = schema.deriving(DbCodecDeriver).derive + Table(tableName, codec, dialect) + } + + private def deriveTableName[A](schema: Schema[A]): String = { + val configured = schema.reflect.modifiers.collectFirst { case Modifier.config("sql.table_name", value) => + value + } + configured.getOrElse { + val typeName = schema.reflect.typeId.name + SqlNameMapper.SnakeCase(typeName) + } + } + + def pluralize(s: String): String = + if (s.isEmpty) s + else if (s.endsWith("s") || s.endsWith("x") || s.endsWith("ch") || s.endsWith("sh") || s.endsWith("zz")) + s + "es" + else if (s.endsWith("z")) s + "zes" // quiz -> quizzes + else if (s.endsWith("y") && s.length > 1 && !isVowel(s.charAt(s.length - 2))) s.dropRight(1) + "ies" + else s + "s" + + private def isVowel(c: Char): Boolean = "aeiouAEIOU".indexOf(c) >= 0 +} diff --git a/sql/shared/src/main/scala/zio/blocks/sql/Transactor.scala b/sql/shared/src/main/scala/zio/blocks/sql/Transactor.scala new file mode 100644 index 0000000000..878e3ec433 --- /dev/null +++ b/sql/shared/src/main/scala/zio/blocks/sql/Transactor.scala @@ -0,0 +1,6 @@ +package zio.blocks.sql + +trait Transactor { + def connect[A](f: DbCon ?=> A): A + def transact[A](f: DbTx ?=> A): A +} diff --git a/sql/shared/src/test/scala/zio/blocks/sql/.gitkeep b/sql/shared/src/test/scala/zio/blocks/sql/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sql/shared/src/test/scala/zio/blocks/sql/DbCodecSpec.scala b/sql/shared/src/test/scala/zio/blocks/sql/DbCodecSpec.scala new file mode 100644 index 0000000000..c5c34573d5 --- /dev/null +++ b/sql/shared/src/test/scala/zio/blocks/sql/DbCodecSpec.scala @@ -0,0 +1,314 @@ +package zio.blocks.sql + +import zio.test._ +import zio.blocks.schema._ + +object DbCodecSpec extends ZIOSpecDefault { + + case class SimpleRecord(name: String, age: Int) + object SimpleRecord { + implicit val schema: Schema[SimpleRecord] = Schema.derived + } + + case class WithOption(id: Int, nickname: Option[String]) + object WithOption { + implicit val schema: Schema[WithOption] = Schema.derived + } + + case class WithTransient( + name: String, + @Modifier.transient() hidden: Int = 0 + ) + object WithTransient { + implicit val schema: Schema[WithTransient] = Schema.derived + } + + case class WithRename( + @Modifier.rename("user_name") name: String, + age: Int + ) + object WithRename { + implicit val schema: Schema[WithRename] = Schema.derived + } + + case class AllPrimitives( + b: Boolean, + by: Byte, + s: Short, + i: Int, + l: Long, + f: Float, + d: Double, + str: String + ) + object AllPrimitives { + implicit val schema: Schema[AllPrimitives] = Schema.derived + } + + case class CamelCaseFields(firstName: String, lastName: String) + object CamelCaseFields { + implicit val schema: Schema[CamelCaseFields] = Schema.derived + } + + enum Color { + case Red, Green, Blue + } + object Color { + implicit val schema: Schema[Color] = Schema.derived + } + + enum Status { + case Active, Inactive, Pending + } + object Status { + implicit val schema: Schema[Status] = Schema.derived + } + + case class WithEnum(id: Int, color: Color) + object WithEnum { + implicit val schema: Schema[WithEnum] = Schema.derived + } + + case class WithOptionalEnum(id: Int, status: Option[Status]) + object WithOptionalEnum { + implicit val schema: Schema[WithOptionalEnum] = Schema.derived + } + + private def deriveCodec[A](implicit s: Schema[A]): DbCodec[A] = + s.deriving(DbCodecDeriver).derive + + private def deriveCodecWithMapper[A]( + mapper: SqlNameMapper + )(implicit s: Schema[A]): DbCodec[A] = + s.deriving(DbCodecDeriver.withColumnNameMapper(mapper)).derive + + def spec: Spec[TestEnvironment, Any] = suite("DbCodecSpec")( + suite("primitive derivation")( + test("Int codec has single column") { + val codec = deriveCodec[Int] + assertTrue( + codec.columns.size == 1, + codec.columnCount == 1 + ) + }, + test("String codec has single column") { + val codec = deriveCodec[String] + assertTrue(codec.columns.size == 1) + }, + test("Boolean codec has single column") { + val codec = deriveCodec[Boolean] + assertTrue(codec.columns.size == 1) + }, + test("Long codec has single column") { + val codec = deriveCodec[Long] + assertTrue(codec.columns.size == 1) + }, + test("Int toDbValues produces DbInt") { + val codec = deriveCodec[Int] + val values = codec.toDbValues(42) + assertTrue(values == IndexedSeq(DbValue.DbInt(42))) + }, + test("String toDbValues produces DbString") { + val codec = deriveCodec[String] + val values = codec.toDbValues("hello") + assertTrue(values == IndexedSeq(DbValue.DbString("hello"))) + }, + test("Boolean toDbValues produces DbBoolean") { + val codec = deriveCodec[Boolean] + val values = codec.toDbValues(true) + assertTrue(values == IndexedSeq(DbValue.DbBoolean(true))) + }, + test("Unit codec has zero columns") { + val codec = deriveCodec[Unit] + assertTrue( + codec.columns.isEmpty, + codec.columnCount == 0, + codec.toDbValues(()) == IndexedSeq.empty + ) + } + ), + suite("record derivation")( + test("simple case class columns match field names") { + val codec = deriveCodec[SimpleRecord] + assertTrue( + codec.columns == IndexedSeq("name", "age"), + codec.columnCount == 2 + ) + }, + test("simple case class toDbValues") { + val codec = deriveCodec[SimpleRecord] + val values = codec.toDbValues(SimpleRecord("Alice", 30)) + assertTrue( + values == IndexedSeq(DbValue.DbString("Alice"), DbValue.DbInt(30)) + ) + }, + test("all primitives record has correct column count") { + val codec = deriveCodec[AllPrimitives] + assertTrue(codec.columnCount == 8) + }, + test("all primitives record toDbValues") { + val codec = deriveCodec[AllPrimitives] + val value = AllPrimitives( + b = true, + by = 1, + s = 2, + i = 3, + l = 4L, + f = 5.0f, + d = 6.0, + str = "hello" + ) + val dbValues = codec.toDbValues(value) + assertTrue( + dbValues == IndexedSeq( + DbValue.DbBoolean(true), + DbValue.DbByte(1), + DbValue.DbShort(2), + DbValue.DbInt(3), + DbValue.DbLong(4L), + DbValue.DbFloat(5.0f), + DbValue.DbDouble(6.0), + DbValue.DbString("hello") + ) + ) + } + ), + suite("option handling")( + test("Option field is nullable column") { + val codec = deriveCodec[WithOption] + assertTrue( + codec.columns == IndexedSeq("id", "nickname"), + codec.columnCount == 2 + ) + }, + test("Option Some produces inner value") { + val codec = deriveCodec[WithOption] + val values = codec.toDbValues(WithOption(1, Some("nick"))) + assertTrue( + values == IndexedSeq(DbValue.DbInt(1), DbValue.DbString("nick")) + ) + }, + test("Option None produces DbNull") { + val codec = deriveCodec[WithOption] + val values = codec.toDbValues(WithOption(1, None)) + assertTrue( + values == IndexedSeq(DbValue.DbInt(1), DbValue.DbNull) + ) + } + ), + suite("enum/sealed trait handling")( + test("simple enum produces single String column") { + val codec = deriveCodec[Color] + assertTrue( + codec.columns == IndexedSeq("value"), + codec.columnCount == 1 + ) + }, + test("enum toDbValues produces variant name as DbString") { + val codec = deriveCodec[Color] + assertTrue( + codec.toDbValues(Color.Red) == IndexedSeq(DbValue.DbString("Red")), + codec.toDbValues(Color.Green) == IndexedSeq(DbValue.DbString("Green")), + codec.toDbValues(Color.Blue) == IndexedSeq(DbValue.DbString("Blue")) + ) + }, + test("enum in record uses snake_case column name") { + val codec = deriveCodec[WithEnum] + assertTrue( + codec.columns == IndexedSeq("id", "color"), + codec.columnCount == 2 + ) + }, + test("enum in record toDbValues") { + val codec = deriveCodec[WithEnum] + val values = codec.toDbValues(WithEnum(1, Color.Blue)) + assertTrue( + values == IndexedSeq(DbValue.DbInt(1), DbValue.DbString("Blue")) + ) + }, + test("optional enum with Some produces string value") { + val codec = deriveCodec[WithOptionalEnum] + val values = codec.toDbValues(WithOptionalEnum(1, Some(Status.Active))) + assertTrue( + values == IndexedSeq(DbValue.DbInt(1), DbValue.DbString("Active")) + ) + }, + test("optional enum with None produces DbNull") { + val codec = deriveCodec[WithOptionalEnum] + val values = codec.toDbValues(WithOptionalEnum(1, None)) + assertTrue( + values == IndexedSeq(DbValue.DbInt(1), DbValue.DbNull) + ) + }, + test("all enum variants round-trip via toDbValues") { + val codec = deriveCodec[Status] + assertTrue( + codec.toDbValues(Status.Active) == IndexedSeq(DbValue.DbString("Active")), + codec.toDbValues(Status.Inactive) == IndexedSeq(DbValue.DbString("Inactive")), + codec.toDbValues(Status.Pending) == IndexedSeq(DbValue.DbString("Pending")) + ) + } + ), + suite("modifier handling")( + test("transient field excluded from columns") { + val codec = deriveCodec[WithTransient] + assertTrue( + codec.columns == IndexedSeq("name"), + codec.columnCount == 1 + ) + }, + test("transient field excluded from toDbValues") { + val codec = deriveCodec[WithTransient] + val values = codec.toDbValues(WithTransient("Alice", 42)) + assertTrue(values == IndexedSeq(DbValue.DbString("Alice"))) + }, + test("rename modifier uses custom name") { + val codec = deriveCodec[WithRename] + assertTrue( + codec.columns == IndexedSeq("user_name", "age") + ) + } + ), + suite("column name mapping")( + test("default SnakeCase mapper: camelCase fields become snake_case columns") { + val codec = deriveCodec[CamelCaseFields] + assertTrue( + codec.columns == IndexedSeq("first_name", "last_name"), + codec.columnCount == 2 + ) + }, + test("Modifier.rename overrides SnakeCase mapper") { + val codec = deriveCodec[WithRename] + assertTrue( + codec.columns == IndexedSeq("user_name", "age") + ) + }, + test("Identity mapper preserves camelCase field names") { + val codec = CamelCaseFields.schema.deriving(DbCodecDeriver.withColumnNameMapper(SqlNameMapper.Identity)).derive + assertTrue( + codec.columns == IndexedSeq("firstName", "lastName"), + codec.columnCount == 2 + ) + }, + test("Custom mapper applies custom function") { + val upperMapper = SqlNameMapper.Custom(_.toUpperCase) + val codec = CamelCaseFields.schema.deriving(DbCodecDeriver.withColumnNameMapper(upperMapper)).derive + assertTrue( + codec.columns == IndexedSeq("FIRSTNAME", "LASTNAME") + ) + } + ), + suite("unsupported types")( + test("List throws UnsupportedOperationException") { + val result = + try { + deriveCodec[List[Int]] + false + } catch { + case _: UnsupportedOperationException => true + } + assertTrue(result) + } + ) + ) +} diff --git a/sql/shared/src/test/scala/zio/blocks/sql/DbValueSpec.scala b/sql/shared/src/test/scala/zio/blocks/sql/DbValueSpec.scala new file mode 100644 index 0000000000..98a7f189e9 --- /dev/null +++ b/sql/shared/src/test/scala/zio/blocks/sql/DbValueSpec.scala @@ -0,0 +1,96 @@ +package zio.blocks.sql + +import zio.test.* +import java.time.{Duration, Instant, LocalDate, LocalDateTime, LocalTime} +import java.util.UUID + +object DbValueSpec extends ZIOSpecDefault { + def spec = suite("DbValueSpec")( + test("DbNull is a case object") { + assertTrue(DbValue.DbNull == DbValue.DbNull) + }, + test("DbInt creation and extraction") { + val v = DbValue.DbInt(42) + assertTrue(v.value == 42) + }, + test("DbLong creation and extraction") { + val v = DbValue.DbLong(9999999999L) + assertTrue(v.value == 9999999999L) + }, + test("DbDouble creation and extraction") { + val v = DbValue.DbDouble(3.14) + assertTrue(v.value == 3.14) + }, + test("DbFloat creation and extraction") { + val v = DbValue.DbFloat(2.71f) + assertTrue(v.value == 2.71f) + }, + test("DbBoolean creation and extraction") { + val v = DbValue.DbBoolean(true) + assertTrue(v.value == true) + }, + test("DbString creation and extraction") { + val v = DbValue.DbString("hello") + assertTrue(v.value == "hello") + }, + test("DbBigDecimal creation and extraction") { + val bd = scala.BigDecimal("123.45") + val v = DbValue.DbBigDecimal(bd) + assertTrue(v.value == bd) + }, + test("DbBytes creation and extraction") { + val bytes = Array[Byte](1, 2, 3) + val v = DbValue.DbBytes(bytes) + assertTrue(v.value.sameElements(bytes)) + }, + test("DbShort creation and extraction") { + val v = DbValue.DbShort(100.toShort) + assertTrue(v.value == 100.toShort) + }, + test("DbByte creation and extraction") { + val v = DbValue.DbByte(50.toByte) + assertTrue(v.value == 50.toByte) + }, + test("DbChar creation and extraction") { + val v = DbValue.DbChar('A') + assertTrue(v.value == 'A') + }, + test("DbLocalDate creation and extraction") { + val ld = LocalDate.of(2024, 3, 14) + val v = DbValue.DbLocalDate(ld) + assertTrue(v.value == ld) + }, + test("DbLocalDateTime creation and extraction") { + val ldt = LocalDateTime.of(2024, 3, 14, 12, 0) + val v = DbValue.DbLocalDateTime(ldt) + assertTrue(v.value == ldt) + }, + test("DbLocalTime creation and extraction") { + val lt = LocalTime.of(12, 30, 45) + val v = DbValue.DbLocalTime(lt) + assertTrue(v.value == lt) + }, + test("DbInstant creation and extraction") { + val inst = Instant.now() + val v = DbValue.DbInstant(inst) + assertTrue(v.value == inst) + }, + test("DbDuration creation and extraction") { + val dur = Duration.ofHours(2) + val v = DbValue.DbDuration(dur) + assertTrue(v.value == dur) + }, + test("DbUUID creation and extraction") { + val uuid = UUID.fromString("550e8400-e29b-41d4-a716-446655440000") + val v = DbValue.DbUUID(uuid) + assertTrue(v.value == uuid) + }, + test("all DbValue types are case classes or case objects") { + assertTrue( + DbValue.DbNull.isInstanceOf[DbValue] && + DbValue.DbInt(1).isInstanceOf[DbValue] && + DbValue.DbString("x").isInstanceOf[DbValue] + ) + } + ) +} diff --git a/sql/shared/src/test/scala/zio/blocks/sql/DdlSpec.scala b/sql/shared/src/test/scala/zio/blocks/sql/DdlSpec.scala new file mode 100644 index 0000000000..3813e9cdbb --- /dev/null +++ b/sql/shared/src/test/scala/zio/blocks/sql/DdlSpec.scala @@ -0,0 +1,98 @@ +package zio.blocks.sql + +import zio.test.* + +object DdlSpec extends ZIOSpecDefault { + def spec = suite("DdlSpec")( + suite("createTable")( + test("generates CREATE TABLE IF NOT EXISTS with columns") { + val columns = IndexedSeq( + ColumnDef("id", "INTEGER", false), + ColumnDef("name", "TEXT", false), + ColumnDef("email", "TEXT", true) + ) + val frag = Ddl.createTable("users", columns) + val sql = frag.sql(SqlDialect.PostgreSQL) + assertTrue( + sql.contains("CREATE TABLE IF NOT EXISTS users"), + sql.contains("id INTEGER NOT NULL"), + sql.contains("name TEXT NOT NULL"), + sql.contains("email TEXT"), + !sql.contains("email TEXT NOT NULL") + ) + }, + test("nullable columns omit NOT NULL") { + val columns = IndexedSeq(ColumnDef("bio", "TEXT", true)) + val sql = Ddl.createTable("profiles", columns).sql(SqlDialect.PostgreSQL) + assertTrue( + sql.contains("bio TEXT"), + !sql.contains("NOT NULL") + ) + }, + test("non-nullable columns include NOT NULL") { + val columns = IndexedSeq(ColumnDef("id", "INTEGER", false)) + val sql = Ddl.createTable("items", columns).sql(SqlDialect.PostgreSQL) + assertTrue(sql.contains("id INTEGER NOT NULL")) + }, + test("multiple columns are comma-separated") { + val columns = IndexedSeq( + ColumnDef("a", "INTEGER", false), + ColumnDef("b", "TEXT", false), + ColumnDef("c", "REAL", true) + ) + val sql = Ddl.createTable("t", columns).sql(SqlDialect.PostgreSQL) + assertTrue( + sql.contains("a INTEGER NOT NULL,\n"), + sql.contains("b TEXT NOT NULL,\n"), + sql.contains("c REAL\n") + ) + }, + test("PostgreSQL type names in column definitions") { + val columns = IndexedSeq( + ColumnDef("active", "BOOLEAN", false), + ColumnDef("data", "BYTEA", false), + ColumnDef("id", "UUID", false) + ) + val sql = Ddl.createTable("pg_table", columns).sql(SqlDialect.PostgreSQL) + assertTrue( + sql.contains("active BOOLEAN NOT NULL"), + sql.contains("data BYTEA NOT NULL"), + sql.contains("id UUID NOT NULL") + ) + }, + test("SQLite type names in column definitions") { + val columns = IndexedSeq( + ColumnDef("active", "INTEGER", false), + ColumnDef("data", "BLOB", false), + ColumnDef("id", "TEXT", false) + ) + val sql = Ddl.createTable("sqlite_table", columns).sql(SqlDialect.SQLite) + assertTrue( + sql.contains("active INTEGER NOT NULL"), + sql.contains("data BLOB NOT NULL"), + sql.contains("id TEXT NOT NULL") + ) + }, + test("result is a parameterless Frag") { + val columns = IndexedSeq(ColumnDef("x", "INTEGER", false)) + val frag = Ddl.createTable("t", columns) + assertTrue(frag.params.isEmpty) + } + ), + suite("dropTable")( + test("generates DROP TABLE IF EXISTS") { + val frag = Ddl.dropTable("users") + assertTrue(frag.sql(SqlDialect.PostgreSQL) == "DROP TABLE IF EXISTS users") + }, + test("works with any table name") { + assertTrue( + Ddl.dropTable("orders").sql(SqlDialect.SQLite) == "DROP TABLE IF EXISTS orders" + ) + }, + test("result is a parameterless Frag") { + val frag = Ddl.dropTable("t") + assertTrue(frag.params.isEmpty) + } + ) + ) +} diff --git a/sql/shared/src/test/scala/zio/blocks/sql/FragSpec.scala b/sql/shared/src/test/scala/zio/blocks/sql/FragSpec.scala new file mode 100644 index 0000000000..81cec12648 --- /dev/null +++ b/sql/shared/src/test/scala/zio/blocks/sql/FragSpec.scala @@ -0,0 +1,113 @@ +package zio.blocks.sql + +import zio.test.* + +object FragSpec extends ZIOSpecDefault { + def spec = suite("FragSpec")( + suite("sql interpolator")( + test("basic interpolation without params") { + val frag = sql"SELECT 1" + assertTrue( + frag.parts == IndexedSeq("SELECT 1"), + frag.params.isEmpty + ) + }, + test("single param") { + val frag = sql"SELECT * FROM t WHERE id = ${DbValue.DbInt(42)}" + assertTrue( + frag.parts == IndexedSeq("SELECT * FROM t WHERE id = ", ""), + frag.params == IndexedSeq(DbValue.DbInt(42)) + ) + }, + test("multiple params") { + val v1 = DbValue.DbString("Alice") + val v2 = DbValue.DbInt(30) + val frag = sql"INSERT INTO t (name, age) VALUES ($v1, $v2)" + assertTrue( + frag.parts == IndexedSeq("INSERT INTO t (name, age) VALUES (", ", ", ")"), + frag.params == IndexedSeq(v1, v2) + ) + }, + test("values are never in the SQL string") { + val name = DbValue.DbString("Robert'); DROP TABLE students;--") + val frag = sql"SELECT * FROM users WHERE name = $name" + val rendered = frag.sql(SqlDialect.PostgreSQL) + assertTrue( + !rendered.contains("Robert"), + !rendered.contains("DROP"), + rendered == "SELECT * FROM users WHERE name = $1" + ) + } + ), + suite("Frag.sql rendering")( + test("PostgreSQL uses numbered placeholders") { + val frag = sql"SELECT * FROM t WHERE a = ${DbValue.DbInt(1)} AND b = ${DbValue.DbString("x")}" + assertTrue(frag.sql(SqlDialect.PostgreSQL) == "SELECT * FROM t WHERE a = $1 AND b = $2") + }, + test("SQLite uses ? placeholders") { + val frag = sql"SELECT * FROM t WHERE a = ${DbValue.DbInt(1)} AND b = ${DbValue.DbString("x")}" + assertTrue(frag.sql(SqlDialect.SQLite) == "SELECT * FROM t WHERE a = ? AND b = ?") + }, + test("no params renders plain SQL") { + val frag = sql"SELECT 1" + assertTrue(frag.sql(SqlDialect.PostgreSQL) == "SELECT 1") + } + ), + suite("Frag composition")( + test("concatenation merges adjacent parts") { + val f1 = sql"SELECT * FROM t" + val f2 = sql" WHERE id = ${DbValue.DbInt(1)}" + val merged = f1 ++ f2 + assertTrue( + merged.parts == IndexedSeq("SELECT * FROM t WHERE id = ", ""), + merged.params == IndexedSeq(DbValue.DbInt(1)) + ) + }, + test("concatenation with params on both sides") { + val f1 = sql"a = ${DbValue.DbInt(1)} AND " + val f2 = sql"b = ${DbValue.DbString("x")}" + val merged = f1 ++ f2 + assertTrue( + merged.sql(SqlDialect.PostgreSQL) == "a = $1 AND b = $2", + merged.params == IndexedSeq(DbValue.DbInt(1), DbValue.DbString("x")) + ) + }, + test("Frag.empty is identity for ++") { + val frag = sql"SELECT 1" + assertTrue( + (frag ++ Frag.empty).sql(SqlDialect.PostgreSQL) == "SELECT 1", + (Frag.empty ++ frag).sql(SqlDialect.PostgreSQL) == "SELECT 1" + ) + } + ), + suite("Frag.const and Frag.empty")( + test("Frag.const creates parameterless fragment") { + val frag = Frag.const("ORDER BY id") + assertTrue( + frag.parts == IndexedSeq("ORDER BY id"), + frag.params.isEmpty, + frag.sql(SqlDialect.PostgreSQL) == "ORDER BY id" + ) + }, + test("Frag.empty is empty") { + assertTrue(Frag.empty.isEmpty) + }, + test("Frag with params is not empty") { + val frag = sql"SELECT ${DbValue.DbInt(1)}" + assertTrue(!frag.isEmpty) + } + ), + suite("queryParams")( + test("returns all params in order") { + val frag = sql"INSERT INTO t VALUES (${DbValue.DbInt(1)}, ${DbValue.DbString("a")}, ${DbValue.DbBoolean(true)})" + assertTrue( + frag.queryParams == IndexedSeq( + DbValue.DbInt(1), + DbValue.DbString("a"), + DbValue.DbBoolean(true) + ) + ) + } + ) + ) +} diff --git a/sql/shared/src/test/scala/zio/blocks/sql/RepoSpec.scala b/sql/shared/src/test/scala/zio/blocks/sql/RepoSpec.scala new file mode 100644 index 0000000000..98bbffc3c7 --- /dev/null +++ b/sql/shared/src/test/scala/zio/blocks/sql/RepoSpec.scala @@ -0,0 +1,85 @@ +package zio.blocks.sql + +import zio.test._ + +object RepoSpec extends ZIOSpecDefault { + + def spec: Spec[TestEnvironment, Any] = suite("RepoSpec")( + suite("buildInsertFrag")( + test("builds correct INSERT Frag for 3 values") { + val values = IndexedSeq( + DbValue.DbInt(1): DbValue, + DbValue.DbString("Alice"): DbValue, + DbValue.DbString("alice@example.com"): DbValue + ) + val frag = Repo.buildInsertFrag("user", "id, name, email", values) + assertTrue( + frag.sql(SqlDialect.SQLite) == "INSERT INTO user (id, name, email) VALUES (?, ?, ?)", + frag.sql(SqlDialect.PostgreSQL) == "INSERT INTO user (id, name, email) VALUES ($1, $2, $3)", + frag.queryParams == values + ) + }, + test("builds correct INSERT Frag for 1 value") { + val values = IndexedSeq(DbValue.DbInt(42): DbValue) + val frag = Repo.buildInsertFrag("t", "id", values) + assertTrue( + frag.sql(SqlDialect.SQLite) == "INSERT INTO t (id) VALUES (?)", + frag.queryParams == values + ) + }, + test("builds INSERT Frag with no values") { + val frag = Repo.buildInsertFrag("t", "", IndexedSeq.empty) + assertTrue( + frag.sql(SqlDialect.SQLite) == "INSERT INTO t () VALUES ()", + frag.queryParams.isEmpty + ) + } + ), + suite("buildUpdateFrag")( + test("builds correct UPDATE Frag for 2 columns + 1 id") { + val columns = IndexedSeq("name", "email") + val entityValues = IndexedSeq(DbValue.DbString("Bob"): DbValue, DbValue.DbString("bob@test.com"): DbValue) + val idValues = IndexedSeq(DbValue.DbInt(1): DbValue) + val frag = Repo.buildUpdateFrag("user", columns, entityValues, "id", idValues) + assertTrue( + frag.sql(SqlDialect.SQLite) == "UPDATE user SET name = ?, email = ? WHERE id = ?", + frag.sql(SqlDialect.PostgreSQL) == "UPDATE user SET name = $1, email = $2 WHERE id = $3", + frag.queryParams == entityValues ++ idValues + ) + }, + test("builds correct UPDATE Frag for 1 column + 1 id") { + val columns = IndexedSeq("name") + val entityValues = IndexedSeq(DbValue.DbString("Alice"): DbValue) + val idValues = IndexedSeq(DbValue.DbInt(5): DbValue) + val frag = Repo.buildUpdateFrag("t", columns, entityValues, "id", idValues) + assertTrue( + frag.sql(SqlDialect.SQLite) == "UPDATE t SET name = ? WHERE id = ?", + frag.queryParams == entityValues ++ idValues + ) + } + ), + suite("Repo construction")( + test("exposes table metadata") { + import zio.blocks.schema._ + case class Item(id: Int, name: String) + object Item { + implicit val schema: Schema[Item] = Schema.derived + } + val table = Table.derived[Item](SqlDialect.SQLite) + val idCodec: DbCodec[Int] = new DbCodec[Int] { + val columns: IndexedSeq[String] = IndexedSeq("value") + def readValue(reader: DbResultReader, startIndex: Int): Int = reader.getInt(startIndex) + def writeValue(writer: DbParamWriter, startIndex: Int, value: Int): Unit = + writer.setInt(startIndex, value) + def toDbValues(value: Int): IndexedSeq[DbValue] = IndexedSeq(DbValue.DbInt(value)) + } + val repo = Repo(table, "id", idCodec, (_: Item).id) + assertTrue( + repo.table.name == "item", + repo.idColumn == "id", + repo.table.columns == IndexedSeq("id", "name") + ) + } + ) + ) +} diff --git a/sql/shared/src/test/scala/zio/blocks/sql/SqlDialectSpec.scala b/sql/shared/src/test/scala/zio/blocks/sql/SqlDialectSpec.scala new file mode 100644 index 0000000000..f118569647 --- /dev/null +++ b/sql/shared/src/test/scala/zio/blocks/sql/SqlDialectSpec.scala @@ -0,0 +1,183 @@ +package zio.blocks.sql + +import zio.test.* +import java.time.{Duration, Instant, LocalDate, LocalDateTime, LocalTime} +import java.util.UUID + +object SqlDialectSpec extends ZIOSpecDefault { + def spec = suite("SqlDialectSpec")( + suite("PostgreSQL")( + test("name is PostgreSQL") { + assertTrue(SqlDialect.PostgreSQL.name == "PostgreSQL") + }, + test("DbNull -> NULL") { + assertTrue(SqlDialect.PostgreSQL.typeName(DbValue.DbNull) == "NULL") + }, + test("DbInt -> INTEGER") { + assertTrue(SqlDialect.PostgreSQL.typeName(DbValue.DbInt(1)) == "INTEGER") + }, + test("DbLong -> BIGINT") { + assertTrue(SqlDialect.PostgreSQL.typeName(DbValue.DbLong(1L)) == "BIGINT") + }, + test("DbDouble -> DOUBLE PRECISION") { + assertTrue( + SqlDialect.PostgreSQL.typeName(DbValue.DbDouble(1.0)) == "DOUBLE PRECISION" + ) + }, + test("DbFloat -> REAL") { + assertTrue(SqlDialect.PostgreSQL.typeName(DbValue.DbFloat(1.0f)) == "REAL") + }, + test("DbBoolean -> BOOLEAN") { + assertTrue(SqlDialect.PostgreSQL.typeName(DbValue.DbBoolean(true)) == "BOOLEAN") + }, + test("DbString -> TEXT") { + assertTrue(SqlDialect.PostgreSQL.typeName(DbValue.DbString("x")) == "TEXT") + }, + test("DbBigDecimal -> NUMERIC") { + assertTrue( + SqlDialect.PostgreSQL.typeName(DbValue.DbBigDecimal(scala.BigDecimal("1"))) == "NUMERIC" + ) + }, + test("DbBytes -> BYTEA") { + assertTrue( + SqlDialect.PostgreSQL.typeName(DbValue.DbBytes(Array[Byte](1))) == "BYTEA" + ) + }, + test("DbShort -> SMALLINT") { + assertTrue( + SqlDialect.PostgreSQL.typeName(DbValue.DbShort(1.toShort)) == "SMALLINT" + ) + }, + test("DbByte -> SMALLINT") { + assertTrue(SqlDialect.PostgreSQL.typeName(DbValue.DbByte(1.toByte)) == "SMALLINT") + }, + test("DbChar -> CHAR(1)") { + assertTrue(SqlDialect.PostgreSQL.typeName(DbValue.DbChar('A')) == "CHAR(1)") + }, + test("DbLocalDate -> DATE") { + assertTrue( + SqlDialect.PostgreSQL.typeName(DbValue.DbLocalDate(LocalDate.now())) == "DATE" + ) + }, + test("DbLocalDateTime -> TIMESTAMP") { + assertTrue( + SqlDialect.PostgreSQL.typeName( + DbValue.DbLocalDateTime(LocalDateTime.now()) + ) == "TIMESTAMP" + ) + }, + test("DbLocalTime -> TIME") { + assertTrue( + SqlDialect.PostgreSQL.typeName(DbValue.DbLocalTime(LocalTime.now())) == "TIME" + ) + }, + test("DbInstant -> TIMESTAMPTZ") { + assertTrue( + SqlDialect.PostgreSQL.typeName(DbValue.DbInstant(Instant.now())) == "TIMESTAMPTZ" + ) + }, + test("DbDuration -> INTERVAL") { + assertTrue( + SqlDialect.PostgreSQL.typeName(DbValue.DbDuration(Duration.ofHours(1))) == "INTERVAL" + ) + }, + test("DbUUID -> UUID") { + assertTrue( + SqlDialect.PostgreSQL.typeName( + DbValue.DbUUID(UUID.fromString("550e8400-e29b-41d4-a716-446655440000")) + ) == "UUID" + ) + }, + test("paramPlaceholder(1) -> $1") { + assertTrue(SqlDialect.PostgreSQL.paramPlaceholder(1) == "$1") + }, + test("paramPlaceholder(2) -> $2") { + assertTrue(SqlDialect.PostgreSQL.paramPlaceholder(2) == "$2") + }, + test("paramPlaceholder(42) -> $42") { + assertTrue(SqlDialect.PostgreSQL.paramPlaceholder(42) == "$42") + } + ), + suite("SQLite")( + test("name is SQLite") { + assertTrue(SqlDialect.SQLite.name == "SQLite") + }, + test("DbNull -> NULL") { + assertTrue(SqlDialect.SQLite.typeName(DbValue.DbNull) == "NULL") + }, + test("DbInt -> INTEGER") { + assertTrue(SqlDialect.SQLite.typeName(DbValue.DbInt(1)) == "INTEGER") + }, + test("DbLong -> INTEGER") { + assertTrue(SqlDialect.SQLite.typeName(DbValue.DbLong(1L)) == "INTEGER") + }, + test("DbDouble -> REAL") { + assertTrue(SqlDialect.SQLite.typeName(DbValue.DbDouble(1.0)) == "REAL") + }, + test("DbFloat -> REAL") { + assertTrue(SqlDialect.SQLite.typeName(DbValue.DbFloat(1.0f)) == "REAL") + }, + test("DbBoolean -> INTEGER") { + assertTrue(SqlDialect.SQLite.typeName(DbValue.DbBoolean(true)) == "INTEGER") + }, + test("DbString -> TEXT") { + assertTrue(SqlDialect.SQLite.typeName(DbValue.DbString("x")) == "TEXT") + }, + test("DbBigDecimal -> TEXT") { + assertTrue( + SqlDialect.SQLite.typeName(DbValue.DbBigDecimal(scala.BigDecimal("1"))) == "TEXT" + ) + }, + test("DbBytes -> BLOB") { + assertTrue(SqlDialect.SQLite.typeName(DbValue.DbBytes(Array[Byte](1))) == "BLOB") + }, + test("DbShort -> INTEGER") { + assertTrue( + SqlDialect.SQLite.typeName(DbValue.DbShort(1.toShort)) == "INTEGER" + ) + }, + test("DbByte -> INTEGER") { + assertTrue(SqlDialect.SQLite.typeName(DbValue.DbByte(1.toByte)) == "INTEGER") + }, + test("DbChar -> TEXT") { + assertTrue(SqlDialect.SQLite.typeName(DbValue.DbChar('A')) == "TEXT") + }, + test("DbLocalDate -> TEXT") { + assertTrue( + SqlDialect.SQLite.typeName(DbValue.DbLocalDate(LocalDate.now())) == "TEXT" + ) + }, + test("DbLocalDateTime -> TEXT") { + assertTrue( + SqlDialect.SQLite.typeName( + DbValue.DbLocalDateTime(LocalDateTime.now()) + ) == "TEXT" + ) + }, + test("DbLocalTime -> TEXT") { + assertTrue(SqlDialect.SQLite.typeName(DbValue.DbLocalTime(LocalTime.now())) == "TEXT") + }, + test("DbInstant -> TEXT") { + assertTrue( + SqlDialect.SQLite.typeName(DbValue.DbInstant(Instant.now())) == "TEXT" + ) + }, + test("DbDuration -> TEXT") { + assertTrue( + SqlDialect.SQLite.typeName(DbValue.DbDuration(Duration.ofHours(1))) == "TEXT" + ) + }, + test("DbUUID -> TEXT") { + assertTrue( + SqlDialect.SQLite.typeName(DbValue.DbUUID(UUID.fromString("550e8400-e29b-41d4-a716-446655440000"))) == "TEXT" + ) + }, + test("paramPlaceholder(1) -> ?") { + assertTrue(SqlDialect.SQLite.paramPlaceholder(1) == "?") + }, + test("paramPlaceholder(42) -> ?") { + assertTrue(SqlDialect.SQLite.paramPlaceholder(42) == "?") + } + ) + ) +} diff --git a/sql/shared/src/test/scala/zio/blocks/sql/SqlInterpolatorSpec.scala b/sql/shared/src/test/scala/zio/blocks/sql/SqlInterpolatorSpec.scala new file mode 100644 index 0000000000..cc6ec22e70 --- /dev/null +++ b/sql/shared/src/test/scala/zio/blocks/sql/SqlInterpolatorSpec.scala @@ -0,0 +1,174 @@ +package zio.blocks.sql + +import zio.test.* + +import scala.language.implicitConversions +import java.time._ +import java.util.UUID + +object SqlInterpolatorSpec extends ZIOSpecDefault { + def spec: Spec[TestEnvironment, Any] = suite("SqlInterpolatorSpec")( + suite("DbParam givens")( + test("Int param converts to DbInt") { + val p = DbParam[Int] + assertTrue(p.toDbValue(42) == DbValue.DbInt(42)) + }, + test("Long param converts to DbLong") { + val p = DbParam[Long] + assertTrue(p.toDbValue(42L) == DbValue.DbLong(42L)) + }, + test("Double param converts to DbDouble") { + val p = DbParam[Double] + assertTrue(p.toDbValue(3.14) == DbValue.DbDouble(3.14)) + }, + test("Float param converts to DbFloat") { + val p = DbParam[Float] + assertTrue(p.toDbValue(3.14f) == DbValue.DbFloat(3.14f)) + }, + test("Boolean param converts to DbBoolean") { + val p = DbParam[Boolean] + assertTrue(p.toDbValue(true) == DbValue.DbBoolean(true)) + }, + test("String param converts to DbString") { + val p = DbParam[String] + assertTrue(p.toDbValue("hello") == DbValue.DbString("hello")) + }, + test("Short param converts to DbShort") { + val p = DbParam[Short] + assertTrue(p.toDbValue(42.toShort) == DbValue.DbShort(42.toShort)) + }, + test("Byte param converts to DbByte") { + val p = DbParam[Byte] + assertTrue(p.toDbValue(7.toByte) == DbValue.DbByte(7.toByte)) + }, + test("BigDecimal param converts to DbBigDecimal") { + val p = DbParam[BigDecimal] + assertTrue(p.toDbValue(BigDecimal("3.14")) == DbValue.DbBigDecimal(BigDecimal("3.14"))) + }, + test("Array[Byte] param converts to DbBytes") { + val p = DbParam[Array[Byte]] + val bytes = Array[Byte](1, 2, 3) + val result = p.toDbValue(bytes) match { + case DbValue.DbBytes(v) => v.sameElements(bytes) + case _ => false + } + assertTrue(result) + }, + test("LocalDate param converts to DbLocalDate") { + val p = DbParam[LocalDate] + val date = LocalDate.of(2024, 1, 15) + assertTrue(p.toDbValue(date) == DbValue.DbLocalDate(date)) + }, + test("LocalDateTime param converts to DbLocalDateTime") { + val p = DbParam[LocalDateTime] + val dt = LocalDateTime.of(2024, 1, 15, 12, 30) + assertTrue(p.toDbValue(dt) == DbValue.DbLocalDateTime(dt)) + }, + test("LocalTime param converts to DbLocalTime") { + val p = DbParam[LocalTime] + val t = LocalTime.of(12, 30, 45) + assertTrue(p.toDbValue(t) == DbValue.DbLocalTime(t)) + }, + test("Instant param converts to DbInstant") { + val p = DbParam[Instant] + val instant = Instant.parse("2024-01-15T12:00:00Z") + assertTrue(p.toDbValue(instant) == DbValue.DbInstant(instant)) + }, + test("Duration param converts to DbDuration") { + val p = DbParam[Duration] + val dur = Duration.ofHours(2) + assertTrue(p.toDbValue(dur) == DbValue.DbDuration(dur)) + }, + test("UUID param converts to DbUUID") { + val p = DbParam[UUID] + val uuid = UUID.fromString("550e8400-e29b-41d4-a716-446655440000") + assertTrue(p.toDbValue(uuid) == DbValue.DbUUID(uuid)) + }, + test("DbValue passthrough") { + val p = DbParam[DbValue] + val v = DbValue.DbString("raw") + assertTrue(p.toDbValue(v) == DbValue.DbString("raw")) + }, + test("Option Some produces inner value") { + val p = DbParam[Option[Int]] + assertTrue(p.toDbValue(Some(42)) == DbValue.DbInt(42)) + }, + test("Option None produces DbNull") { + val p = DbParam[Option[Int]] + assertTrue(p.toDbValue(None) == DbValue.DbNull) + }, + test("Nested Option Some(Some) produces inner value") { + val p = DbParam[Option[Option[String]]] + assertTrue(p.toDbValue(Some(Some("nested"))) == DbValue.DbString("nested")) + }, + test("Nested Option Some(None) produces DbNull") { + val p = DbParam[Option[Option[String]]] + assertTrue(p.toDbValue(Some(None)) == DbValue.DbNull) + } + ), + suite("sql interpolator with DbValue params")( + test("single DbValue param") { + val frag = sql"SELECT ${DbValue.DbInt(42)}" + assertTrue( + frag.queryParams == IndexedSeq(DbValue.DbInt(42)), + frag.parts == IndexedSeq("SELECT ", "") + ) + }, + test("multiple DbValue params") { + val frag = + sql"INSERT INTO t VALUES (${DbValue.DbInt(1)}, ${DbValue.DbString("hello")}, ${DbValue.DbBoolean(true)}, ${DbValue.DbDouble(3.14)})" + assertTrue( + frag.queryParams.length == 4, + frag.queryParams(0) == DbValue.DbInt(1), + frag.queryParams(1) == DbValue.DbString("hello"), + frag.queryParams(2) == DbValue.DbBoolean(true), + frag.queryParams(3) == DbValue.DbDouble(3.14) + ) + }, + test("no params produces empty params") { + val frag = sql"SELECT 1" + assertTrue( + frag.queryParams.isEmpty, + frag.parts == IndexedSeq("SELECT 1") + ) + } + ), + suite("sql interpolator with DbParam conversion")( + test("Int param converts via DbParam") { + val frag = sql"SELECT ${42}" + assertTrue( + frag.queryParams == IndexedSeq(DbValue.DbInt(42)), + frag.parts == IndexedSeq("SELECT ", "") + ) + }, + test("String param converts via DbParam") { + val frag = sql"SELECT ${"hello"}" + assertTrue(frag.queryParams == IndexedSeq(DbValue.DbString("hello"))) + }, + test("Boolean param converts via DbParam") { + val frag = sql"SELECT ${true}" + assertTrue(frag.queryParams == IndexedSeq(DbValue.DbBoolean(true))) + }, + test("multiple mixed DbParam types") { + val frag = sql"INSERT INTO t VALUES (${1}, ${"hello"}, ${true}, ${3.14})" + assertTrue( + frag.queryParams.length == 4, + frag.queryParams(0) == DbValue.DbInt(1), + frag.queryParams(1) == DbValue.DbString("hello"), + frag.queryParams(2) == DbValue.DbBoolean(true), + frag.queryParams(3) == DbValue.DbDouble(3.14) + ) + }, + test("Option Some converts via DbParam") { + val v: Option[Int] = Some(42) + val frag = sql"SELECT ${v}" + assertTrue(frag.queryParams == IndexedSeq(DbValue.DbInt(42))) + }, + test("Option None converts via DbParam") { + val v: Option[Int] = None + val frag = sql"SELECT ${v}" + assertTrue(frag.queryParams == IndexedSeq(DbValue.DbNull)) + } + ) + ) +} diff --git a/sql/shared/src/test/scala/zio/blocks/sql/SqlNameMapperSpec.scala b/sql/shared/src/test/scala/zio/blocks/sql/SqlNameMapperSpec.scala new file mode 100644 index 0000000000..42c83efac9 --- /dev/null +++ b/sql/shared/src/test/scala/zio/blocks/sql/SqlNameMapperSpec.scala @@ -0,0 +1,70 @@ +package zio.blocks.sql + +import zio.test._ + +object SqlNameMapperSpec extends ZIOSpecDefault { + + def spec = suite("SqlNameMapperSpec")( + suite("SnakeCase")( + test("camelCase to snake_case") { + assertTrue(SqlNameMapper.SnakeCase("firstName") == "first_name") + }, + test("PascalCase to snake_case") { + assertTrue(SqlNameMapper.SnakeCase("FirstName") == "first_name") + }, + test("already snake_case unchanged") { + assertTrue(SqlNameMapper.SnakeCase("first_name") == "first_name") + }, + test("ID suffix handled correctly") { + assertTrue(SqlNameMapper.SnakeCase("userID") == "user_id") + }, + test("consecutive capitals") { + assertTrue(SqlNameMapper.SnakeCase("HTTPResponse") == "http_response") + }, + test("single word lowercase unchanged") { + assertTrue(SqlNameMapper.SnakeCase("name") == "name") + }, + test("single word uppercase to lowercase") { + assertTrue(SqlNameMapper.SnakeCase("Name") == "name") + }, + test("empty string") { + assertTrue(SqlNameMapper.SnakeCase("") == "") + }, + test("kebab-case converted to snake_case") { + assertTrue(SqlNameMapper.SnakeCase("first-name") == "first_name") + }, + test("mixed separators") { + assertTrue(SqlNameMapper.SnakeCase("first_name-value") == "first_name_value") + }, + test("numbers in name") { + assertTrue(SqlNameMapper.SnakeCase("field2Name") == "field2_name") + } + ), + suite("Identity")( + test("camelCase unchanged") { + assertTrue(SqlNameMapper.Identity("firstName") == "firstName") + }, + test("PascalCase unchanged") { + assertTrue(SqlNameMapper.Identity("FirstName") == "FirstName") + }, + test("snake_case unchanged") { + assertTrue(SqlNameMapper.Identity("first_name") == "first_name") + }, + test("empty string") { + assertTrue(SqlNameMapper.Identity("") == "") + } + ), + suite("Custom")( + test("applies custom function") { + val upper = SqlNameMapper.Custom(_.toUpperCase) + assertTrue(upper("firstName") == "FIRSTNAME") + }, + test("chains transformations") { + val snakeThenUpper = SqlNameMapper.Custom { s => + SqlNameMapper.SnakeCase(s).toUpperCase + } + assertTrue(snakeThenUpper("firstName") == "FIRST_NAME") + } + ) + ) +} diff --git a/sql/shared/src/test/scala/zio/blocks/sql/TableSpec.scala b/sql/shared/src/test/scala/zio/blocks/sql/TableSpec.scala new file mode 100644 index 0000000000..135e2d35fb --- /dev/null +++ b/sql/shared/src/test/scala/zio/blocks/sql/TableSpec.scala @@ -0,0 +1,146 @@ +package zio.blocks.sql + +import zio.test.* +import zio.blocks.schema._ + +object TableSpec extends ZIOSpecDefault { + + case class SimpleRecord(name: String, age: Int) + object SimpleRecord { + implicit val schema: Schema[SimpleRecord] = Schema.derived + } + + case class UserProfile(firstName: String, lastName: String) + object UserProfile { + implicit val schema: Schema[UserProfile] = Schema.derived + } + + case class Category(name: String) + object Category { + implicit val schema: Schema[Category] = Schema.derived + } + + case class Address(street: String, city: String) + object Address { + implicit val schema: Schema[Address] = Schema.derived + } + + case class Box(width: Int) + object Box { + implicit val schema: Schema[Box] = Schema.derived + } + + def spec = suite("TableSpec")( + suite("Table.derived")( + test("derives simple_record from SimpleRecord") { + val table = Table.derived[SimpleRecord](SqlDialect.PostgreSQL) + assertTrue(table.name == "simple_record") + }, + test("derives user_profile from UserProfile") { + val table = Table.derived[UserProfile](SqlDialect.PostgreSQL) + assertTrue(table.name == "user_profile") + }, + test("derives category from Category") { + val table = Table.derived[Category](SqlDialect.SQLite) + assertTrue(table.name == "category") + }, + test("derives address from Address") { + val table = Table.derived[Address](SqlDialect.PostgreSQL) + assertTrue(table.name == "address") + }, + test("derives box from Box") { + val table = Table.derived[Box](SqlDialect.PostgreSQL) + assertTrue(table.name == "box") + }, + test("table.columns matches codec.columns") { + val table = Table.derived[SimpleRecord](SqlDialect.PostgreSQL) + assertTrue( + table.columns == IndexedSeq("name", "age"), + table.columns.size == 2 + ) + }, + test("derived with explicit table name uses it directly") { + val table = Table.derived[SimpleRecord]("my_custom_table", SqlDialect.PostgreSQL) + assertTrue(table.name == "my_custom_table") + }, + test("derived with explicit name ignores type name") { + val table = Table.derived[UserProfile]("profiles", SqlDialect.SQLite) + assertTrue(table.name == "profiles") + } + ), + suite("pluralize")( + test("pluralizes user to users") { + assertTrue(Table.pluralize("user") == "users") + }, + test("pluralizes address to addresses") { + assertTrue(Table.pluralize("address") == "addresses") + }, + test("pluralizes category to categories") { + assertTrue(Table.pluralize("category") == "categories") + }, + test("pluralizes box to boxes") { + assertTrue(Table.pluralize("box") == "boxes") + }, + test("pluralizes bus to buses") { + assertTrue(Table.pluralize("bus") == "buses") + }, + test("pluralizes fox to foxes") { + assertTrue(Table.pluralize("fox") == "foxes") + }, + test("pluralizes church to churches") { + assertTrue(Table.pluralize("church") == "churches") + }, + test("pluralizes dish to dishes") { + assertTrue(Table.pluralize("dish") == "dishes") + }, + test("pluralizes quiz to quizzes") { + assertTrue(Table.pluralize("quiz") == "quizzes") + }, + test("pluralizes buzz to buzzes") { + assertTrue(Table.pluralize("buzz") == "buzzes") + }, + test("pluralizes fuzz to fuzzes") { + assertTrue(Table.pluralize("fuzz") == "fuzzes") + }, + test("empty string pluralizes to empty string") { + assertTrue(Table.pluralize("") == "") + } + ), + suite("Table.dropTable")( + test("generates DROP TABLE IF EXISTS") { + val table = Table.derived[SimpleRecord](SqlDialect.PostgreSQL) + val frag = table.dropTable + assertTrue(frag.sql(SqlDialect.PostgreSQL) == "DROP TABLE IF EXISTS simple_record") + }, + test("works with SQLite dialect") { + val table = Table.derived[Category](SqlDialect.SQLite) + val frag = table.dropTable + assertTrue(frag.sql(SqlDialect.SQLite) == "DROP TABLE IF EXISTS category") + } + ), + suite("Table.createTable")( + test("generates CREATE TABLE IF NOT EXISTS") { + val table = Table.derived[SimpleRecord](SqlDialect.PostgreSQL) + val frag = table.createTable + val sql = frag.sql(SqlDialect.PostgreSQL) + assertTrue( + sql.contains("CREATE TABLE IF NOT EXISTS simple_record"), + sql.contains("name"), + sql.contains("age") + ) + }, + test("works with PostgreSQL dialect") { + val table = Table.derived[Category](SqlDialect.PostgreSQL) + val frag = table.createTable + val sql = frag.sql(SqlDialect.PostgreSQL) + assertTrue(sql.contains("CREATE TABLE IF NOT EXISTS category")) + }, + test("works with SQLite dialect") { + val table = Table.derived[Category](SqlDialect.SQLite) + val frag = table.createTable + val sql = frag.sql(SqlDialect.SQLite) + assertTrue(sql.contains("CREATE TABLE IF NOT EXISTS category")) + } + ) + ) +}