Skip to content

Commit

Permalink
add decorators
Browse files Browse the repository at this point in the history
  • Loading branch information
mathieuancelin committed Jul 26, 2024
1 parent 1031112 commit b9b0317
Show file tree
Hide file tree
Showing 10 changed files with 308 additions and 77 deletions.
12 changes: 10 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,23 @@ ThisBuild / version := "1.0.0-dev"
ThisBuild / organization := "com.cloud-apim"
ThisBuild / organizationName := "Cloud-APIM"

lazy val springAiVersion = "0.8.1-SNAPSHOT"

lazy val jackson = Seq(
ExclusionRule("com.fasterxml.jackson.core", "jackson-databind"),
ExclusionRule("io.opentelemetry"),
)

lazy val slf4j = Seq(
ExclusionRule("org.slf4j"),
ExclusionRule("ch.qos.logback")
)

lazy val netty = Seq(
ExclusionRule("io.netty", "netty-transport-native-epoll"),
ExclusionRule("io.netty", "netty-transport-native-kqueue"),
)

lazy val all = jackson ++ slf4j

lazy val root = (project in file("."))
.settings(
name := "otoroshi-llm-extension",
Expand All @@ -27,6 +32,9 @@ lazy val root = (project in file("."))
),
libraryDependencies ++= Seq(
"fr.maif" %% "otoroshi" % "16.18.4" % "provided" excludeAll (netty: _*),
"dev.langchain4j" % "langchain4j" % "0.33.0" excludeAll(all: _*),
"dev.langchain4j" % "langchain4j-embeddings" % "0.33.0" excludeAll(all: _*),
"dev.langchain4j" % "langchain4j-embeddings-all-minilm-l6-v2" % "0.33.0" excludeAll(all: _*),
"io.netty" % "netty-transport-native-kqueue" % "4.1.107.Final" % "provided" excludeAll(jackson: _*),
"io.netty" % "netty-transport-native-epoll" % "4.1.107.Final" % "provided" excludeAll(jackson: _*),
munit % Test
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package com.cloud.apim.otoroshi.extensions.aigateway.decorators

import com.cloud.apim.otoroshi.extensions.aigateway.ChatClient
import com.cloud.apim.otoroshi.extensions.aigateway.entities.AiProvider

object ChatClientDecorators {

val possibleDecorators: Seq[Function[(AiProvider, ChatClient), ChatClient]] = Seq(
ChatClientWithRegexValidation.applyIfPossible,
ChatClientWithLlmValidation.applyIfPossible,
)

def apply(provider: AiProvider, client: ChatClient): ChatClient = {
possibleDecorators.foldLeft(client) {
case (client, predicate) => predicate((provider, client))
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package com.cloud.apim.otoroshi.extensions.aigateway.decorators

class externalvalidation {

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package com.cloud.apim.otoroshi.extensions.aigateway.decorators

import com.cloud.apim.otoroshi.extensions.aigateway.entities.AiProvider
import com.cloud.apim.otoroshi.extensions.aigateway.{ChatClient, ChatMessage, ChatPrompt, ChatResponse}
import otoroshi.env.Env
import otoroshi.utils.TypedMap
import otoroshi.utils.syntax.implicits._
import otoroshi_plugins.com.cloud.apim.extensions.aigateway.AiExtension
import play.api.libs.json.{JsValue, Json}

import scala.concurrent.{ExecutionContext, Future}

object ChatClientWithLlmValidation {
def applyIfPossible(tuple: (AiProvider, ChatClient)): ChatClient = {
if (tuple._1.validatorRef.isDefined && tuple._1.validatorPrompt.isDefined) {
new ChatClientWithLlmValidation(tuple._1, tuple._2)
} else {
tuple._2
}
}
}

class ChatClientWithLlmValidation(originalProvider: AiProvider, chatClient: ChatClient) extends ChatClient {

override def call(originalPrompt: ChatPrompt, attrs: TypedMap)(implicit ec: ExecutionContext, env: Env): Future[Either[JsValue, ChatResponse]] = {

def pass(): Future[Either[JsValue, ChatResponse]] = chatClient.call(originalPrompt, attrs)

def fail(idx: Int): Future[Either[JsValue, ChatResponse]] = Left(Json.obj("error" -> "bad_request", "error_description" -> s"request content did not pass llm validation (${idx})")).vfuture

originalProvider.validatorRef match {
case None => pass()
case Some(ref) if ref == originalProvider.id => pass()
case Some(ref) => {
originalProvider.validatorPrompt match {
case None => pass()
case Some(pref) => env.adminExtensions.extension[AiExtension].flatMap(_.states.prompt(pref)) match {
case None => Left(Json.obj("error" -> "validation prompt not found")).vfuture
case Some(prompt) => {
env.adminExtensions.extension[AiExtension].flatMap(_.states.provider(ref).flatMap(_.getChatClient())) match {
case None => Left(Json.obj("error" -> "validation provider not found")).vfuture
case Some(validationClient) => {
validationClient.call(ChatPrompt(Seq(
ChatMessage("system", prompt.prompt)
) ++ originalPrompt.messages), attrs).flatMap {
case Left(err) => fail(2)
case Right(resp) => {
val content = resp.generations.head.message.content.toLowerCase().trim.replace("\n", " ")
println(s"content: '${content}'")
if (content == "true") {
pass()
} else if (content == "false") {
fail(3)
} else if (content.startsWith("{") && content.endsWith("}")) {
if (Json.parse(content).select("result").asOpt[Boolean].getOrElse(false)) {
pass()
} else {
fail(4)
}
} else {
content.split(" ").headOption match {
case Some("true") => pass()
case _ => fail(5)
}
}
}
}
}
}
}
}
}
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package com.cloud.apim.otoroshi.extensions.aigateway.decorators

import com.cloud.apim.otoroshi.extensions.aigateway.entities.AiProvider
import com.cloud.apim.otoroshi.extensions.aigateway.{ChatClient, ChatPrompt, ChatResponse}
import otoroshi.env.Env
import otoroshi.utils.syntax.implicits._
import otoroshi.utils.{RegexPool, TypedMap}
import play.api.libs.json.{JsValue, Json}

import scala.concurrent.{ExecutionContext, Future}

object ChatClientWithRegexValidation {
def applyIfPossible(tuple: (AiProvider, ChatClient)): ChatClient = {
if (tuple._1.allow.nonEmpty || tuple._1.deny.nonEmpty) {
new ChatClientWithRegexValidation(tuple._1, tuple._2)
} else {
tuple._2
}
}
}

class ChatClientWithRegexValidation(originalProvider: AiProvider, chatClient: ChatClient) extends ChatClient {

private val allow = originalProvider.allow
private val deny = originalProvider.deny

private def validate(content: String): Boolean = {
val allowed = if (allow.isEmpty) true else allow.exists(al => RegexPool.regex(al).matches(content))
val denied = if (deny.isEmpty) false else deny.exists(dn => RegexPool.regex(dn).matches(content))
!denied && allowed
}

override def call(originalPrompt: ChatPrompt, attrs: TypedMap)(implicit ec: ExecutionContext, env: Env): Future[Either[JsValue, ChatResponse]] = {

def pass(): Future[Either[JsValue, ChatResponse]] = chatClient.call(originalPrompt, attrs)

def fail(): Future[Either[JsValue, ChatResponse]] = Left(Json.obj("error" -> "bad_request", "error_description" -> s"request content did not pass regex validation")).vfuture

val contents = originalPrompt.messages.map(_.content)
if (!contents.forall(content => validate(content))) {
fail()
} else {
pass()
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package com.cloud.apim.otoroshi.extensions.aigateway.decorators

import com.cloud.apim.otoroshi.extensions.aigateway.{ChatClient, ChatGeneration, ChatMessage, ChatPrompt, ChatResponse, ChatResponseMetadata}
import com.cloud.apim.otoroshi.extensions.aigateway.entities.AiProvider
import com.github.blemale.scaffeine.Scaffeine
import dev.langchain4j.data.segment.TextSegment
import dev.langchain4j.model.embedding.onnx.allminilml6v2.AllMiniLmL6V2EmbeddingModel
import dev.langchain4j.store.embedding.EmbeddingSearchRequest
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore
import otoroshi.utils.syntax.implicits._
import otoroshi.env.Env
import otoroshi.utils.TypedMap
import play.api.libs.json.JsValue

import scala.collection.concurrent.TrieMap
import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration.{DurationInt, FiniteDuration}
import scala.jdk.CollectionConverters._


object ChatClientWithSemanticCache {
val embeddingStores = new TrieMap[String, InMemoryEmbeddingStore[TextSegment]]()
val embeddingModel = new AllMiniLmL6V2EmbeddingModel()
val cache = Scaffeine()
.expireAfter[String, (FiniteDuration, ChatResponse)](
create = (key, value) => value._1,
update = (key, value, currentDuration) => currentDuration,
read = (key, value, currentDuration) => currentDuration
)
.maximumSize(5000)
.build[String, (FiniteDuration, ChatResponse)]()
def applyIfPossible(tuple: (AiProvider, ChatClient)): ChatClient = {
if (tuple._1.cacheStrategy.contains("semantic")) {
new ChatClientWithSemanticCache(tuple._1, tuple._2)
} else {
tuple._2
}
}
}

class ChatClientWithSemanticCache(originalProvider: AiProvider, chatClient: ChatClient) extends ChatClient {

private val ttl = originalProvider.ttl.getOrElse(24.hours)

override def call(originalPrompt: ChatPrompt, attrs: TypedMap)(implicit ec: ExecutionContext, env: Env): Future[Either[JsValue, ChatResponse]] = {
val query = originalPrompt.messages.filter(_.role.toLowerCase().trim == "user").map(_.content).mkString(", ")
val key = query.sha512
ChatClientWithSemanticCache.cache.getIfPresent(key) match {
case Some((_, response)) =>
println("using semantic cached response")
response.rightf
case None => {
val embeddingModel = ChatClientWithSemanticCache.embeddingModel
val embeddingStore = ChatClientWithSemanticCache.embeddingStores.getOrUpdate(originalProvider.id) {
new InMemoryEmbeddingStore[TextSegment]()
}
val queryEmbedding = embeddingModel.embed(query).content()
val relevant = embeddingStore.search(EmbeddingSearchRequest.builder().queryEmbedding(queryEmbedding).maxResults(1).build())
val matches = relevant.matches().asScala
if (matches.nonEmpty) {
val resp = matches.head
val id = resp.embeddingId()
val text = resp.embedded().text()
println("using semantic response")
val chatResponse = ChatResponse(Seq(ChatGeneration(ChatMessage("assistant", text))), ChatResponseMetadata.empty)
ChatClientWithSemanticCache.cache.put(key, (ttl, chatResponse))
chatResponse.rightf
} else {
chatClient.call(originalPrompt, attrs).map {
case Left(err) => err.left
case Right(resp) => {
val segment = TextSegment.from(resp.generations.head.message.content)
val embedding = embeddingModel.embed(segment).content()
embeddingStore.add(key, embedding, segment)
ChatClientWithSemanticCache.cache.put(key, (ttl, resp))
resp.right
}
}
}
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package com.cloud.apim.otoroshi.extensions.aigateway.decorators

import com.cloud.apim.otoroshi.extensions.aigateway.entities.AiProvider
import com.cloud.apim.otoroshi.extensions.aigateway.{ChatClient, ChatPrompt, ChatResponse}
import com.github.blemale.scaffeine.Scaffeine
import otoroshi.env.Env
import otoroshi.utils.TypedMap
import otoroshi.utils.syntax.implicits._
import play.api.libs.json.JsValue

import scala.concurrent.duration.{DurationInt, FiniteDuration}
import scala.concurrent.{ExecutionContext, Future}

object ChatClientWithSimpleCache {
val cache = Scaffeine()
.expireAfter[String, (FiniteDuration, ChatResponse)](
create = (key, value) => value._1,
update = (key, value, currentDuration) => currentDuration,
read = (key, value, currentDuration) => currentDuration
)
.maximumSize(5000)
.build[String, (FiniteDuration, ChatResponse)]()
def applyIfPossible(tuple: (AiProvider, ChatClient)): ChatClient = {
if (tuple._1.cacheStrategy.contains("simple")) {
new ChatClientWithSimpleCache(tuple._1, tuple._2)
} else {
tuple._2
}
}
}

class ChatClientWithSimpleCache(originalProvider: AiProvider, chatClient: ChatClient) extends ChatClient {

private val ttl = originalProvider.ttl.getOrElse(24.hours)

override def call(originalPrompt: ChatPrompt, attrs: TypedMap)(implicit ec: ExecutionContext, env: Env): Future[Either[JsValue, ChatResponse]] = {
val key = originalPrompt.messages.map(m => s"${m.role}:${m.content}").mkString(",").sha512
ChatClientWithSimpleCache.cache.getIfPresent(key) match {
case Some((_, response)) =>
println("using simple cache response")
response.rightf
case None => {
chatClient.call(originalPrompt, attrs).map {
case Left(err) => err.left
case Right(resp) => {
ChatClientWithSimpleCache.cache.put(key, (ttl, resp))
resp.right
}
}
}
}
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.cloud.apim.otoroshi.extensions.aigateway.entities

import com.cloud.apim.otoroshi.extensions.aigateway.decorators.ChatClientDecorators
import com.cloud.apim.otoroshi.extensions.aigateway.{ChatClient, ChatClientWithValidation}
import com.cloud.apim.otoroshi.extensions.aigateway.providers._
import otoroshi.api.{GenericResourceAccessApiWithState, Resource, ResourceVersion}
Expand Down Expand Up @@ -31,6 +32,8 @@ case class AiProvider(
deny: Seq[String] = Seq.empty,
validatorRef: Option[String] = None,
validatorPrompt: Option[String] = None,
cacheStrategy: Option[String],// = None,
ttl: Option[FiniteDuration],// = None,
) extends EntityLocationSupport {
override def internalId: String = id
override def json: JsValue = AiProvider.format.writes(this)
Expand All @@ -42,37 +45,38 @@ case class AiProvider(
val baseUrl = connection.select("base_url").asOpt[String]
val token = connection.select("token").asOpt[String].getOrElse("xxx")
val timeout = connection.select("timeout").asOpt[Long].map(FiniteDuration(_, TimeUnit.MILLISECONDS))
provider.toLowerCase() match {
val rawClient = provider.toLowerCase() match {
case "openai" => {
val api = new OpenAiApi(baseUrl.getOrElse(OpenAiApi.baseUrl), token, timeout.getOrElse(10.seconds), env = env)
val opts = OpenAiChatClientOptions.fromJson(options)
new ChatClientWithValidation(this, new OpenAiChatClient(api, opts, id)).some
new OpenAiChatClient(api, opts, id).some
}
case "azure-openai" => {
val resourceName = connection.select("resource_name").as[String]
val deploymentId = connection.select("deployment_id").as[String]
val apikey = connection.select("api_key").as[String]
val api = new AzureOpenAiApi(resourceName, deploymentId, apikey, timeout.getOrElse(10.seconds), env = env)
val opts = AzureOpenAiChatClientOptions.fromJson(options)
new ChatClientWithValidation(this, new AzureOpenAiChatClient(api, opts, id)).some
new AzureOpenAiChatClient(api, opts, id).some
}
case "mistral" => {
val api = new MistralAiApi(baseUrl.getOrElse(OpenAiApi.baseUrl), token, timeout.getOrElse(10.seconds), env = env)
val opts = MistralAiChatClientOptions.fromJson(options)
new ChatClientWithValidation(this, new MistralAiChatClient(api, opts, id)).some
new MistralAiChatClient(api, opts, id).some
}
case "ollama" => {
val api = new OllamaAiApi(baseUrl.getOrElse(OpenAiApi.baseUrl), token.some.filterNot(_ == "xxx"), timeout.getOrElse(10.seconds), env = env)
val opts = OllamaAiChatClientOptions.fromJson(options)
new ChatClientWithValidation(this, new OllamaAiChatClient(api, opts, id)).some
new OllamaAiChatClient(api, opts, id).some
}
case "anthropic" => {
val api = new AnthropicApi(baseUrl.getOrElse(AnthropicApi.baseUrl), token, timeout.getOrElse(10.seconds), env = env)
val opts = AnthropicChatClientOptions.fromJson(options)
new ChatClientWithValidation(this, new AnthropicChatClient(api, opts, id)).some
new AnthropicChatClient(api, opts, id).some
}
case _ => None
}
rawClient.map(c => ChatClientDecorators(this, c))
}
}

Expand Down
Loading

0 comments on commit b9b0317

Please sign in to comment.