diff --git a/README.md b/README.md index 3d8fb93..9115ccd 100644 --- a/README.md +++ b/README.md @@ -169,6 +169,123 @@ implementation("com.locanara:locanara:1.0.0") --- +## Pipeline DSL + +Compose multiple AI steps into a single type-safe workflow. Each step's output becomes the next step's input, and the return type is determined by the last step. + +### Basic Pipeline (two steps) + +**Swift** + +```swift +import Locanara + +let model = FoundationLanguageModel() + +// Step 1: fix typos +let proofread = try await model.proofread( + "Ths is a tset of on-devce AI." +) + +// Step 2: translate the corrected text +let translated = try await model.translate( + proofread.correctedText, to: "ko" +) +print(translated.translatedText) +``` + +**Kotlin** + +```kotlin +import com.locanara.dsl.* +import com.locanara.platform.PromptApiModel + +suspend fun example(context: Context) { + val model = PromptApiModel(context) + + // Step 1: fix typos + val proofread = model.proofread( + "Ths is a tset of on-devce AI." + ) + + // Step 2: translate the corrected text + val translated = model.translate( + proofread.correctedText, to = "ko" + ) + println(translated.translatedText) +} +``` + +### Declarative Pipeline Builder (Swift) + +Swift's `@PipelineBuilder` result builder enforces return types at compile time. The compiler rejects pipelines with incompatible step types, making multi-step workflows safe to refactor. + +```swift +import Locanara + +let model = FoundationLanguageModel() + +// Two-step: proofread → translate +// Return type is TranslateResult — compiler enforced +let result = try await model.pipeline { + Proofread() + Translate(to: "ko") +}.run("Ths is a tset sentece about on-devce AI.") + +print(result.translatedText) // "이것은 온디바이스 AI에 관한 테스트 문장입니다." +print(result.targetLanguage) // "ko" + +// Three-step: summarize → proofread → translate +let threeStep = try await model.pipeline { + Summarize(bulletCount: 3) + Proofread() + Translate(to: "ja") +}.run(longArticle) +// Returns TranslateResult (last step determines the type) +``` + +### Kotlin Pipeline DSL + +```kotlin +import com.locanara.dsl.* +import com.locanara.platform.PromptApiModel + +suspend fun pipelineExample(context: Context) { + val model = PromptApiModel(context) + + // Fluent pipeline API + val result = model.pipeline() + .proofread() + .translate(to = "ko") + .run("Ths is a tset sentece about on-devce AI.") + + // result is TranslateResult (last step determines type) + println(result.translatedText) + + // Three-step pipeline + val threeStep = model.pipeline() + .summarize(bulletCount = 3) + .proofread() + .translate(to = "ja") + .run(longArticle) +} +``` + +### Available Pipeline Steps + +| Step | Swift | Kotlin | Output | +| --------- | ------------------------- | -------------------------- | ----------------- | +| Summarize | `Summarize(bulletCount:)` | `.summarize(bulletCount:)` | `SummarizeResult` | +| Classify | `Classify(categories:)` | `.classify(categories:)` | `ClassifyResult` | +| Translate | `Translate(to:)` | `.translate(to:)` | `TranslateResult` | +| Proofread | `Proofread()` | `.proofread()` | `ProofreadResult` | +| Rewrite | `Rewrite(style:)` | `.rewrite(style:)` | `RewriteResult` | +| Extract | `Extract(entityTypes:)` | `.extract(entityTypes:)` | `ExtractResult` | + +> **Full tutorial**: [locanara.com/docs/tutorials/pipeline](https://locanara.com/docs/tutorials/pipeline) + +--- + ## Packages - [**apple**](packages/apple) — iOS/macOS SDK diff --git a/packages/android/locanara/src/main/kotlin/com/locanara/rag/DocumentChunker.kt b/packages/android/locanara/src/main/kotlin/com/locanara/rag/DocumentChunker.kt index 09db5e1..7ec08fc 100644 --- a/packages/android/locanara/src/main/kotlin/com/locanara/rag/DocumentChunker.kt +++ b/packages/android/locanara/src/main/kotlin/com/locanara/rag/DocumentChunker.kt @@ -218,7 +218,7 @@ class DocumentChunker( } // Move forward with overlap - val moveDistance = maxOf(1, chunkSize - config.chunkOverlap) + val moveDistance = maxOf(1, config.targetChunkSize - config.chunkOverlap) currentIndex += moveDistance } diff --git a/packages/android/locanara/src/test/kotlin/com/locanara/ChainsTests.kt b/packages/android/locanara/src/test/kotlin/com/locanara/ChainsTests.kt new file mode 100644 index 0000000..2c8e58f --- /dev/null +++ b/packages/android/locanara/src/test/kotlin/com/locanara/ChainsTests.kt @@ -0,0 +1,148 @@ +package com.locanara + +import com.locanara.builtin.ChatChain +import com.locanara.builtin.ClassifyChain +import com.locanara.builtin.ExtractChain +import com.locanara.builtin.ProofreadChain +import com.locanara.builtin.RewriteChain +import com.locanara.builtin.SummarizeChain +import com.locanara.builtin.TranslateChain +import com.locanara.composable.BufferMemory +import com.locanara.core.ChainInput +import kotlinx.coroutines.runBlocking +import org.junit.Assert.assertEquals +import org.junit.Assert.assertFalse +import org.junit.Assert.assertNotNull +import org.junit.Assert.assertTrue +import org.junit.Test + +// MARK: - Built-in Chain Tests + +class SummarizeChainTest { + @Test + fun `run returns typed result`() = runBlocking { + val model = MockModel { "This is a summary." } + val chain = SummarizeChain(model = model, bulletCount = 1) + + val result = chain.run("Long article text here...") + + assertEquals("This is a summary.", result.summary) + assertEquals("Long article text here...".length, result.originalLength) + } + + @Test + fun `invoke returns chain output`() = runBlocking { + val model = MockModel { "Summary text" } + val chain = SummarizeChain(model = model) + + val output = chain.invoke(ChainInput(text = "input")) + + assertEquals("Summary text", output.text) + assertNotNull(output.typed()) + } +} + +class ClassifyChainTest { + @Test + fun `run returns classify result`() = runBlocking { + val model = MockModel { "positive" } + val chain = ClassifyChain( + model = model, + categories = listOf("positive", "negative") + ) + + val result = chain.run("Great product!") + + assertEquals("positive", result.topClassification.label) + assertEquals(1.0, result.topClassification.score, 0.001) + } +} + +class TranslateChainTest { + @Test + fun `run returns translate result`() = runBlocking { + val model = MockModel { "안녕하세요" } + val chain = TranslateChain(model = model, targetLanguage = "ko") + + val result = chain.run("Hello") + + assertEquals("안녕하세요", result.translatedText) + assertEquals("en", result.sourceLanguage) + assertEquals("ko", result.targetLanguage) + } +} + +class RewriteChainTest { + @Test + fun `run returns rewrite result`() = runBlocking { + val model = MockModel { "Good day, how may I assist you?" } + val chain = RewriteChain(model = model, style = RewriteOutputType.PROFESSIONAL) + + val result = chain.run("hey whats up") + + assertEquals("Good day, how may I assist you?", result.rewrittenText) + assertEquals(RewriteOutputType.PROFESSIONAL, result.style) + } +} + +class ProofreadChainTest { + @Test + fun `run returns proofread result`() = runBlocking { + val model = MockModel { "This is a test." } + val chain = ProofreadChain(model = model) + + val result = chain.run("Ths is a tset.") + + assertEquals("This is a test.", result.correctedText) + assertTrue(result.hasCorrections) + } + + @Test + fun `no corrections detected`() = runBlocking { + val model = MockModel { "Already correct." } + val chain = ProofreadChain(model = model) + + val result = chain.run("Already correct.") + + assertFalse(result.hasCorrections) + } +} + +class ChatChainTest { + @Test + fun `run returns chat result`() = runBlocking { + val model = MockModel { "Hi there!" } + val chain = ChatChain(model = model) + + val result = chain.run("Hello!") + + assertEquals("Hi there!", result.message) + assertTrue(result.canContinue) + } + + @Test + fun `chat with memory saves entries`() = runBlocking { + val model = MockModel { "First response" } + val memory = BufferMemory(maxEntries = 10) + val chain = ChatChain(model = model, memory = memory) + + chain.run("First message") + + val entries = memory.load(ChainInput(text = "test")) + assertEquals(2, entries.size) // user + assistant + } +} + +class ExtractChainTest { + @Test + fun `run returns extract result`() = runBlocking { + val model = MockModel { "Tim Cook\nCupertino" } + val chain = ExtractChain(model = model, entityTypes = listOf("person", "location")) + + val result = chain.run("Tim Cook lives in Cupertino") + + assertEquals(2, result.entities.size) + assertEquals("Tim Cook", result.entities[0].value) + assertEquals("Cupertino", result.entities[1].value) + } +} diff --git a/packages/android/locanara/src/test/kotlin/com/locanara/ComposableTests.kt b/packages/android/locanara/src/test/kotlin/com/locanara/ComposableTests.kt new file mode 100644 index 0000000..0755761 --- /dev/null +++ b/packages/android/locanara/src/test/kotlin/com/locanara/ComposableTests.kt @@ -0,0 +1,154 @@ +package com.locanara + +import com.locanara.builtin.ProofreadChain +import com.locanara.builtin.RewriteChain +import com.locanara.builtin.SummarizeChain +import com.locanara.composable.BufferMemory +import com.locanara.composable.ContentFilterGuardrail +import com.locanara.composable.GuardrailResult +import com.locanara.composable.InputLengthGuardrail +import com.locanara.composable.SequentialChain +import com.locanara.core.ChainInput +import com.locanara.core.ChainOutput +import com.locanara.runtime.ChainExecutor +import kotlinx.coroutines.runBlocking +import org.junit.Assert.assertEquals +import org.junit.Assert.assertTrue +import org.junit.Test + +// MARK: - Memory Tests + +class MemoryTest { + @Test + fun `buffer memory save and load`() = runBlocking { + val memory = BufferMemory(maxEntries = 5) + val input = ChainInput(text = "Hello") + val output = ChainOutput(value = "Hi", text = "Hi") + + memory.save(input, output) + + val entries = memory.load(ChainInput(text = "test")) + assertEquals(2, entries.size) + assertEquals("user", entries[0].role) + assertEquals("Hello", entries[0].content) + assertEquals("assistant", entries[1].role) + assertEquals("Hi", entries[1].content) + } + + @Test + fun `buffer memory trimming`() = runBlocking { + val memory = BufferMemory(maxEntries = 2) + + for (i in 0 until 5) { + memory.save( + ChainInput(text = "msg $i"), + ChainOutput(value = "resp $i", text = "resp $i") + ) + } + + val entries = memory.load(ChainInput(text = "test")) + assertTrue(entries.size <= 4) // maxEntries * 2 + } + + @Test + fun `buffer memory clear`() = runBlocking { + val memory = BufferMemory() + memory.save(ChainInput(text = "hello"), ChainOutput(value = "hi", text = "hi")) + memory.clear() + + val entries = memory.load(ChainInput(text = "test")) + assertEquals(0, entries.size) + } +} + +// MARK: - Guardrail Tests + +class GuardrailTest { + @Test + fun `input length passes`() = runBlocking { + val guardrail = InputLengthGuardrail(maxCharacters = 100) + val result = guardrail.checkInput(ChainInput(text = "short")) + assertTrue(result is GuardrailResult.Passed) + } + + @Test + fun `input length truncates`() = runBlocking { + val guardrail = InputLengthGuardrail(maxCharacters = 5, truncate = true) + val result = guardrail.checkInput(ChainInput(text = "longer text")) + assertTrue(result is GuardrailResult.Modified) + assertEquals("longe", (result as GuardrailResult.Modified).newText) + } + + @Test + fun `input length blocks`() = runBlocking { + val guardrail = InputLengthGuardrail(maxCharacters = 5, truncate = false) + val result = guardrail.checkInput(ChainInput(text = "longer text")) + assertTrue(result is GuardrailResult.Blocked) + } + + @Test + fun `content filter blocks`() = runBlocking { + val guardrail = ContentFilterGuardrail(blockedPatterns = listOf("password", "secret")) + val blocked = guardrail.checkInput(ChainInput(text = "my password is 123")) + assertTrue(blocked is GuardrailResult.Blocked) + + val passed = guardrail.checkInput(ChainInput(text = "Hello world")) + assertTrue(passed is GuardrailResult.Passed) + } +} + +// MARK: - Chain Executor Tests + +class ChainExecutorTest { + @Test + fun `execute records history`() = runBlocking { + val model = MockModel { "result" } + val chain = SummarizeChain(model = model) + val executor = ChainExecutor(maxRetries = 0) + + executor.execute(chain, ChainInput(text = "test")) + + val history = executor.getHistory() + assertEquals(1, history.size) + assertEquals("SummarizeChain", history[0].chainName) + assertTrue(history[0].success) + assertEquals(1, history[0].attempt) + } + + @Test + fun `clear history`() = runBlocking { + val model = MockModel { "result" } + val chain = SummarizeChain(model = model) + val executor = ChainExecutor() + + executor.execute(chain, ChainInput(text = "test")) + executor.clearHistory() + + assertEquals(0, executor.getHistory().size) + } +} + +// MARK: - Sequential Chain Tests + +class SequentialChainTest { + @Test + fun `sequential execution`() = runBlocking { + var callCount = 0 + val model = MockModel { + callCount++ + "step$callCount" + } + + val chain = SequentialChain( + chains = listOf( + ProofreadChain(model = model), + RewriteChain(model = model, style = RewriteOutputType.PROFESSIONAL) + ) + ) + + val output = chain.invoke(ChainInput(text = "input")) + + assertEquals(2, callCount) + assertEquals("step2", output.text) + } +} diff --git a/packages/android/locanara/src/test/kotlin/com/locanara/CoreTests.kt b/packages/android/locanara/src/test/kotlin/com/locanara/CoreTests.kt new file mode 100644 index 0000000..dbc7441 --- /dev/null +++ b/packages/android/locanara/src/test/kotlin/com/locanara/CoreTests.kt @@ -0,0 +1,85 @@ +package com.locanara + +import com.locanara.builtin.SummarizeChain +import com.locanara.core.ChainInput +import com.locanara.core.ChainOutput +import com.locanara.core.OutputParser +import com.locanara.core.PromptTemplate +import com.locanara.core.TextOutputParser +import org.junit.Assert.assertEquals +import org.junit.Assert.assertNotNull +import org.junit.Assert.assertNull +import org.junit.Test + +// MARK: - Core Layer Tests + +class PromptTemplateTest { + @Test + fun `basic formatting`() { + val template = PromptTemplate( + templateString = "Summarize this: {text}", + inputVariables = listOf("text") + ) + val result = template.format(mapOf("text" to "Hello world")) + assertEquals("Summarize this: Hello world", result) + } + + @Test + fun `multiple variables`() { + val template = PromptTemplate( + templateString = "Translate from {source} to {target}: {text}", + inputVariables = listOf("source", "target", "text") + ) + val result = template.format( + mapOf("source" to "English", "target" to "Korean", "text" to "Hello") + ) + assertEquals("Translate from English to Korean: Hello", result) + } + + @Test(expected = IllegalArgumentException::class) + fun `missing variable throws`() { + val template = PromptTemplate( + templateString = "Hello {name}", + inputVariables = listOf("name") + ) + template.format(emptyMap()) + } + + @Test + fun `auto detection`() { + val template = PromptTemplate.from("Hello {name}, welcome to {place}") + assertEquals(listOf("name", "place"), template.inputVariables) + val result = template.format(mapOf("name" to "Alice", "place" to "Locanara")) + assertEquals("Hello Alice, welcome to Locanara", result) + } +} + +class OutputParserTest { + @Test + fun `text parser trims whitespace`() { + val parser = TextOutputParser() + val result = parser.parse(" hello world ") + assertEquals("hello world", result) + } +} + +class SchemaTest { + @Test + fun `chain input creation`() { + val input = ChainInput(text = "hello", metadata = mutableMapOf("key" to "value")) + assertEquals("hello", input.text) + assertEquals("value", input.metadata["key"]) + } + + @Test + fun `chain output typed`() { + val result = SummarizeResult( + summary = "test", originalLength = 100, summaryLength = 4 + ) + val output = ChainOutput(value = result, text = "test", processingTimeMs = 5) + + assertNotNull(output.typed()) + assertEquals("test", output.typed()?.summary) + assertNull(output.typed()) // wrong type + } +} diff --git a/packages/android/locanara/src/test/kotlin/com/locanara/DSLTests.kt b/packages/android/locanara/src/test/kotlin/com/locanara/DSLTests.kt new file mode 100644 index 0000000..0809c33 --- /dev/null +++ b/packages/android/locanara/src/test/kotlin/com/locanara/DSLTests.kt @@ -0,0 +1,110 @@ +package com.locanara + +import com.locanara.dsl.pipeline +import com.locanara.dsl.proofread +import com.locanara.dsl.summarize +import com.locanara.dsl.translate +import kotlinx.coroutines.runBlocking +import org.junit.Assert.assertEquals +import org.junit.Assert.assertTrue +import org.junit.Test + +// MARK: - Pipeline DSL Tests + +class PipelineTest { + @Test + fun `single step pipeline`() = runBlocking { + val model = MockModel { "Summary of input." } + + val result = model.pipeline() + .summarize(bulletCount = 1) + .run("Long text here") + + // Compile-time: result is SummarizeResult + assertEquals("Summary of input.", result.summary) + } + + @Test + fun `multi step pipeline type safety`() = runBlocking { + var callCount = 0 + val model = MockModel { + callCount++ + if (callCount == 1) "Summarized text" else "번역된 텍스트" + } + + val result = model.pipeline() + .summarize(bulletCount = 3) + .translate(to = "ko") + .run("Long article in English") + + // Compile-time: result is TranslateResult (last step) + assertEquals("번역된 텍스트", result.translatedText) + assertEquals("ko", result.targetLanguage) + assertEquals(2, callCount) + } + + @Test + fun `three step pipeline`() = runBlocking { + var callCount = 0 + val model = MockModel { + callCount++ + when (callCount) { + 1 -> "Corrected text" + 2 -> "Professionally written text" + else -> "unexpected" + } + } + + val result = model.pipeline() + .proofread() + .rewrite(style = RewriteOutputType.PROFESSIONAL) + .run("messy text with erors") + + // Compile-time: result is RewriteResult + assertEquals("Professionally written text", result.rewrittenText) + assertEquals(RewriteOutputType.PROFESSIONAL, result.style) + } + + @Test + fun `pipeline passes text between steps`() = runBlocking { + val receivedPrompts = mutableListOf() + val model = MockModel { prompt -> + receivedPrompts.add(prompt) + if (receivedPrompts.size == 1) "step1 output" else "step2 output" + } + + model.pipeline() + .proofread() + .rewrite(style = RewriteOutputType.FRIENDLY) + .run("original input") + + // Second step should receive first step's output in its prompt + assertTrue(receivedPrompts[1].contains("step1 output")) + } +} + +// MARK: - Model Extension Tests + +class ModelExtensionTest { + @Test + fun `summarize extension`() = runBlocking { + val model = MockModel { "Short summary." } + val result = model.summarize("Long text", bulletCount = 2) + assertEquals("Short summary.", result.summary) + } + + @Test + fun `translate extension`() = runBlocking { + val model = MockModel { "Hola" } + val result = model.translate("Hello", to = "es") + assertEquals("Hola", result.translatedText) + assertEquals("es", result.targetLanguage) + } + + @Test + fun `proofread extension`() = runBlocking { + val model = MockModel { "Fixed text." } + val result = model.proofread("Brkn text.") + assertEquals("Fixed text.", result.correctedText) + } +} diff --git a/packages/android/locanara/src/test/kotlin/com/locanara/ErrorHandlingTests.kt b/packages/android/locanara/src/test/kotlin/com/locanara/ErrorHandlingTests.kt new file mode 100644 index 0000000..f2ebcfb --- /dev/null +++ b/packages/android/locanara/src/test/kotlin/com/locanara/ErrorHandlingTests.kt @@ -0,0 +1,172 @@ +package com.locanara + +import com.locanara.builtin.ChatChain +import com.locanara.builtin.ClassifyChain +import com.locanara.builtin.ExtractChain +import com.locanara.builtin.ProofreadChain +import com.locanara.builtin.RewriteChain +import com.locanara.builtin.SummarizeChain +import com.locanara.builtin.TranslateChain +import kotlinx.coroutines.runBlocking +import org.junit.Assert.assertEquals + +import org.junit.Assert.assertTrue +import org.junit.Assert.fail +import org.junit.Test + +// MARK: - Error Handling Tests + +class ErrorHandlingTest { + + // --- LocanaraException property tests --- + + @Test + fun `ModelBusy has correct message and code`() { + val exception = LocanaraException.ModelBusy + val message = requireNotNull(exception.message) + assertTrue(message.contains("busy")) + assertEquals(ErrorCode.MODEL_BUSY, exception.code) + } + + @Test + fun `BackgroundUseBlocked has correct message and code`() { + val exception = LocanaraException.BackgroundUseBlocked + val message = requireNotNull(exception.message) + assertTrue(message.contains("foreground")) + assertEquals(ErrorCode.BACKGROUND_USE_BLOCKED, exception.code) + } + + @Test + fun `ExecutionFailed preserves reason in message`() { + val exception = LocanaraException.ExecutionFailed("something went wrong") + val message = requireNotNull(exception.message) + assertTrue(message.contains("something went wrong")) + assertEquals(ErrorCode.EXECUTION_FAILED, exception.code) + } + + @Test + fun `ExecutionFailed preserves cause`() { + val cause = RuntimeException("root cause") + val exception = LocanaraException.ExecutionFailed("wrapped", cause) + assertEquals(cause, exception.cause) + } + + @Test + fun `InvalidInput has correct code`() { + val exception = LocanaraException.InvalidInput("too short") + val message = requireNotNull(exception.message) + assertTrue(message.contains("too short")) + assertEquals(ErrorCode.INVALID_INPUT, exception.code) + } + + @Test + fun `DeviceNotSupported has correct code`() { + val exception: LocanaraException = LocanaraException.DeviceNotSupported + assertEquals(ErrorCode.DEVICE_NOT_SUPPORTED, exception.code) + } + + @Test + fun `PermissionDenied has correct code`() { + val exception = LocanaraException.PermissionDenied + assertEquals(ErrorCode.PERMISSION_DENIED, exception.code) + } + + // --- Chain error propagation tests --- + + @Test + fun `SummarizeChain propagates LocanaraException from model`() { + val chain = SummarizeChain(model = failingModel(LocanaraException.ExecutionFailed("model timeout"))) + try { + runBlocking { chain.run("test text") } + fail("Expected LocanaraException.ExecutionFailed") + } catch (e: LocanaraException.ExecutionFailed) { + val message = requireNotNull(e.message) + assertTrue(message.contains("model timeout")) + } + } + + @Test + fun `ClassifyChain propagates LocanaraException from model`() { + val chain = ClassifyChain(model = failingModel(LocanaraException.ModelBusy), categories = listOf("a", "b")) + try { + runBlocking { chain.run("text") } + fail("Expected LocanaraException.ModelBusy") + } catch (e: LocanaraException) { + assertTrue(e is LocanaraException.ModelBusy) + } + } + + @Test + fun `TranslateChain propagates LocanaraException from model`() { + val chain = TranslateChain(model = failingModel(LocanaraException.BackgroundUseBlocked), targetLanguage = "ko") + try { + runBlocking { chain.run("hello") } + fail("Expected LocanaraException.BackgroundUseBlocked") + } catch (e: LocanaraException) { + assertTrue(e is LocanaraException.BackgroundUseBlocked) + } + } + + @Test + fun `ProofreadChain propagates LocanaraException from model`() { + val chain = ProofreadChain(model = failingModel(LocanaraException.ExecutionFailed("inference failed"))) + try { + runBlocking { chain.run("text") } + fail("Expected LocanaraException.ExecutionFailed") + } catch (e: LocanaraException.ExecutionFailed) { + val message = requireNotNull(e.message) + assertTrue(message.contains("inference failed")) + } + } + + @Test + fun `ChatChain propagates LocanaraException from model`() { + val chain = ChatChain(model = failingModel(LocanaraException.ExecutionFailed("chat failed"))) + try { + runBlocking { chain.run("hello") } + fail("Expected LocanaraException.ExecutionFailed") + } catch (e: LocanaraException.ExecutionFailed) { + val message = requireNotNull(e.message) + assertTrue(message.contains("chat failed")) + } + } + + @Test + fun `RewriteChain propagates LocanaraException from model`() { + val chain = RewriteChain( + model = failingModel(LocanaraException.ExecutionFailed("rewrite failed")), + style = RewriteOutputType.FRIENDLY + ) + try { + runBlocking { chain.run("text") } + fail("Expected LocanaraException.ExecutionFailed") + } catch (e: LocanaraException.ExecutionFailed) { + val message = requireNotNull(e.message) + assertTrue(message.contains("rewrite failed")) + } + } + + @Test + fun `ExtractChain propagates LocanaraException from model`() { + val chain = ExtractChain( + model = failingModel(LocanaraException.ExecutionFailed("extract failed")), + entityTypes = listOf("person") + ) + try { + runBlocking { chain.run("Tim Cook") } + fail("Expected LocanaraException.ExecutionFailed") + } catch (e: LocanaraException.ExecutionFailed) { + val message = requireNotNull(e.message) + assertTrue(message.contains("extract failed")) + } + } + + // --- LocanaraException is-a Exception --- + + @Test + fun `LocanaraException is catchable as Exception`() { + val ex: Exception = LocanaraException.ModelBusy + assertTrue(ex is LocanaraException) + assertTrue(ex is Exception) + } +} diff --git a/packages/android/locanara/src/test/kotlin/com/locanara/FrameworkTest.kt b/packages/android/locanara/src/test/kotlin/com/locanara/FrameworkTest.kt deleted file mode 100644 index fc96725..0000000 --- a/packages/android/locanara/src/test/kotlin/com/locanara/FrameworkTest.kt +++ /dev/null @@ -1,662 +0,0 @@ -package com.locanara - -import com.locanara.builtin.ChatChain -import com.locanara.builtin.ClassifyChain -import com.locanara.builtin.ExtractChain -import com.locanara.builtin.ProofreadChain -import com.locanara.builtin.RewriteChain -import com.locanara.builtin.SummarizeChain -import com.locanara.builtin.TranslateChain -import org.junit.Assert.fail -import com.locanara.composable.BufferMemory -import com.locanara.composable.Chain -import com.locanara.composable.ContentFilterGuardrail -import com.locanara.composable.GuardrailResult -import com.locanara.composable.InputLengthGuardrail -import com.locanara.composable.SequentialChain -import com.locanara.core.ChainInput -import com.locanara.core.ChainOutput -import com.locanara.core.GenerationConfig -import com.locanara.core.LocanaraModel -import com.locanara.core.ModelResponse -import com.locanara.core.OutputParser -import com.locanara.core.PromptTemplate -import com.locanara.core.TextOutputParser -import com.locanara.dsl.pipeline -import com.locanara.dsl.summarize -import com.locanara.dsl.translate -import com.locanara.dsl.proofread -import com.locanara.runtime.ChainExecutor -import kotlinx.coroutines.flow.Flow -import kotlinx.coroutines.flow.emptyFlow -import kotlinx.coroutines.runBlocking -import org.junit.Assert.assertEquals -import org.junit.Assert.assertFalse -import org.junit.Assert.assertNotNull -import org.junit.Assert.assertNull -import org.junit.Assert.assertTrue -import org.junit.Test - -// MARK: - Mock Model - -class MockModel( - private val responseGenerator: (String) -> String = { "mock response" } -) : LocanaraModel { - override val name = "MockModel" - override val isReady = true - override val maxContextTokens = 4000 - - override suspend fun generate(prompt: String, config: GenerationConfig?): ModelResponse { - val text = responseGenerator(prompt) - return ModelResponse(text = text, processingTimeMs = 5) - } - - override fun stream(prompt: String, config: GenerationConfig?): Flow = emptyFlow() -} - -// MARK: - Core Layer Tests - -class PromptTemplateTest { - @Test - fun `basic formatting`() { - val template = PromptTemplate( - templateString = "Summarize this: {text}", - inputVariables = listOf("text") - ) - val result = template.format(mapOf("text" to "Hello world")) - assertEquals("Summarize this: Hello world", result) - } - - @Test - fun `multiple variables`() { - val template = PromptTemplate( - templateString = "Translate from {source} to {target}: {text}", - inputVariables = listOf("source", "target", "text") - ) - val result = template.format( - mapOf("source" to "English", "target" to "Korean", "text" to "Hello") - ) - assertEquals("Translate from English to Korean: Hello", result) - } - - @Test(expected = IllegalArgumentException::class) - fun `missing variable throws`() { - val template = PromptTemplate( - templateString = "Hello {name}", - inputVariables = listOf("name") - ) - template.format(emptyMap()) - } - - @Test - fun `auto detection`() { - val template = PromptTemplate.from("Hello {name}, welcome to {place}") - val result = template.format(mapOf("name" to "Alice", "place" to "Locanara")) - assertEquals("Hello Alice, welcome to Locanara", result) - } -} - -class OutputParserTest { - @Test - fun `text parser trims whitespace`() { - val parser = TextOutputParser() - val result = parser.parse(" hello world ") - assertEquals("hello world", result) - } -} - -class SchemaTest { - @Test - fun `chain input creation`() { - val input = ChainInput(text = "hello", metadata = mutableMapOf("key" to "value")) - assertEquals("hello", input.text) - assertEquals("value", input.metadata["key"]) - } - - @Test - fun `chain output typed`() { - val result = SummarizeResult( - summary = "test", originalLength = 100, summaryLength = 4 - ) - val output = ChainOutput(value = result, text = "test", processingTimeMs = 5) - - assertNotNull(output.typed()) - assertEquals("test", output.typed()?.summary) - assertNull(output.typed()) // wrong type - } -} - -// MARK: - Built-in Chain Tests - -class SummarizeChainTest { - @Test - fun `run returns typed result`() = runBlocking { - val model = MockModel { "This is a summary." } - val chain = SummarizeChain(model = model, bulletCount = 1) - - val result = chain.run("Long article text here...") - - assertEquals("This is a summary.", result.summary) - assertEquals("Long article text here...".length, result.originalLength) - } - - @Test - fun `invoke returns chain output`() = runBlocking { - val model = MockModel { "Summary text" } - val chain = SummarizeChain(model = model) - - val output = chain.invoke(ChainInput(text = "input")) - - assertEquals("Summary text", output.text) - assertNotNull(output.typed()) - } -} - -class ClassifyChainTest { - @Test - fun `run returns classify result`() = runBlocking { - val model = MockModel { "positive" } - val chain = ClassifyChain( - model = model, - categories = listOf("positive", "negative") - ) - - val result = chain.run("Great product!") - - assertEquals("positive", result.topClassification.label) - assertEquals(1.0, result.topClassification.score, 0.001) - } -} - -class TranslateChainTest { - @Test - fun `run returns translate result`() = runBlocking { - val model = MockModel { "안녕하세요" } - val chain = TranslateChain(model = model, targetLanguage = "ko") - - val result = chain.run("Hello") - - assertEquals("안녕하세요", result.translatedText) - assertEquals("en", result.sourceLanguage) - assertEquals("ko", result.targetLanguage) - } -} - -class RewriteChainTest { - @Test - fun `run returns rewrite result`() = runBlocking { - val model = MockModel { "Good day, how may I assist you?" } - val chain = RewriteChain(model = model, style = RewriteOutputType.PROFESSIONAL) - - val result = chain.run("hey whats up") - - assertEquals("Good day, how may I assist you?", result.rewrittenText) - assertEquals(RewriteOutputType.PROFESSIONAL, result.style) - } -} - -class ProofreadChainTest { - @Test - fun `run returns proofread result`() = runBlocking { - val model = MockModel { "This is a test." } - val chain = ProofreadChain(model = model) - - val result = chain.run("Ths is a tset.") - - assertEquals("This is a test.", result.correctedText) - assertTrue(result.hasCorrections) - } - - @Test - fun `no corrections detected`() = runBlocking { - val model = MockModel { "Already correct." } - val chain = ProofreadChain(model = model) - - val result = chain.run("Already correct.") - - assertFalse(result.hasCorrections) - } -} - -class ChatChainTest { - @Test - fun `run returns chat result`() = runBlocking { - val model = MockModel { "Hi there!" } - val chain = ChatChain(model = model) - - val result = chain.run("Hello!") - - assertEquals("Hi there!", result.message) - assertTrue(result.canContinue) - } - - @Test - fun `chat with memory saves entries`() = runBlocking { - val model = MockModel { "First response" } - val memory = BufferMemory(maxEntries = 10) - val chain = ChatChain(model = model, memory = memory) - - chain.run("First message") - - val entries = memory.load(ChainInput(text = "test")) - assertEquals(2, entries.size) // user + assistant - } -} - -class ExtractChainTest { - @Test - fun `run returns extract result`() = runBlocking { - val model = MockModel { "Tim Cook\nCupertino" } - val chain = ExtractChain(model = model, entityTypes = listOf("person", "location")) - - val result = chain.run("Tim Cook lives in Cupertino") - - assertEquals(2, result.entities.size) - assertEquals("Tim Cook", result.entities[0].value) - assertEquals("Cupertino", result.entities[1].value) - } -} - -// MARK: - Pipeline Tests - -class PipelineTest { - @Test - fun `single step pipeline`() = runBlocking { - val model = MockModel { "Summary of input." } - - val result = model.pipeline() - .summarize(bulletCount = 1) - .run("Long text here") - - // Compile-time: result is SummarizeResult - assertEquals("Summary of input.", result.summary) - } - - @Test - fun `multi step pipeline type safety`() = runBlocking { - var callCount = 0 - val model = MockModel { - callCount++ - if (callCount == 1) "Summarized text" else "번역된 텍스트" - } - - val result = model.pipeline() - .summarize(bulletCount = 3) - .translate(to = "ko") - .run("Long article in English") - - // Compile-time: result is TranslateResult (last step) - assertEquals("번역된 텍스트", result.translatedText) - assertEquals("ko", result.targetLanguage) - assertEquals(2, callCount) - } - - @Test - fun `three step pipeline`() = runBlocking { - var callCount = 0 - val model = MockModel { - callCount++ - when (callCount) { - 1 -> "Corrected text" - 2 -> "Professionally written text" - else -> "unexpected" - } - } - - val result = model.pipeline() - .proofread() - .rewrite(style = RewriteOutputType.PROFESSIONAL) - .run("messy text with erors") - - // Compile-time: result is RewriteResult - assertEquals("Professionally written text", result.rewrittenText) - assertEquals(RewriteOutputType.PROFESSIONAL, result.style) - } - - @Test - fun `pipeline passes text between steps`() = runBlocking { - val receivedPrompts = mutableListOf() - val model = MockModel { prompt -> - receivedPrompts.add(prompt) - if (receivedPrompts.size == 1) "step1 output" else "step2 output" - } - - model.pipeline() - .proofread() - .rewrite(style = RewriteOutputType.FRIENDLY) - .run("original input") - - // Second step should receive first step's output in its prompt - assertTrue(receivedPrompts[1].contains("step1 output")) - } -} - -// MARK: - Model Extension Tests - -class ModelExtensionTest { - @Test - fun `summarize extension`() = runBlocking { - val model = MockModel { "Short summary." } - val result = model.summarize("Long text", bulletCount = 2) - assertEquals("Short summary.", result.summary) - } - - @Test - fun `translate extension`() = runBlocking { - val model = MockModel { "Hola" } - val result = model.translate("Hello", to = "es") - assertEquals("Hola", result.translatedText) - assertEquals("es", result.targetLanguage) - } - - @Test - fun `proofread extension`() = runBlocking { - val model = MockModel { "Fixed text." } - val result = model.proofread("Brkn text.") - assertEquals("Fixed text.", result.correctedText) - } -} - -// MARK: - Composable Layer Tests - -class MemoryTest { - @Test - fun `buffer memory save and load`() = runBlocking { - val memory = BufferMemory(maxEntries = 5) - val input = ChainInput(text = "Hello") - val output = ChainOutput(value = "Hi", text = "Hi") - - memory.save(input, output) - - val entries = memory.load(ChainInput(text = "test")) - assertEquals(2, entries.size) - assertEquals("user", entries[0].role) - assertEquals("Hello", entries[0].content) - assertEquals("assistant", entries[1].role) - assertEquals("Hi", entries[1].content) - } - - @Test - fun `buffer memory trimming`() = runBlocking { - val memory = BufferMemory(maxEntries = 2) - - for (i in 0 until 5) { - memory.save( - ChainInput(text = "msg $i"), - ChainOutput(value = "resp $i", text = "resp $i") - ) - } - - val entries = memory.load(ChainInput(text = "test")) - assertTrue(entries.size <= 4) // maxEntries * 2 - } - - @Test - fun `buffer memory clear`() = runBlocking { - val memory = BufferMemory() - memory.save(ChainInput(text = "hello"), ChainOutput(value = "hi", text = "hi")) - memory.clear() - - val entries = memory.load(ChainInput(text = "test")) - assertEquals(0, entries.size) - } -} - -class GuardrailTest { - @Test - fun `input length passes`() = runBlocking { - val guardrail = InputLengthGuardrail(maxCharacters = 100) - val result = guardrail.checkInput(ChainInput(text = "short")) - assertTrue(result is GuardrailResult.Passed) - } - - @Test - fun `input length truncates`() = runBlocking { - val guardrail = InputLengthGuardrail(maxCharacters = 5, truncate = true) - val result = guardrail.checkInput(ChainInput(text = "longer text")) - assertTrue(result is GuardrailResult.Modified) - assertEquals("longe", (result as GuardrailResult.Modified).newText) - } - - @Test - fun `input length blocks`() = runBlocking { - val guardrail = InputLengthGuardrail(maxCharacters = 5, truncate = false) - val result = guardrail.checkInput(ChainInput(text = "longer text")) - assertTrue(result is GuardrailResult.Blocked) - } - - @Test - fun `content filter blocks`() = runBlocking { - val guardrail = ContentFilterGuardrail(blockedPatterns = listOf("password", "secret")) - val blocked = guardrail.checkInput(ChainInput(text = "my password is 123")) - assertTrue(blocked is GuardrailResult.Blocked) - - val passed = guardrail.checkInput(ChainInput(text = "Hello world")) - assertTrue(passed is GuardrailResult.Passed) - } -} - -// MARK: - Chain Executor Tests - -class ChainExecutorTest { - @Test - fun `execute records history`() = runBlocking { - val model = MockModel { "result" } - val chain = SummarizeChain(model = model) - val executor = ChainExecutor(maxRetries = 0) - - executor.execute(chain, ChainInput(text = "test")) - - val history = executor.getHistory() - assertEquals(1, history.size) - assertEquals("SummarizeChain", history[0].chainName) - assertTrue(history[0].success) - assertEquals(1, history[0].attempt) - } - - @Test - fun `clear history`() = runBlocking { - val model = MockModel { "result" } - val chain = SummarizeChain(model = model) - val executor = ChainExecutor() - - executor.execute(chain, ChainInput(text = "test")) - executor.clearHistory() - - assertEquals(0, executor.getHistory().size) - } -} - -// MARK: - Sequential Chain Tests - -class SequentialChainTest { - @Test - fun `sequential execution`() = runBlocking { - var callCount = 0 - val model = MockModel { - callCount++ - "step$callCount" - } - - val chain = SequentialChain( - chains = listOf( - ProofreadChain(model = model), - RewriteChain(model = model, style = RewriteOutputType.PROFESSIONAL) - ) - ) - - val output = chain.invoke(ChainInput(text = "input")) - - assertEquals(2, callCount) - assertEquals("step2", output.text) - } -} - -// MARK: - Error Handling Tests - -/** Helper: a LocanaraModel that always throws the given exception. */ -private fun failingModel(error: LocanaraException): LocanaraModel = object : LocanaraModel { - override val name = "FailingModel" - override val isReady = true - override val maxContextTokens = 4000 - override suspend fun generate(prompt: String, config: GenerationConfig?): ModelResponse = throw error - override fun stream(prompt: String, config: GenerationConfig?): Flow = emptyFlow() -} - -class ErrorHandlingTest { - - // --- LocanaraException property tests --- - - @Test - fun `ModelBusy has correct message and code`() { - val exception = LocanaraException.ModelBusy - val message = exception.message - assertNotNull(message) - assertTrue(message!!.contains("busy")) - assertEquals(ErrorCode.MODEL_BUSY, exception.code) - } - - @Test - fun `BackgroundUseBlocked has correct message and code`() { - val exception = LocanaraException.BackgroundUseBlocked - val message = exception.message - assertNotNull(message) - assertTrue(message!!.contains("foreground")) - assertEquals(ErrorCode.BACKGROUND_USE_BLOCKED, exception.code) - } - - @Test - fun `ExecutionFailed preserves reason in message`() { - val exception = LocanaraException.ExecutionFailed("something went wrong") - val message = exception.message - assertNotNull(message) - assertTrue(message!!.contains("something went wrong")) - assertEquals(ErrorCode.EXECUTION_FAILED, exception.code) - } - - @Test - fun `ExecutionFailed preserves cause`() { - val cause = RuntimeException("root cause") - val exception = LocanaraException.ExecutionFailed("wrapped", cause) - assertEquals(cause, exception.cause) - } - - @Test - fun `InvalidInput has correct code`() { - val exception = LocanaraException.InvalidInput("too short") - val message = exception.message - assertNotNull(message) - assertTrue(message!!.contains("too short")) - assertEquals(ErrorCode.INVALID_INPUT, exception.code) - } - - @Test - fun `DeviceNotSupported has correct code`() { - val exception: LocanaraException = LocanaraException.DeviceNotSupported - assertEquals(ErrorCode.DEVICE_NOT_SUPPORTED, exception.code) - } - - @Test - fun `PermissionDenied has correct code`() { - val exception = LocanaraException.PermissionDenied - assertEquals(ErrorCode.PERMISSION_DENIED, exception.code) - } - - // --- Chain error propagation tests --- - - @Test - fun `SummarizeChain propagates LocanaraException from model`() { - val chain = SummarizeChain(model = failingModel(LocanaraException.ExecutionFailed("model timeout"))) - try { - runBlocking { chain.run("test text") } - fail("Expected LocanaraException.ExecutionFailed") - } catch (e: LocanaraException.ExecutionFailed) { - val message = e.message - assertNotNull(message) - assertTrue(message!!.contains("model timeout")) - } - } - - @Test - fun `ClassifyChain propagates LocanaraException from model`() { - val chain = ClassifyChain(model = failingModel(LocanaraException.ModelBusy), categories = listOf("a", "b")) - try { - runBlocking { chain.run("text") } - fail("Expected LocanaraException.ModelBusy") - } catch (e: LocanaraException) { - assertTrue(e is LocanaraException.ModelBusy) - } - } - - @Test - fun `TranslateChain propagates LocanaraException from model`() { - val chain = TranslateChain(model = failingModel(LocanaraException.BackgroundUseBlocked), targetLanguage = "ko") - try { - runBlocking { chain.run("hello") } - fail("Expected LocanaraException.BackgroundUseBlocked") - } catch (e: LocanaraException) { - assertTrue(e is LocanaraException.BackgroundUseBlocked) - } - } - - @Test - fun `ProofreadChain propagates LocanaraException from model`() { - val chain = ProofreadChain(model = failingModel(LocanaraException.ExecutionFailed("inference failed"))) - try { - runBlocking { chain.run("text") } - fail("Expected LocanaraException.ExecutionFailed") - } catch (e: LocanaraException.ExecutionFailed) { - val message = e.message - assertNotNull(message) - assertTrue(message!!.contains("inference failed")) - } - } - - @Test - fun `ChatChain propagates LocanaraException from model`() { - val chain = ChatChain(model = failingModel(LocanaraException.ExecutionFailed("chat failed"))) - try { - runBlocking { chain.run("hello") } - fail("Expected LocanaraException.ExecutionFailed") - } catch (e: LocanaraException.ExecutionFailed) { - val message = e.message - assertNotNull(message) - assertTrue(message!!.contains("chat failed")) - } - } - - @Test - fun `RewriteChain propagates LocanaraException from model`() { - val chain = RewriteChain(model = failingModel(LocanaraException.ExecutionFailed("rewrite failed")), style = RewriteOutputType.FRIENDLY) - try { - runBlocking { chain.run("text") } - fail("Expected LocanaraException.ExecutionFailed") - } catch (e: LocanaraException.ExecutionFailed) { - val message = e.message - assertNotNull(message) - assertTrue(message!!.contains("rewrite failed")) - } - } - - @Test - fun `ExtractChain propagates LocanaraException from model`() { - val chain = ExtractChain(model = failingModel(LocanaraException.ExecutionFailed("extract failed")), entityTypes = listOf("person")) - try { - runBlocking { chain.run("Tim Cook") } - fail("Expected LocanaraException.ExecutionFailed") - } catch (e: LocanaraException.ExecutionFailed) { - val message = e.message - assertNotNull(message) - assertTrue(message!!.contains("extract failed")) - } - } - - // --- LocanaraException is-a Exception --- - - @Test - fun `LocanaraException is catchable as Exception`() { - val ex: Exception = LocanaraException.ModelBusy - assertTrue(ex is LocanaraException) - assertTrue(ex is Exception) - } -} diff --git a/packages/android/locanara/src/test/kotlin/com/locanara/RAGTests.kt b/packages/android/locanara/src/test/kotlin/com/locanara/RAGTests.kt new file mode 100644 index 0000000..47888e9 --- /dev/null +++ b/packages/android/locanara/src/test/kotlin/com/locanara/RAGTests.kt @@ -0,0 +1,189 @@ +package com.locanara + +import com.locanara.rag.ChunkingConfig +import com.locanara.rag.DocumentChunker +import org.junit.Assert.assertEquals +import org.junit.Assert.assertFalse +import org.junit.Assert.assertNotNull +import org.junit.Assert.assertTrue +import org.junit.Test + +// MARK: - RAG Layer Tests +// +// VectorStore and RAGManager require an Android Context and are covered by +// instrumentation tests. This suite focuses on the pure-JVM DocumentChunker, +// ChunkingConfig, and DocumentChunk which can run on the host JVM. + +class ChunkingConfigTest { + @Test + fun `default config has sensible values`() { + val config = ChunkingConfig.DEFAULT + assertTrue(config.targetChunkSize > 0) + assertTrue(config.chunkOverlap >= 0) + assertTrue(config.chunkOverlap < config.targetChunkSize) + assertTrue(config.minChunkSize >= 0) + } + + @Test + fun `long document config has larger chunk size`() { + assertTrue(ChunkingConfig.LONG_DOCUMENT.targetChunkSize > ChunkingConfig.DEFAULT.targetChunkSize) + } + + @Test(expected = IllegalArgumentException::class) + fun `non-positive targetChunkSize throws`() { + ChunkingConfig(targetChunkSize = 0) + } + + @Test(expected = IllegalArgumentException::class) + fun `negative chunkOverlap throws`() { + ChunkingConfig(targetChunkSize = 512, chunkOverlap = -1) + } + + @Test(expected = IllegalArgumentException::class) + fun `chunkOverlap equal to targetChunkSize throws`() { + ChunkingConfig(targetChunkSize = 100, chunkOverlap = 100) + } + + @Test(expected = IllegalArgumentException::class) + fun `chunkOverlap greater than targetChunkSize throws`() { + ChunkingConfig(targetChunkSize = 100, chunkOverlap = 150) + } +} + +class DocumentChunkerTest { + private val chunker = DocumentChunker( + ChunkingConfig( + targetChunkSize = 100, + chunkOverlap = 10, + respectSentences = false, + minChunkSize = 10 + ) + ) + + @Test + fun `empty text returns no chunks`() { + val chunks = chunker.chunk("") + assertTrue(chunks.isEmpty()) + } + + @Test + fun `short text produces single chunk`() { + val text = "Hello world." + val chunks = chunker.chunk(text) + assertEquals(1, chunks.size) + assertEquals(text, chunks[0].content) + } + + @Test + fun `chunk index starts at zero`() { + val text = "Hello world." + val chunks = chunker.chunk(text) + assertEquals(0, chunks[0].index) + } + + @Test + fun `chunks have unique ids`() { + val text = "A".repeat(500) + val chunks = chunker.chunk(text) + val ids = chunks.map { it.id }.toSet() + assertEquals(chunks.size, ids.size) + } + + @Test + fun `long text produces multiple chunks`() { + val text = "A".repeat(500) + val chunks = chunker.chunk(text) + assertTrue(chunks.size > 1) + } + + @Test + fun `metadata is attached to all chunks`() { + val text = "A".repeat(300) + val metadata = mapOf("source" to "test-doc", "author" to "Alice") + val chunks = chunker.chunk(text, metadata = metadata) + chunks.forEach { chunk -> + assertNotNull(chunk.metadata) + assertEquals("test-doc", chunk.metadata?.get("source")) + assertEquals("Alice", chunk.metadata?.get("author")) + } + } + + @Test + fun `chunk offsets are non-negative`() { + val text = "B".repeat(300) + val chunks = chunker.chunk(text) + chunks.forEach { chunk -> + assertTrue(chunk.startOffset >= 0) + assertTrue(chunk.endOffset > chunk.startOffset) + } + } + + @Test + fun `sentence-respecting chunker splits on sentences`() { + val sentenceChunker = DocumentChunker( + ChunkingConfig( + targetChunkSize = 60, + chunkOverlap = 0, + respectSentences = true, + minChunkSize = 0 + ) + ) + // Three short sentences that together exceed targetChunkSize + val text = "First sentence. Second sentence. Third sentence. Fourth sentence." + val chunks = sentenceChunker.chunk(text) + assertTrue(chunks.isNotEmpty()) + // Each chunk should not dramatically exceed the targetChunkSize + chunks.forEach { chunk -> + assertTrue( + "Chunk length ${chunk.content.length} should be reasonable", + chunk.content.length <= 200 // generous bound + ) + } + } + + @Test + fun `estimate chunk count is positive for non-empty text`() { + val count = chunker.estimateChunkCount("Hello world, this is a test sentence.") + assertTrue(count >= 1) + } + + @Test + fun `estimate chunk count is zero for empty text`() { + val count = chunker.estimateChunkCount("") + assertEquals(0, count) + } + + @Test + fun `chunking stats are correct`() { + val text = "A".repeat(300) + val chunks = chunker.chunk(text) + val stats = chunker.getChunkingStats(chunks) + + assertEquals(chunks.size, stats.count) + assertTrue(stats.minSize > 0) + assertTrue(stats.maxSize >= stats.minSize) + assertTrue(stats.avgSize > 0) + assertTrue(stats.totalSize > 0) + } + + @Test + fun `empty chunks list returns zero stats`() { + val stats = chunker.getChunkingStats(emptyList()) + assertEquals(0, stats.count) + assertEquals(0, stats.minSize) + assertEquals(0, stats.maxSize) + assertEquals(0, stats.avgSize) + assertEquals(0, stats.totalSize) + } + + @Test + fun `chunk content is non-empty`() { + val text = "The quick brown fox jumps over the lazy dog. " + + "Pack my box with five dozen liquor jugs. " + + "How vexingly quick daft zebras jump." + val chunks = chunker.chunk(text) + chunks.forEach { chunk -> + assertFalse("Chunk content should not be empty", chunk.content.isEmpty()) + } + } +} diff --git a/packages/android/locanara/src/test/kotlin/com/locanara/TestHelpers.kt b/packages/android/locanara/src/test/kotlin/com/locanara/TestHelpers.kt new file mode 100644 index 0000000..80666c6 --- /dev/null +++ b/packages/android/locanara/src/test/kotlin/com/locanara/TestHelpers.kt @@ -0,0 +1,40 @@ +package com.locanara + +import com.locanara.core.GenerationConfig +import com.locanara.core.LocanaraModel +import com.locanara.core.ModelResponse +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.emptyFlow + +// MARK: - Shared Mock Helpers + +/** + * A configurable mock model for testing. + * @param responseGenerator A function that returns a response for a given prompt. + */ +class MockModel( + private val responseGenerator: (String) -> String = { "mock response" } +) : LocanaraModel { + override val name = "MockModel" + override val isReady = true + override val maxContextTokens = 4000 + + override suspend fun generate(prompt: String, config: GenerationConfig?): ModelResponse { + val text = responseGenerator(prompt) + return ModelResponse(text = text, processingTimeMs = 5) + } + + override fun stream(prompt: String, config: GenerationConfig?): Flow = emptyFlow() +} + +/** + * Creates a [LocanaraModel] stub that always throws the given [LocanaraException] when + * [generate] is called. Used to test error propagation through chains. + */ +fun failingModel(error: LocanaraException): LocanaraModel = object : LocanaraModel { + override val name = "FailingModel" + override val isReady = true + override val maxContextTokens = 4000 + override suspend fun generate(prompt: String, config: GenerationConfig?): ModelResponse = throw error + override fun stream(prompt: String, config: GenerationConfig?): Flow = emptyFlow() +}