Skip to content

Commit

Permalink
fix #16, fix #14, fix #4
Browse files Browse the repository at this point in the history
  • Loading branch information
mathieuancelin committed Jul 26, 2024
1 parent b9b0317 commit 8a17218
Show file tree
Hide file tree
Showing 9 changed files with 262 additions and 68 deletions.
140 changes: 106 additions & 34 deletions src/main/resources/cloudapim/extensions/ai/AiProvidersPage.js
Original file line number Diff line number Diff line change
Expand Up @@ -232,15 +232,39 @@ class AiProvidersPage extends Component {
type: 'number',
props: { label: 'Context size' },
},
'deny': {
'regex_validation.deny': {
type: 'array',
props: { label: 'Deny', suffix: 'regex' },
},
'allow': {
'regex_validation.allow': {
type: 'array',
props: { label: 'Allow', suffix: 'regex' },
},
'validator_ref': {
'http_validation.url': {
type: 'string',
props: { label: 'URL' },
},
'http_validation.headers': {
type: 'object',
props: { label: 'Headers' },
},
'http_validation.ttl': {
type: 'object',
props: { label: 'TTL', suffix: 'millis.' },
},
'cache.ttl': {
type: 'number',
props: { label: 'TTL', suffix: 'millis.' },
},
'cache.strategy': {
type: 'select',
props: { label: 'Cache strategy', possibleValues: [
{ label: 'None', value: 'none' },
{ label: 'Simple', value: 'simple' },
{ label: 'Semantic', value: 'semantic' },
] },
},
'llm_validation.provider': {
type: 'select',
props: {
label: 'Validator provider',
Expand All @@ -252,7 +276,7 @@ class AiProvidersPage extends Component {
}),
}
},
'validator_prompt': {
'llm_validation.prompt': {
type: 'select',
props: {
label: 'Validator prompt',
Expand Down Expand Up @@ -284,11 +308,19 @@ class AiProvidersPage extends Component {
'description',
'<<<Provider',
'provider',
'<<<Validation',
'allow',
'deny',
'validator_ref',
'validator_prompt',
'>>>Cache',
'cache.strategy',
'cache.ttl',
'>>>Regex validation',
'regex_validation.allow',
'regex_validation.deny',
'>>>LLM Based validation',
'llm_validation.provider',
'llm_validation.prompt',
'>>>External validation',
'http_validation.url',
'http_validation.headers',
'http_validation.ttl',
'>>>Metadata and tags',
'tags',
'metadata',
Expand Down Expand Up @@ -316,11 +348,19 @@ class AiProvidersPage extends Component {
'options.num_gpu',
'options.num_gqa',
'options.num_ctx',
'<<<Validation',
'allow',
'deny',
'validator_ref',
'validator_prompt',
'>>>Cache',
'cache.strategy',
'cache.ttl',
'>>>Regex validation',
'regex_validation.allow',
'regex_validation.deny',
'>>>LLM Based validation',
'llm_validation.provider',
'llm_validation.prompt',
'>>>External validation',
'http_validation.url',
'http_validation.headers',
'http_validation.ttl',
'>>>Tester',
'tester',
'>>>Metadata and tags',
Expand All @@ -344,11 +384,19 @@ class AiProvidersPage extends Component {
'options.safe_prompt',
'options.temperature',
'options.top_p',
'<<<Validation',
'allow',
'deny',
'validator_ref',
'validator_prompt',
'>>>Cache',
'cache.strategy',
'cache.ttl',
'>>>Regex validation',
'regex_validation.allow',
'regex_validation.deny',
'>>>LLM Based validation',
'llm_validation.provider',
'llm_validation.prompt',
'>>>External validation',
'http_validation.url',
'http_validation.headers',
'http_validation.ttl',
'>>>Tester',
'tester',
'>>>Metadata and tags',
Expand All @@ -371,11 +419,19 @@ class AiProvidersPage extends Component {
'options.temperature',
'options.top_p',
'options.top_k',
'<<<Validation',
'allow',
'deny',
'validator_ref',
'validator_prompt',
'>>>Cache',
'cache.strategy',
'cache.ttl',
'>>>Regex validation',
'regex_validation.allow',
'regex_validation.deny',
'>>>LLM Based validation',
'llm_validation.provider',
'llm_validation.prompt',
'>>>External validation',
'http_validation.url',
'http_validation.headers',
'http_validation.ttl',
'>>>Tester',
'tester',
'>>>Metadata and tags',
Expand All @@ -398,11 +454,19 @@ class AiProvidersPage extends Component {
'options.n',
'options.temperature',
'options.topP',
'<<<Validation',
'allow',
'deny',
'validator_ref',
'validator_prompt',
'>>>Cache',
'cache.strategy',
'cache.ttl',
'>>>Regex validation',
'regex_validation.allow',
'regex_validation.deny',
'>>>LLM Based validation',
'llm_validation.provider',
'llm_validation.prompt',
'>>>External validation',
'http_validation.url',
'http_validation.headers',
'http_validation.ttl',
'>>>Tester',
'tester',
'>>>Metadata and tags',
Expand All @@ -424,11 +488,19 @@ class AiProvidersPage extends Component {
'options.n',
'options.temperature',
'options.topP',
'<<<Validation',
'allow',
'deny',
'validator_ref',
'validator_prompt',
'>>>Cache',
'cache.strategy',
'cache.ttl',
'>>>Regex validation',
'regex_validation.allow',
'regex_validation.deny',
'>>>LLM Based validation',
'llm_validation.provider',
'llm_validation.prompt',
'>>>External validation',
'http_validation.url',
'http_validation.headers',
'http_validation.ttl',
'>>>Tester',
'tester',
'>>>Metadata and tags',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ object ChatClientDecorators {
val possibleDecorators: Seq[Function[(AiProvider, ChatClient), ChatClient]] = Seq(
ChatClientWithRegexValidation.applyIfPossible,
ChatClientWithLlmValidation.applyIfPossible,
ChatClientWithSimpleCache.applyIfPossible,
ChatClientWithSemanticCache.applyIfPossible,
ChatClientWithSemanticCache.applyIfPossible,
ChatClientWithHttpValidation.applyIfPossible,
)

def apply(provider: AiProvider, client: ChatClient): ChatClient = {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,64 @@
package com.cloud.apim.otoroshi.extensions.aigateway.decorators

class externalvalidation {
import com.cloud.apim.otoroshi.extensions.aigateway.{ChatClient, ChatPrompt, ChatResponse}
import com.cloud.apim.otoroshi.extensions.aigateway.entities.AiProvider
import com.github.blemale.scaffeine.Scaffeine
import otoroshi.env.Env
import otoroshi.utils.TypedMap
import play.api.libs.json.{JsValue, Json}
import otoroshi.utils.syntax.implicits._

import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration.FiniteDuration

object ChatClientWithHttpValidation {
val cache = Scaffeine()
.expireAfter[String, (FiniteDuration, Boolean)](
create = (key, value) => value._1,
update = (key, value, currentDuration) => currentDuration,
read = (key, value, currentDuration) => currentDuration
)
.maximumSize(5000)
.build[String, (FiniteDuration, Boolean)]()
def applyIfPossible(tuple: (AiProvider, ChatClient)): ChatClient = {
if (tuple._1.httpValidation.url.isDefined) {
new ChatClientWithHttpValidation(tuple._1, tuple._2)
} else {
tuple._2
}
}
}

class ChatClientWithHttpValidation(originalProvider: AiProvider, chatClient: ChatClient) extends ChatClient {

private val ttl = originalProvider.httpValidation.ttl


override def call(originalPrompt: ChatPrompt, attrs: TypedMap)(implicit ec: ExecutionContext, env: Env): Future[Either[JsValue, ChatResponse]] = {

def pass(): Future[Either[JsValue, ChatResponse]] = chatClient.call(originalPrompt, attrs)

def fail(): Future[Either[JsValue, ChatResponse]] = Left(Json.obj("error" -> "bad_request", "error_description" -> s"request content did not pass http validation")).vfuture

val key = originalPrompt.messages.map(m => s"${m.role}:${m.content}").mkString(",").sha512
ChatClientWithHttpValidation.cache.getIfPresent(key) match {
case Some((_, true)) => pass()
case Some((_, false)) => fail()
case None => {
env.Ws
.url(originalProvider.httpValidation.url.get)
.withHttpHeaders(originalProvider.httpValidation.headers.toSeq: _*)
.post(originalPrompt.json).flatMap { resp =>
if (resp.status != 200) {
ChatClientWithHttpValidation.cache.put(key, (ttl, false))
fail()
} else {
val value = resp.json.select("result").asOpt[Boolean].getOrElse(false)
ChatClientWithHttpValidation.cache.put(key, (ttl, value))
pass()
}
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import scala.concurrent.{ExecutionContext, Future}

object ChatClientWithLlmValidation {
def applyIfPossible(tuple: (AiProvider, ChatClient)): ChatClient = {
if (tuple._1.validatorRef.isDefined && tuple._1.validatorPrompt.isDefined) {
if (tuple._1.llmValidation.provider.isDefined && tuple._1.llmValidation.prompt.isDefined) {
new ChatClientWithLlmValidation(tuple._1, tuple._2)
} else {
tuple._2
Expand All @@ -28,11 +28,11 @@ class ChatClientWithLlmValidation(originalProvider: AiProvider, chatClient: Chat

def fail(idx: Int): Future[Either[JsValue, ChatResponse]] = Left(Json.obj("error" -> "bad_request", "error_description" -> s"request content did not pass llm validation (${idx})")).vfuture

originalProvider.validatorRef match {
originalProvider.llmValidation.provider match {
case None => pass()
case Some(ref) if ref == originalProvider.id => pass()
case Some(ref) => {
originalProvider.validatorPrompt match {
originalProvider.llmValidation.prompt match {
case None => pass()
case Some(pref) => env.adminExtensions.extension[AiExtension].flatMap(_.states.prompt(pref)) match {
case None => Left(Json.obj("error" -> "validation prompt not found")).vfuture
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import scala.concurrent.{ExecutionContext, Future}

object ChatClientWithRegexValidation {
def applyIfPossible(tuple: (AiProvider, ChatClient)): ChatClient = {
if (tuple._1.allow.nonEmpty || tuple._1.deny.nonEmpty) {
if (tuple._1.regexValidation.allow.nonEmpty || tuple._1.regexValidation.deny.nonEmpty) {
new ChatClientWithRegexValidation(tuple._1, tuple._2)
} else {
tuple._2
Expand All @@ -21,8 +21,8 @@ object ChatClientWithRegexValidation {

class ChatClientWithRegexValidation(originalProvider: AiProvider, chatClient: ChatClient) extends ChatClient {

private val allow = originalProvider.allow
private val deny = originalProvider.deny
private val allow = originalProvider.regexValidation.allow
private val deny = originalProvider.regexValidation.deny

private def validate(content: String): Boolean = {
val allowed = if (allow.isEmpty) true else allow.exists(al => RegexPool.regex(al).matches(content))
Expand Down
Loading

0 comments on commit 8a17218

Please sign in to comment.