From 74abd46ac8ddf2ba86b54562d64cf31429d0e701 Mon Sep 17 00:00:00 2001 From: Mathieu Ancelin Date: Fri, 26 Jul 2024 16:42:35 +0200 Subject: [PATCH] fix #7, fix #8 --- .../extensions/ai/AiProvidersPage.js | 31 ++++++ .../aigateway/decorators/loadbalancer.scala | 105 ++++++++++++++++++ .../aigateway/entities/provider.scala | 3 +- 3 files changed, 138 insertions(+), 1 deletion(-) create mode 100644 src/main/scala/com/cloud/apim/otoroshi/extensions/aigateway/decorators/loadbalancer.scala diff --git a/src/main/resources/cloudapim/extensions/ai/AiProvidersPage.js b/src/main/resources/cloudapim/extensions/ai/AiProvidersPage.js index 4be930e..2062bf4 100644 --- a/src/main/resources/cloudapim/extensions/ai/AiProvidersPage.js +++ b/src/main/resources/cloudapim/extensions/ai/AiProvidersPage.js @@ -127,6 +127,7 @@ class AiProvidersPage extends Component { { 'label': 'Mistral', value: 'mistral' }, { 'label': 'Ollama', value: 'ollama' }, { 'label': 'Anthropic', value: 'anthropic' }, + { 'label': 'Loadbalancer', value: 'loadbalancer' }, ] } }, 'connection.resource_name': { @@ -264,6 +265,21 @@ class AiProvidersPage extends Component { { label: 'Semantic', value: 'semantic' }, ] }, }, + 'options.loadbalancing': { + type: 'select', + props: { + label: 'Load Balancing strategy', + possibleValues: [ + { label: 'Round robin', value: 'round_robin' }, + { label: 'Random', value: 'random' }, + { label: 'Best response time', value: 'best_response_time' }, + ] + }, + }, + 'options.ratio': { + type: 'number', + props: { label: 'TTL', suffix: 'millis.' }, + }, 'llm_validation.provider': { type: 'select', props: { @@ -343,6 +359,21 @@ class AiProvidersPage extends Component { 'metadata', ] } + if (state.provider === "loadbalancer") { + return [ + '_loc', 'id', 'name', 'description', + '<<>>Tester', + 'tester', + '>>>Metadata and tags', + 'tags', + 'metadata', + ]; + } if (state.provider === "ollama") { return [ '_loc', 'id', 'name', 'description', diff --git a/src/main/scala/com/cloud/apim/otoroshi/extensions/aigateway/decorators/loadbalancer.scala b/src/main/scala/com/cloud/apim/otoroshi/extensions/aigateway/decorators/loadbalancer.scala new file mode 100644 index 0000000..9f2690a --- /dev/null +++ b/src/main/scala/com/cloud/apim/otoroshi/extensions/aigateway/decorators/loadbalancer.scala @@ -0,0 +1,105 @@ +package com.cloud.apim.otoroshi.extensions.aigateway.decorators + +import com.cloud.apim.otoroshi.extensions.aigateway.entities.AiProvider +import com.cloud.apim.otoroshi.extensions.aigateway.{ChatClient, ChatPrompt, ChatResponse} +import otoroshi.env.Env +import otoroshi.utils.TypedMap +import otoroshi.utils.cache.types.UnboundedTrieMap +import otoroshi.utils.syntax.implicits._ +import otoroshi_plugins.com.cloud.apim.extensions.aigateway.AiExtension +import play.api.libs.json.{JsValue, Json} + +import java.util.concurrent.atomic.{AtomicInteger, AtomicLong} +import scala.concurrent.{ExecutionContext, Future} + +trait LoadBalancing { + def select(reqId: String, targets: Seq[AiProvider])(implicit env: Env): AiProvider +} + +object RoundRobin extends LoadBalancing { + private val reqCounter = new AtomicInteger(0) + override def select(reqId: String, targets: Seq[AiProvider])(implicit env: Env): AiProvider = { + val index: Int = reqCounter.incrementAndGet() % (if (targets.nonEmpty) targets.size else 1) + targets.apply(index) + } +} + +object Random extends LoadBalancing { + private val random = new scala.util.Random + override def select(reqId: String, targets: Seq[AiProvider])(implicit env: Env): AiProvider = { + val index = random.nextInt(targets.length) + targets.apply(index) + } +} + +case class AtomicAverage(count: AtomicLong, sum: AtomicLong) { + def incrBy(v: Long): Unit = { + count.incrementAndGet() + sum.addAndGet(v) + } + def average: Long = sum.get / count.get +} + +object BestResponseTime extends LoadBalancing { + + private[models] val random = new scala.util.Random + private[models] val responseTimes = new UnboundedTrieMap[String, AtomicAverage]() + + def incrementAverage(desc: AiProvider, responseTime: Long): Unit = { + val key = desc.id + val avg = responseTimes.getOrElseUpdate(key, AtomicAverage(new AtomicLong(0), new AtomicLong(0))) + avg.incrBy(responseTime) + } + + override def select(reqId: String, targets: Seq[AiProvider])(implicit env: Env): AiProvider = { + val keys = targets.map(t => t.id) + val existing = responseTimes.toSeq.filter(t => keys.exists(k => t._1 == k)) + val nonExisting: Seq[String] = keys.filterNot(k => responseTimes.contains(k)) + if (existing.size != targets.size) { + nonExisting.headOption.flatMap(h => targets.find(t => t.id == h)).getOrElse { + val index = random.nextInt(targets.length) + targets.apply(index) + } + } else { + val possibleTargets: Seq[(String, Long)] = existing.map(t => (t._1, t._2.average)) + val (key, _) = possibleTargets.minBy(_._2) + targets.find(t => t.id == key).getOrElse { + val index = random.nextInt(targets.length) + targets.apply(index) + } + } + } +} + +object LoadBalancerChatClient { + val counter = new AtomicLong(0L) +} + +class LoadBalancerChatClient(provider: AiProvider) extends ChatClient { + + override def call(prompt: ChatPrompt, attrs: TypedMap)(implicit ec: ExecutionContext, env: Env): Future[Either[JsValue, ChatResponse]] = { + val refs = provider.options.select("refs").asOpt[Seq[String]].getOrElse(Seq.empty) + val loadBalancing: LoadBalancing = provider.options.select("loadbalancing").asOpt[String].map(_.toLowerCase()).getOrElse("round_robin") match { + case "random" => Random + case "best_response_time" => BestResponseTime + case _ => RoundRobin + } + if (refs.isEmpty) { + Json.obj("error" -> "no provider configured").leftf + } else { + val providers: Seq[AiProvider] = refs.flatMap(r => env.adminExtensions.extension[AiExtension].flatMap(_.states.provider(r))) + // val index = LoadBalancerChatClient.counter.incrementAndGet() % (if (providers.nonEmpty) providers.size else 1) + val provider = loadBalancing.select(LoadBalancerChatClient.counter.incrementAndGet().toString, providers) + provider.getChatClient() match { + case None => Json.obj("error" -> "no client found").leftf + case Some(client) => { + val start = System.console() + client.call(prompt, attrs).map { resp => + BestResponseTime.incrementAverage(provider, System.currentTimeMillis() - start) + resp + } + } + } + } + } +} diff --git a/src/main/scala/com/cloud/apim/otoroshi/extensions/aigateway/entities/provider.scala b/src/main/scala/com/cloud/apim/otoroshi/extensions/aigateway/entities/provider.scala index 5af9979..a8fb300 100644 --- a/src/main/scala/com/cloud/apim/otoroshi/extensions/aigateway/entities/provider.scala +++ b/src/main/scala/com/cloud/apim/otoroshi/extensions/aigateway/entities/provider.scala @@ -1,7 +1,7 @@ package com.cloud.apim.otoroshi.extensions.aigateway.entities import com.cloud.apim.otoroshi.extensions.aigateway.ChatClient -import com.cloud.apim.otoroshi.extensions.aigateway.decorators.ChatClientDecorators +import com.cloud.apim.otoroshi.extensions.aigateway.decorators.{ChatClientDecorators, LoadBalancerChatClient} import com.cloud.apim.otoroshi.extensions.aigateway.providers._ import otoroshi.api.{GenericResourceAccessApiWithState, Resource, ResourceVersion} import otoroshi.env.Env @@ -94,6 +94,7 @@ case class AiProvider( val opts = AnthropicChatClientOptions.fromJson(options) new AnthropicChatClient(api, opts, id).some } + case "loadbalancer" => new LoadBalancerChatClient(this).some case _ => None } rawClient.map(c => ChatClientDecorators(this, c))