Skip to content

Commit

Permalink
Merge pull request #23 from cloud-apim/gemini-provider
Browse files Browse the repository at this point in the history
Gemini provider
  • Loading branch information
mathieuancelin authored Sep 19, 2024
2 parents 0e6e9f0 + a2d85ae commit 856fe66
Show file tree
Hide file tree
Showing 4 changed files with 200 additions and 0 deletions.
54 changes: 54 additions & 0 deletions src/main/resources/cloudapim/extensions/ai/AiProvidersPage.js
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,44 @@ class AiProvidersPage extends Component {
'metadata',
];
}
if (state.provider === "gemini") {
return [
'_loc', 'id', 'name', 'description',
'<<<Provider',
'provider',
'<<<API Connection',
'connection.model',
'connection.token',
'connection.timeout',
'<<<Connection options',
'options.maxOutputTokens',
'options.temperature',
'options.topP',
'options.topK',
'options.stopSequences',
'>>>Provider fallback',
'provider_fallback',
'>>>Cache',
'cache.strategy',
'cache.ttl',
state.cache.strategy === 'semantic' ? 'cache.score' : null,
'>>>Regex validation',
'regex_validation.allow',
'regex_validation.deny',
'>>>LLM Based validation',
'llm_validation.provider',
'llm_validation.prompt',
'>>>External validation',
'http_validation.url',
'http_validation.headers',
'http_validation.ttl',
'>>>Tester',
'tester',
'>>>Metadata and tags',
'tags',
'metadata',
];
}
if (state.provider === "azure-openai") {
return [
'_loc', 'id', 'name', 'description',
Expand Down Expand Up @@ -1003,6 +1041,22 @@ class AiProvidersPage extends Component {
},
options: ClientOptions.hugging,
});
} else if (state.provider === 'gemini') {
update({
id: state.id,
name: state.name,
description: state.description,
tags: state.tags,
metadata: state.metadata,
provider: 'gemini',
connection: {
base_url: BaseUrls.gemini,
model: 'model name',
token: 'xxx',
timeout: 10000,
},
options: ClientOptions.gemini,
});
} else if (state.provider === 'azure-openai') {
update({
id: state.id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ case class AiProvider(
val opts = CloudflareChatClientOptions.fromJson(options)
new CloudflareChatClient(api, opts, id).some
}
case "gemini" => {
val model = connection.select("model").as[String]
val api = new GeminiApi(model, token, timeout.getOrElse(10.seconds), env = env)
val opts = GeminiChatClientOptions.fromJson(options)
new GeminiChatClient(api, opts, id).some
}
case "mistral" => {
val api = new MistralAiApi(baseUrl.getOrElse(MistralAiApi.baseUrl), token, timeout.getOrElse(10.seconds), env = env)
val opts = MistralAiChatClientOptions.fromJson(options)
Expand Down Expand Up @@ -271,6 +277,21 @@ object AiProvider {
),
options = HuggingFaceChatClientOptions().json
).json
case Some("gemini") => AiProvider(
id = IdGenerator.namedId("provider", env),
name = "Gemini provider",
description = "A Gemini LLM api provider",
metadata = Map.empty,
tags = Seq.empty,
location = EntityLocation.default,
provider = "gemini",
connection = Json.obj(
"model" -> GeminiModels.GEMINI_1_5_FLASH,
"token" -> "xxxxx",
"timeout" -> 10000,
),
options = MistralAiChatClientOptions().json
).json
case _ => AiProvider(
id = IdGenerator.namedId("provider", env),
name = "OpenAI provider",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ class AiExtension(val env: Env) extends AdminExtension {
| mistral: ${MistralAiChatClientOptions().json.stringify},
| ollama: ${OllamaAiChatClientOptions().json.stringify},
| groq: ${GroqChatClientOptions().json.stringify},
| gemini: ${GeminiChatClientOptions().json.stringify},
| 'azure-openai': ${AzureOpenAiChatClientOptions().json.stringify},
| 'cohere': ${CohereAiChatClientOptions().json.stringify},
| ovh: ${OVHAiEndpointsChatClientOptions().json.stringify},
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
package com.cloud.apim.otoroshi.extensions.aigateway.providers

import com.cloud.apim.otoroshi.extensions.aigateway._
import dev.langchain4j.data.segment.TextSegment
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore
import otoroshi.env.Env
import otoroshi.utils.TypedMap
import otoroshi.utils.syntax.implicits._
import play.api.libs.json.{JsObject, JsValue, Json}
import dev.langchain4j.model.embedding.onnx.allminilml6v2.AllMiniLmL6V2EmbeddingModel
import dev.langchain4j.store.embedding.EmbeddingSearchRequest

import scala.concurrent.duration.{DurationInt, FiniteDuration}
import scala.concurrent.{ExecutionContext, Future}
import scala.jdk.CollectionConverters.asScalaBufferConverter

case class GeminiApiResponse(status: Int, headers: Map[String, String], body: JsValue) {
def json: JsValue = Json.obj(
"status" -> status,
"headers" -> headers,
"body" -> body,
)
}
object GeminiModels {
val GEMINI_1_5_FLASH = "gemini-1.5-flash"
}

object GeminiApi {
def url(model: String, token: String): String = {
s"https://generativelanguage.googleapis.com/v1beta/models/${model}:generateContent?key=${token}"
}
}

class GeminiApi(model: String, token: String, timeout: FiniteDuration = 10.seconds, env: Env) {

def call(method: String, body: Option[JsValue])(implicit ec: ExecutionContext): Future[GeminiApiResponse] = {
env.Ws
.url(s"${GeminiApi.url(model, token)}")
.withHttpHeaders(
"Accept" -> "application/json",
).applyOnWithOpt(body) {
case (builder, body) => builder
.addHttpHeaders("Content-Type" -> "application/json")
.withBody(body)
}
.withMethod(method)
.withRequestTimeout(timeout)
.execute()
.map { resp =>
GeminiApiResponse(resp.status, resp.headers.mapValues(_.last), resp.json)
}
}
}

object GeminiChatClientOptions {
def fromJson(json: JsValue): GeminiChatClientOptions = {
GeminiChatClientOptions(
maxOutputTokens = json.select("maxOutputTokens").asOpt[Int],
temperature = json.select("temperature").asOpt[Float].getOrElse(1.0f),
topP = json.select("topP").asOpt[Float].getOrElse(0.95f),
topK = json.select("top_k").asOpt[Int].getOrElse(40),
stopSequences = json.select("stopSequences").asOpt[Array[String]],
)
}
}

case class GeminiChatClientOptions(
maxOutputTokens: Option[Int] = None,
temperature: Float = 1,
topP: Float = 0.95f,
topK: Int = 1,
stopSequences: Option[Array[String]] = None
) extends ChatOptions {
override def json: JsObject = Json.obj(
"maxOutputTokens" -> maxOutputTokens,
"temperature" -> temperature,
"topP" -> topP,
"topK" -> topK,
"stopSequences" -> stopSequences
)
}

class GeminiChatClient(api: GeminiApi, options: GeminiChatClientOptions, id: String) extends ChatClient {

override def call(prompt: ChatPrompt, attrs: TypedMap)(implicit ec: ExecutionContext, env: Env): Future[Either[JsValue, ChatResponse]] = {
val mergedOptions = options.json.deepMerge(prompt.options.map(_.json).getOrElse(Json.obj()))
api.call("POST", Some(mergedOptions ++ Json.obj(
"contents" -> Json.obj("parts" -> prompt.json),
"generationConfig" -> options.json
))).map { resp =>
val usage = ChatResponseMetadata(
ChatResponseMetadataRateLimit.empty,
ChatResponseMetadataUsage(
promptTokens = resp.body.select("usageMetadata").select("promptTokenCount").asOpt[Long].getOrElse(-1L),
generationTokens = resp.body.select("usageMetadata").select("totalTokenCount").asOpt[Long].getOrElse(-1L),
),
)
val duration: Long = resp.headers.getIgnoreCase("gemini-processing-ms").map(_.toLong).getOrElse(0L)
val slug = Json.obj(
"provider_kind" -> "gemini",
"provider" -> id,
"duration" -> duration,
"rate_limit" -> usage.rateLimit.json,
"usage" -> usage.usage.json
)
attrs.update(ChatClient.ApiUsageKey -> usage)
attrs.update(otoroshi.plugins.Keys.ExtraAnalyticsDataKey) {
case Some(obj@JsObject(_)) => {
val arr = obj.select("ai").asOpt[Seq[JsObject]].getOrElse(Seq.empty)
val newArr = arr ++ Seq(slug)
obj ++ Json.obj("ai" -> newArr)
}
case Some(other) => other
case None => Json.obj("ai" -> Seq(slug))
}
val messages = resp.body.select("candidates").asOpt[Seq[JsObject]].getOrElse(Seq.empty).map { obj =>
val role = obj.select("content").select("role").asOpt[String].getOrElse("user")
val content = obj.select("content").select("parts").asOpt[Seq[String]].getOrElse(Seq.empty).mkString(" ")
ChatGeneration(ChatMessage(role, content))
}
Right(ChatResponse(messages, usage))
}
}
}

0 comments on commit 856fe66

Please sign in to comment.