-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
1031112
commit b9b0317
Showing
10 changed files
with
308 additions
and
77 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
18 changes: 18 additions & 0 deletions
18
src/main/scala/com/cloud/apim/otoroshi/extensions/aigateway/decorators/decorators.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
package com.cloud.apim.otoroshi.extensions.aigateway.decorators | ||
|
||
import com.cloud.apim.otoroshi.extensions.aigateway.ChatClient | ||
import com.cloud.apim.otoroshi.extensions.aigateway.entities.AiProvider | ||
|
||
object ChatClientDecorators { | ||
|
||
val possibleDecorators: Seq[Function[(AiProvider, ChatClient), ChatClient]] = Seq( | ||
ChatClientWithRegexValidation.applyIfPossible, | ||
ChatClientWithLlmValidation.applyIfPossible, | ||
) | ||
|
||
def apply(provider: AiProvider, client: ChatClient): ChatClient = { | ||
possibleDecorators.foldLeft(client) { | ||
case (client, predicate) => predicate((provider, client)) | ||
} | ||
} | ||
} |
5 changes: 5 additions & 0 deletions
5
...in/scala/com/cloud/apim/otoroshi/extensions/aigateway/decorators/externalvalidation.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
package com.cloud.apim.otoroshi.extensions.aigateway.decorators | ||
|
||
class externalvalidation { | ||
|
||
} |
76 changes: 76 additions & 0 deletions
76
src/main/scala/com/cloud/apim/otoroshi/extensions/aigateway/decorators/llmvalidation.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
package com.cloud.apim.otoroshi.extensions.aigateway.decorators | ||
|
||
import com.cloud.apim.otoroshi.extensions.aigateway.entities.AiProvider | ||
import com.cloud.apim.otoroshi.extensions.aigateway.{ChatClient, ChatMessage, ChatPrompt, ChatResponse} | ||
import otoroshi.env.Env | ||
import otoroshi.utils.TypedMap | ||
import otoroshi.utils.syntax.implicits._ | ||
import otoroshi_plugins.com.cloud.apim.extensions.aigateway.AiExtension | ||
import play.api.libs.json.{JsValue, Json} | ||
|
||
import scala.concurrent.{ExecutionContext, Future} | ||
|
||
object ChatClientWithLlmValidation { | ||
def applyIfPossible(tuple: (AiProvider, ChatClient)): ChatClient = { | ||
if (tuple._1.validatorRef.isDefined && tuple._1.validatorPrompt.isDefined) { | ||
new ChatClientWithLlmValidation(tuple._1, tuple._2) | ||
} else { | ||
tuple._2 | ||
} | ||
} | ||
} | ||
|
||
class ChatClientWithLlmValidation(originalProvider: AiProvider, chatClient: ChatClient) extends ChatClient { | ||
|
||
override def call(originalPrompt: ChatPrompt, attrs: TypedMap)(implicit ec: ExecutionContext, env: Env): Future[Either[JsValue, ChatResponse]] = { | ||
|
||
def pass(): Future[Either[JsValue, ChatResponse]] = chatClient.call(originalPrompt, attrs) | ||
|
||
def fail(idx: Int): Future[Either[JsValue, ChatResponse]] = Left(Json.obj("error" -> "bad_request", "error_description" -> s"request content did not pass llm validation (${idx})")).vfuture | ||
|
||
originalProvider.validatorRef match { | ||
case None => pass() | ||
case Some(ref) if ref == originalProvider.id => pass() | ||
case Some(ref) => { | ||
originalProvider.validatorPrompt match { | ||
case None => pass() | ||
case Some(pref) => env.adminExtensions.extension[AiExtension].flatMap(_.states.prompt(pref)) match { | ||
case None => Left(Json.obj("error" -> "validation prompt not found")).vfuture | ||
case Some(prompt) => { | ||
env.adminExtensions.extension[AiExtension].flatMap(_.states.provider(ref).flatMap(_.getChatClient())) match { | ||
case None => Left(Json.obj("error" -> "validation provider not found")).vfuture | ||
case Some(validationClient) => { | ||
validationClient.call(ChatPrompt(Seq( | ||
ChatMessage("system", prompt.prompt) | ||
) ++ originalPrompt.messages), attrs).flatMap { | ||
case Left(err) => fail(2) | ||
case Right(resp) => { | ||
val content = resp.generations.head.message.content.toLowerCase().trim.replace("\n", " ") | ||
println(s"content: '${content}'") | ||
if (content == "true") { | ||
pass() | ||
} else if (content == "false") { | ||
fail(3) | ||
} else if (content.startsWith("{") && content.endsWith("}")) { | ||
if (Json.parse(content).select("result").asOpt[Boolean].getOrElse(false)) { | ||
pass() | ||
} else { | ||
fail(4) | ||
} | ||
} else { | ||
content.split(" ").headOption match { | ||
case Some("true") => pass() | ||
case _ => fail(5) | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} |
46 changes: 46 additions & 0 deletions
46
src/main/scala/com/cloud/apim/otoroshi/extensions/aigateway/decorators/regexvalidation.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
package com.cloud.apim.otoroshi.extensions.aigateway.decorators | ||
|
||
import com.cloud.apim.otoroshi.extensions.aigateway.entities.AiProvider | ||
import com.cloud.apim.otoroshi.extensions.aigateway.{ChatClient, ChatPrompt, ChatResponse} | ||
import otoroshi.env.Env | ||
import otoroshi.utils.syntax.implicits._ | ||
import otoroshi.utils.{RegexPool, TypedMap} | ||
import play.api.libs.json.{JsValue, Json} | ||
|
||
import scala.concurrent.{ExecutionContext, Future} | ||
|
||
object ChatClientWithRegexValidation { | ||
def applyIfPossible(tuple: (AiProvider, ChatClient)): ChatClient = { | ||
if (tuple._1.allow.nonEmpty || tuple._1.deny.nonEmpty) { | ||
new ChatClientWithRegexValidation(tuple._1, tuple._2) | ||
} else { | ||
tuple._2 | ||
} | ||
} | ||
} | ||
|
||
class ChatClientWithRegexValidation(originalProvider: AiProvider, chatClient: ChatClient) extends ChatClient { | ||
|
||
private val allow = originalProvider.allow | ||
private val deny = originalProvider.deny | ||
|
||
private def validate(content: String): Boolean = { | ||
val allowed = if (allow.isEmpty) true else allow.exists(al => RegexPool.regex(al).matches(content)) | ||
val denied = if (deny.isEmpty) false else deny.exists(dn => RegexPool.regex(dn).matches(content)) | ||
!denied && allowed | ||
} | ||
|
||
override def call(originalPrompt: ChatPrompt, attrs: TypedMap)(implicit ec: ExecutionContext, env: Env): Future[Either[JsValue, ChatResponse]] = { | ||
|
||
def pass(): Future[Either[JsValue, ChatResponse]] = chatClient.call(originalPrompt, attrs) | ||
|
||
def fail(): Future[Either[JsValue, ChatResponse]] = Left(Json.obj("error" -> "bad_request", "error_description" -> s"request content did not pass regex validation")).vfuture | ||
|
||
val contents = originalPrompt.messages.map(_.content) | ||
if (!contents.forall(content => validate(content))) { | ||
fail() | ||
} else { | ||
pass() | ||
} | ||
} | ||
} |
83 changes: 83 additions & 0 deletions
83
src/main/scala/com/cloud/apim/otoroshi/extensions/aigateway/decorators/semanticcache.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
package com.cloud.apim.otoroshi.extensions.aigateway.decorators | ||
|
||
import com.cloud.apim.otoroshi.extensions.aigateway.{ChatClient, ChatGeneration, ChatMessage, ChatPrompt, ChatResponse, ChatResponseMetadata} | ||
import com.cloud.apim.otoroshi.extensions.aigateway.entities.AiProvider | ||
import com.github.blemale.scaffeine.Scaffeine | ||
import dev.langchain4j.data.segment.TextSegment | ||
import dev.langchain4j.model.embedding.onnx.allminilml6v2.AllMiniLmL6V2EmbeddingModel | ||
import dev.langchain4j.store.embedding.EmbeddingSearchRequest | ||
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore | ||
import otoroshi.utils.syntax.implicits._ | ||
import otoroshi.env.Env | ||
import otoroshi.utils.TypedMap | ||
import play.api.libs.json.JsValue | ||
|
||
import scala.collection.concurrent.TrieMap | ||
import scala.concurrent.{ExecutionContext, Future} | ||
import scala.concurrent.duration.{DurationInt, FiniteDuration} | ||
import scala.jdk.CollectionConverters._ | ||
|
||
|
||
object ChatClientWithSemanticCache { | ||
val embeddingStores = new TrieMap[String, InMemoryEmbeddingStore[TextSegment]]() | ||
val embeddingModel = new AllMiniLmL6V2EmbeddingModel() | ||
val cache = Scaffeine() | ||
.expireAfter[String, (FiniteDuration, ChatResponse)]( | ||
create = (key, value) => value._1, | ||
update = (key, value, currentDuration) => currentDuration, | ||
read = (key, value, currentDuration) => currentDuration | ||
) | ||
.maximumSize(5000) | ||
.build[String, (FiniteDuration, ChatResponse)]() | ||
def applyIfPossible(tuple: (AiProvider, ChatClient)): ChatClient = { | ||
if (tuple._1.cacheStrategy.contains("semantic")) { | ||
new ChatClientWithSemanticCache(tuple._1, tuple._2) | ||
} else { | ||
tuple._2 | ||
} | ||
} | ||
} | ||
|
||
class ChatClientWithSemanticCache(originalProvider: AiProvider, chatClient: ChatClient) extends ChatClient { | ||
|
||
private val ttl = originalProvider.ttl.getOrElse(24.hours) | ||
|
||
override def call(originalPrompt: ChatPrompt, attrs: TypedMap)(implicit ec: ExecutionContext, env: Env): Future[Either[JsValue, ChatResponse]] = { | ||
val query = originalPrompt.messages.filter(_.role.toLowerCase().trim == "user").map(_.content).mkString(", ") | ||
val key = query.sha512 | ||
ChatClientWithSemanticCache.cache.getIfPresent(key) match { | ||
case Some((_, response)) => | ||
println("using semantic cached response") | ||
response.rightf | ||
case None => { | ||
val embeddingModel = ChatClientWithSemanticCache.embeddingModel | ||
val embeddingStore = ChatClientWithSemanticCache.embeddingStores.getOrUpdate(originalProvider.id) { | ||
new InMemoryEmbeddingStore[TextSegment]() | ||
} | ||
val queryEmbedding = embeddingModel.embed(query).content() | ||
val relevant = embeddingStore.search(EmbeddingSearchRequest.builder().queryEmbedding(queryEmbedding).maxResults(1).build()) | ||
val matches = relevant.matches().asScala | ||
if (matches.nonEmpty) { | ||
val resp = matches.head | ||
val id = resp.embeddingId() | ||
val text = resp.embedded().text() | ||
println("using semantic response") | ||
val chatResponse = ChatResponse(Seq(ChatGeneration(ChatMessage("assistant", text))), ChatResponseMetadata.empty) | ||
ChatClientWithSemanticCache.cache.put(key, (ttl, chatResponse)) | ||
chatResponse.rightf | ||
} else { | ||
chatClient.call(originalPrompt, attrs).map { | ||
case Left(err) => err.left | ||
case Right(resp) => { | ||
val segment = TextSegment.from(resp.generations.head.message.content) | ||
val embedding = embeddingModel.embed(segment).content() | ||
embeddingStore.add(key, embedding, segment) | ||
ChatClientWithSemanticCache.cache.put(key, (ttl, resp)) | ||
resp.right | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} |
53 changes: 53 additions & 0 deletions
53
src/main/scala/com/cloud/apim/otoroshi/extensions/aigateway/decorators/simplecache.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
package com.cloud.apim.otoroshi.extensions.aigateway.decorators | ||
|
||
import com.cloud.apim.otoroshi.extensions.aigateway.entities.AiProvider | ||
import com.cloud.apim.otoroshi.extensions.aigateway.{ChatClient, ChatPrompt, ChatResponse} | ||
import com.github.blemale.scaffeine.Scaffeine | ||
import otoroshi.env.Env | ||
import otoroshi.utils.TypedMap | ||
import otoroshi.utils.syntax.implicits._ | ||
import play.api.libs.json.JsValue | ||
|
||
import scala.concurrent.duration.{DurationInt, FiniteDuration} | ||
import scala.concurrent.{ExecutionContext, Future} | ||
|
||
object ChatClientWithSimpleCache { | ||
val cache = Scaffeine() | ||
.expireAfter[String, (FiniteDuration, ChatResponse)]( | ||
create = (key, value) => value._1, | ||
update = (key, value, currentDuration) => currentDuration, | ||
read = (key, value, currentDuration) => currentDuration | ||
) | ||
.maximumSize(5000) | ||
.build[String, (FiniteDuration, ChatResponse)]() | ||
def applyIfPossible(tuple: (AiProvider, ChatClient)): ChatClient = { | ||
if (tuple._1.cacheStrategy.contains("simple")) { | ||
new ChatClientWithSimpleCache(tuple._1, tuple._2) | ||
} else { | ||
tuple._2 | ||
} | ||
} | ||
} | ||
|
||
class ChatClientWithSimpleCache(originalProvider: AiProvider, chatClient: ChatClient) extends ChatClient { | ||
|
||
private val ttl = originalProvider.ttl.getOrElse(24.hours) | ||
|
||
override def call(originalPrompt: ChatPrompt, attrs: TypedMap)(implicit ec: ExecutionContext, env: Env): Future[Either[JsValue, ChatResponse]] = { | ||
val key = originalPrompt.messages.map(m => s"${m.role}:${m.content}").mkString(",").sha512 | ||
ChatClientWithSimpleCache.cache.getIfPresent(key) match { | ||
case Some((_, response)) => | ||
println("using simple cache response") | ||
response.rightf | ||
case None => { | ||
chatClient.call(originalPrompt, attrs).map { | ||
case Left(err) => err.left | ||
case Right(resp) => { | ||
ChatClientWithSimpleCache.cache.put(key, (ttl, resp)) | ||
resp.right | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.