diff --git a/src/commonMain/kotlin/uk/gibby/driver/Surreal.kt b/src/commonMain/kotlin/uk/gibby/driver/Surreal.kt index 825af94..2907bf5 100644 --- a/src/commonMain/kotlin/uk/gibby/driver/Surreal.kt +++ b/src/commonMain/kotlin/uk/gibby/driver/Surreal.kt @@ -6,13 +6,11 @@ import io.ktor.util.collections.* import io.ktor.websocket.* import kotlinx.coroutines.* import kotlinx.coroutines.channels.Channel -import kotlinx.coroutines.flow.receiveAsFlow +import kotlinx.coroutines.flow.* import kotlinx.serialization.json.* -import uk.gibby.driver.rpc.model.RpcRequest -import uk.gibby.driver.rpc.model.RpcResponse -import uk.gibby.driver.rpc.model.RpcResponseSerializer -import kotlinx.serialization.encodeToString -import uk.gibby.driver.rpc.model.LiveQueryAction +import uk.gibby.driver.rpc.exception.LiveQueryKilledException +import uk.gibby.driver.rpc.functions.kill +import uk.gibby.driver.rpc.model.* /** * SurrealDB driver @@ -26,7 +24,7 @@ class Surreal(private val host: String, private val port: Int = 8000) { private var count = 0L private var connection: DefaultClientWebSocketSession? = null private val requests = ConcurrentMap>() - private val liveQueries = ConcurrentMap>() + private val liveQueries = ConcurrentMap>>() private val context = CoroutineScope(Dispatchers.Default) /** @@ -46,44 +44,53 @@ class Surreal(private val host: String, private val port: Int = 8000) { val response = try { surrealJson.decodeFromString(RpcResponseSerializer, it.readText()) } catch (e: Exception) { + // In theory this could be an error for any request, so we cancel all of them requests.forEach { (_, r) -> r.cancel(CancellationException("Failed to decode incoming response: ${it.readText()}\n${e.message}"))} throw e } - if(response is RpcResponse.Notification) { - val action = response.result - val liveQuery = liveQueries[action.id] - if (liveQuery != null) { - liveQuery.send(response.result) - } - else { - println("Couldn't find live query with id ${action.id}") - requests.forEach { - (_, r) -> r.cancel(CancellationException("Received a request with an unknown id: ${response.id} body: $response")) - } - } - } - if(response.id != null) { - val request = requests[response.id] - if (request != null) { - when(response) { - is RpcResponse.Success -> request.send(response.result) - is RpcResponse.Error -> request.cancel(CancellationException("SurrealDB responded with an error: '${response.error}'")) - else -> TODO() - } - requests.remove(response.id) - } - else { - if (response.id == null) println("SurrealDB: Received a response with no id: $response") - else requests.forEach { - (_, r) -> r.cancel(CancellationException("Received a request with an unknown id: ${response.id} body: $response")) - } - } + when(response) { + is RpcResponse.Success -> handleSuccess(response) + is RpcResponse.Error -> handleError(response) + is RpcResponse.Notification -> handleNotification(response) } } } } } + private suspend fun handleSuccess(response: RpcResponse.Success) { + val request = requests[response.id] + if (request != null) { + request.send(response.result) + } else { + requests.forEach { + // In theory this could be an error for any request, so we cancel all of them + (_, r) -> + r.cancel(CancellationException("Received a request with an unknown id: ${response.id} body: $response")) + } + } + } + + private fun handleError(response: RpcResponse.Error) { + val request = requests[response.id] + if (request != null) { + request.cancel(CancellationException("SurrealDB responded with an error: '${response.error}'")) + } else { + requests.forEach { + // In theory this could be an error for any request, so we cancel all of them + (_, r) -> + r.cancel(CancellationException("Received a request with an unknown id: ${response.id} body: $response")) + } + } + requests.remove(response.id) + } + + private suspend fun handleNotification(response: RpcResponse.Notification) { + val action = response.result + val liveQuery = liveQueries.getOrPut(action.id) { Channel() } + context.launch { liveQuery.send(response.result) } + } + internal suspend fun sendRequest(method: String, params: JsonArray): JsonElement { val id = count++.toString() val request = RpcRequest(id, method, params) @@ -93,15 +100,24 @@ class Surreal(private val host: String, private val port: Int = 8000) { return channel.receive() } - internal fun subscribe(liveQueryId: String): Channel { - val channel = Channel() - println("Live query $liveQueryId created") - liveQueries[liveQueryId] = channel.apply { - invokeOnClose { - println("Live query $liveQueryId closed") - liveQueries.remove(liveQueryId) - } - } - return channel + fun subscribeAsJson(liveQueryId: String): Flow> { + val channel = liveQueries.getOrPut(liveQueryId) { Channel() } + return channel.receiveAsFlow() + } + + inline fun subscribe(liveQueryId: String): Flow> { + return subscribeAsJson(liveQueryId).map { it.asType() } + } + + fun unsubscribe(liveQueryId: String) { + val channel = liveQueries[liveQueryId] + channel?.cancel(LiveQueryKilledException) + liveQueries.remove(liveQueryId) + } + + internal fun triggerKill(liveQueryId: String) { + context.launch { kill(liveQueryId) } } -} \ No newline at end of file + +} + diff --git a/src/commonMain/kotlin/uk/gibby/driver/rpc/exception/LiveQueryKilledException.kt b/src/commonMain/kotlin/uk/gibby/driver/rpc/exception/LiveQueryKilledException.kt new file mode 100644 index 0000000..1ea775b --- /dev/null +++ b/src/commonMain/kotlin/uk/gibby/driver/rpc/exception/LiveQueryKilledException.kt @@ -0,0 +1,5 @@ +package uk.gibby.driver.rpc.exception + +import kotlinx.coroutines.CancellationException + +object LiveQueryKilledException: CancellationException("Live query has been killed") \ No newline at end of file diff --git a/src/commonMain/kotlin/uk/gibby/driver/rpc/functions/Kill.kt b/src/commonMain/kotlin/uk/gibby/driver/rpc/functions/Kill.kt new file mode 100644 index 0000000..b2e642c --- /dev/null +++ b/src/commonMain/kotlin/uk/gibby/driver/rpc/functions/Kill.kt @@ -0,0 +1,10 @@ +package uk.gibby.driver.rpc.functions + +import kotlinx.serialization.json.add +import kotlinx.serialization.json.buildJsonArray +import uk.gibby.driver.Surreal + +suspend fun Surreal.kill(liveQueryId: String) { + sendRequest("kill", buildJsonArray { add(liveQueryId) }) +} + diff --git a/src/commonMain/kotlin/uk/gibby/driver/rpc/functions/Live.kt b/src/commonMain/kotlin/uk/gibby/driver/rpc/functions/Live.kt new file mode 100644 index 0000000..342c1a7 --- /dev/null +++ b/src/commonMain/kotlin/uk/gibby/driver/rpc/functions/Live.kt @@ -0,0 +1,56 @@ +package uk.gibby.driver.rpc.functions + +import io.ktor.utils.io.core.* +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.map +import kotlinx.coroutines.flow.receiveAsFlow +import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.add +import kotlinx.serialization.json.buildJsonArray +import kotlinx.serialization.json.decodeFromJsonElement +import uk.gibby.driver.Surreal +import uk.gibby.driver.rpc.model.LiveQueryAction +import uk.gibby.driver.rpc.model.asType +import uk.gibby.driver.surrealJson +import kotlin.jvm.JvmName + + +suspend fun Surreal.live(table: String): String { + val response = sendRequest("live", buildJsonArray { add(table) }) + return surrealJson.decodeFromJsonElement(response) +} + + +@JvmName("observeJson") +suspend fun Surreal.observeAsJson(liveQueryId: String): LiveQueryFlow { + val id = live(liveQueryId) + return LiveQueryFlow( + flow = subscribe(id), + id = id, + connection = this + ) +} + +suspend inline fun Surreal.observe(table: String): LiveQueryFlow { + val jsonFlow = observeAsJson(table) + return jsonFlow.map { it.asType() } +} + +class LiveQueryFlow( + private val flow: Flow>, + val id: String, + private val connection: Surreal +): Flow> by flow, Closeable { + override fun close() { + connection.unsubscribe(id) + connection.triggerKill(id) + } + + fun map(transform: (LiveQueryAction) -> LiveQueryAction): LiveQueryFlow { + return LiveQueryFlow( + flow = flow.map { transform(it) }, + id = id, + connection = connection + ) + } +} \ No newline at end of file diff --git a/src/commonMain/kotlin/uk/gibby/driver/rpc/functions/LiveQuery.kt b/src/commonMain/kotlin/uk/gibby/driver/rpc/functions/LiveQuery.kt deleted file mode 100644 index fc2306b..0000000 --- a/src/commonMain/kotlin/uk/gibby/driver/rpc/functions/LiveQuery.kt +++ /dev/null @@ -1,18 +0,0 @@ -package uk.gibby.driver.rpc.functions - -import kotlinx.coroutines.channels.consumeEach -import kotlinx.coroutines.flow.Flow -import kotlinx.coroutines.flow.consumeAsFlow -import kotlinx.serialization.json.JsonElement -import kotlinx.serialization.json.add -import kotlinx.serialization.json.buildJsonArray -import kotlinx.serialization.json.decodeFromJsonElement -import uk.gibby.driver.Surreal -import uk.gibby.driver.rpc.model.LiveQueryAction -import uk.gibby.driver.surrealJson - -suspend fun Surreal.live(table: String): Flow { - val response = sendRequest("live", buildJsonArray{ add(table) }) - val id: String = surrealJson.decodeFromJsonElement(response) - return subscribe(id).consumeAsFlow() -} \ No newline at end of file diff --git a/src/commonMain/kotlin/uk/gibby/driver/rpc/model/LiveQueryAction.kt b/src/commonMain/kotlin/uk/gibby/driver/rpc/model/LiveQueryAction.kt index 3c186e1..40b2546 100644 --- a/src/commonMain/kotlin/uk/gibby/driver/rpc/model/LiveQueryAction.kt +++ b/src/commonMain/kotlin/uk/gibby/driver/rpc/model/LiveQueryAction.kt @@ -1,7 +1,66 @@ package uk.gibby.driver.rpc.model -import kotlinx.serialization.Serializable -import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.* +import kotlinx.serialization.builtins.serializer +import kotlinx.serialization.descriptors.SerialDescriptor +import kotlinx.serialization.descriptors.buildClassSerialDescriptor +import kotlinx.serialization.encoding.CompositeDecoder +import kotlinx.serialization.encoding.Decoder +import kotlinx.serialization.encoding.Encoder +import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.decodeFromJsonElement +import uk.gibby.driver.surrealJson -@Serializable -data class LiveQueryAction(val action: String, val id: String, val result: JsonObject? = null) \ No newline at end of file +@Serializable(with = LiveQueryActionSerializer::class) +sealed class LiveQueryAction(val id: String) { + class Create(id: String, val result: T): LiveQueryAction(id) + class Update(id: String, val result: T): LiveQueryAction(id) + class Delete(id: String, val deletedId: String): LiveQueryAction(id) + +} + + +inline fun LiveQueryAction.asType(): LiveQueryAction { + return when(this) { + is LiveQueryAction.Delete -> this + is LiveQueryAction.Create -> LiveQueryAction.Create(id, surrealJson.decodeFromJsonElement(result)) + is LiveQueryAction.Update -> LiveQueryAction.Update(id, surrealJson.decodeFromJsonElement(result)) + } +} + +class LiveQueryActionSerializer(private val resultSerializer: KSerializer): KSerializer> { + override val descriptor: SerialDescriptor = buildClassSerialDescriptor("LiveQueryAction") { + element("action", String.serializer().descriptor) + element("id", String.serializer().descriptor) + element("result", resultSerializer.descriptor) + } + + override fun deserialize(decoder: Decoder): LiveQueryAction { + val input = decoder.beginStructure(descriptor) + var action: String? = null + var id: String? = null + var result: Thing? = null + loop@ while (true) { + when (val i = input.decodeElementIndex(descriptor)) { + CompositeDecoder.DECODE_DONE -> break@loop + 0 -> action = input.decodeStringElement(descriptor, i) + 1 -> id = input.decodeStringElement(descriptor, i) + 2 -> result = input.decodeSerializableElement(descriptor, i, ThingSerializer(resultSerializer)) + else -> throw SerializationException("Unknown index $i") + } + } + input.endStructure(descriptor) + if (action == null || id == null || result == null) throw SerializationException("Missing fields") + return when(action) { + "CREATE" -> LiveQueryAction.Create(id, (result as Thing.Record).result) + "UPDATE" -> LiveQueryAction.Update(id, (result as Thing.Record).result) + "DELETE" -> LiveQueryAction.Delete(id, (result as Thing.Reference).id) + else -> throw SerializationException("Unknown action $action") + } + } + + override fun serialize(encoder: Encoder, value: LiveQueryAction) { + TODO("Not yet implemented") + } + +} \ No newline at end of file diff --git a/src/commonMain/kotlin/uk/gibby/driver/rpc/model/RpcResponse.kt b/src/commonMain/kotlin/uk/gibby/driver/rpc/model/RpcResponse.kt index 6e4bc3a..0270832 100644 --- a/src/commonMain/kotlin/uk/gibby/driver/rpc/model/RpcResponse.kt +++ b/src/commonMain/kotlin/uk/gibby/driver/rpc/model/RpcResponse.kt @@ -15,7 +15,7 @@ sealed class RpcResponse { data class Error(override val id: String, val error: JsonElement): RpcResponse() @Serializable - data class Notification(val result: LiveQueryAction): RpcResponse() { + data class Notification(val result: LiveQueryAction): RpcResponse() { override val id: String? = null } } diff --git a/src/commonTest/kotlin/KillTest.kt b/src/commonTest/kotlin/KillTest.kt new file mode 100644 index 0000000..32427ae --- /dev/null +++ b/src/commonTest/kotlin/KillTest.kt @@ -0,0 +1,21 @@ +import kotlinx.coroutines.test.runTest +import uk.gibby.driver.Surreal +import uk.gibby.driver.rpc.functions.kill +import uk.gibby.driver.rpc.functions.live +import uk.gibby.driver.rpc.functions.signin +import uk.gibby.driver.rpc.functions.use +import utils.cleanDatabase +import kotlin.test.Test + +class KillTest { + @Test + fun testKill() = runTest { + cleanDatabase() + val connection = Surreal("localhost", 8000) + connection.connect() + connection.signin("root", "root") + connection.use("test", "test") + val liveQueryId = connection.live("test") + connection.kill(liveQueryId) + } +} \ No newline at end of file diff --git a/src/commonTest/kotlin/LiveQueryTest.kt b/src/commonTest/kotlin/LiveQueryTest.kt index c815b47..5dbc770 100644 --- a/src/commonTest/kotlin/LiveQueryTest.kt +++ b/src/commonTest/kotlin/LiveQueryTest.kt @@ -1,16 +1,13 @@ -import kotlinx.coroutines.delay -import kotlinx.coroutines.flow.filter import kotlinx.coroutines.flow.first -import kotlinx.coroutines.flow.map -import kotlinx.coroutines.launch import kotlinx.coroutines.test.runTest -import kotlinx.serialization.json.buildJsonObject -import kotlinx.serialization.json.put import uk.gibby.driver.Surreal +import uk.gibby.driver.rpc.exception.LiveQueryKilledException import uk.gibby.driver.rpc.functions.* +import uk.gibby.driver.rpc.model.LiveQueryAction import uk.gibby.driver.rpc.model.bind +import uk.gibby.driver.rpc.model.data import utils.cleanDatabase -import kotlin.test.Test +import kotlin.test.* class LiveQueryTest { @Test @@ -20,8 +17,137 @@ class LiveQueryTest { connection.connect() connection.signin("root", "root") connection.use("test", "test") - val incoming = connection.live("test") - connection.create("test").content(buildJsonObject { put("thing", "thing") }) - println(incoming.first()) + + val incoming = connection.observe("test") + connection.create("test", "first").content(TestClass("thing", 1)) + connection.create("test", "second").content(TestClass("thing", 2)) + connection.update("test", "first").patch { + replace("myText", "thing2") + } + connection.delete("test", "second") + + val first = incoming.first() + assertTrue { first is LiveQueryAction.Create } + first as LiveQueryAction.Create + assertEquals("thing", first.result.myText) + assertEquals(1, first.result.myNumber) + + val second = incoming.first() + assertTrue { second is LiveQueryAction.Create } + second as LiveQueryAction.Create + assertEquals("thing", second.result.myText) + assertEquals(2, second.result.myNumber) + + val updated = incoming.first() + assertTrue { updated is LiveQueryAction.Update } + updated as LiveQueryAction.Update + assertEquals("thing2", updated.result.myText) + assertEquals(1, updated.result.myNumber) + + val deleted = incoming.first() + assertTrue { deleted is LiveQueryAction.Delete } + deleted as LiveQueryAction.Delete + assertEquals("test:second", deleted.deletedId) + + incoming.close() + assertFailsWith { + incoming.first() + } } + + @Test + fun testLiveQueryAsPartOfRegularQuery() = runTest { + cleanDatabase() + val connection = Surreal("localhost", 8000) + connection.connect() + connection.signin("root", "root") + connection.use("test", "test") + val result = connection.query("LIVE SELECT * FROM test;") + val liveQueryId = result.first().data() + val incoming = connection.subscribe(liveQueryId) + + connection.create("test", "first").content(TestClass("thing", 1)) + connection.create("test", "second").content(TestClass("thing", 2)) + connection.update("test", "first").patch { + replace("myText", "thing2") + } + connection.delete("test", "second") + + val first = incoming.first() + assertTrue { first is LiveQueryAction.Create } + first as LiveQueryAction.Create + assertEquals("thing", first.result.myText) + assertEquals(1, first.result.myNumber) + + val second = incoming.first() + assertTrue { second is LiveQueryAction.Create } + second as LiveQueryAction.Create + assertEquals("thing", second.result.myText) + assertEquals(2, second.result.myNumber) + + val updated = incoming.first() + assertTrue { updated is LiveQueryAction.Update } + updated as LiveQueryAction.Update + assertEquals("thing2", updated.result.myText) + assertEquals(1, updated.result.myNumber) + + val deleted = incoming.first() + assertTrue { deleted is LiveQueryAction.Delete } + deleted as LiveQueryAction.Delete + assertEquals("test:second", deleted.deletedId) + + connection.query("KILL \$liveQueryId;", bind("liveQueryId", liveQueryId)) + connection.unsubscribe(liveQueryId) + assertFailsWith { + incoming.first() + } + } + + @Test + fun testLiveQueryUsingRpcFunctions() = runTest { + cleanDatabase() + val connection = Surreal("localhost", 8000) + connection.connect() + connection.signin("root", "root") + connection.use("test", "test") + val liveQueryId = connection.live("test") + val incoming = connection.subscribe(liveQueryId) + + connection.create("test", "first").content(TestClass("thing", 1)) + connection.create("test", "second").content(TestClass("thing", 2)) + connection.update("test", "first").patch { + replace("myText", "thing2") + } + connection.delete("test", "second") + + val first = incoming.first() + assertTrue { first is LiveQueryAction.Create } + first as LiveQueryAction.Create + assertEquals("thing", first.result.myText) + assertEquals(1, first.result.myNumber) + + val second = incoming.first() + assertTrue { second is LiveQueryAction.Create } + second as LiveQueryAction.Create + assertEquals("thing", second.result.myText) + assertEquals(2, second.result.myNumber) + + val updated = incoming.first() + assertTrue { updated is LiveQueryAction.Update } + updated as LiveQueryAction.Update + assertEquals("thing2", updated.result.myText) + assertEquals(1, updated.result.myNumber) + + val deleted = incoming.first() + assertTrue { deleted is LiveQueryAction.Delete } + deleted as LiveQueryAction.Delete + assertEquals("test:second", deleted.deletedId) + + connection.kill(liveQueryId) + connection.unsubscribe(liveQueryId) + assertFailsWith { + incoming.first() + } + } + } \ No newline at end of file