diff --git a/rig-core/src/providers/mod.rs b/rig-core/src/providers/mod.rs index 53ca700f5..2bba1e6c2 100644 --- a/rig-core/src/providers/mod.rs +++ b/rig-core/src/providers/mod.rs @@ -11,6 +11,7 @@ //! - DeepSeek //! - Azure OpenAI //! - Mira +//! - YandexGPT //! //! Each provider has its own module, which contains a `Client` implementation that can //! be used to initialize completion and embedding models and execute requests to those models. @@ -64,3 +65,4 @@ pub mod perplexity; pub mod together; pub mod voyageai; pub mod xai; +pub mod yandex; diff --git a/rig-core/src/providers/yandex.rs b/rig-core/src/providers/yandex.rs new file mode 100644 index 000000000..4943feee7 --- /dev/null +++ b/rig-core/src/providers/yandex.rs @@ -0,0 +1,333 @@ +//! YandexGPT OpenAI-compatible provider. +//! +//! This provider reuses the OpenAI-compatible request/response shapes with +//! a custom base URL and required `OpenAI-Project` header. The final model +//! identifier is assembled as `gpt:///` where the +//! folder ID is provided via [`ClientBuilder::folder`]. + +use crate::client::{ + self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder, + ProviderClient, +}; +use crate::completion::{self, CompletionError, CompletionRequest as CoreCompletionRequest}; +use crate::embeddings::{self, EmbeddingError}; +use crate::http_client; +use crate::http_client::HttpClientExt; +use crate::providers::openai; +use crate::wasm_compat::{WasmCompatSend, WasmCompatSync}; +use http::header::{HeaderName, HeaderValue}; + +const YANDEX_API_BASE_URL: &str = "https://llm.api.cloud.yandex.net/v1"; +#[allow(dead_code)] +const YANDEX_RESPONSES_API_BASE_URL: &str = "https://rest-assistant.api.cloud.yandex.net/v1"; + +/// `yandexgpt-lite/latest` text model. +pub const YANDEXGPT_LITE_LATEST: &str = "yandexgpt-lite/latest"; +/// `yandexgpt/latest` text model. +pub const YANDEXGPT_LATEST: &str = "yandexgpt/latest"; +/// `yandexgpt/rc` (YandexGPT 5.1) text model. +pub const YANDEXGPT_RC: &str = "yandexgpt/rc"; + +/// `text-search-doc/latest` embedding model. +pub const YANDEX_EMBED_TEXT_SEARCH_DOC: &str = "text-search-doc/latest"; +/// `text-search-query/latest` embedding model. +pub const YANDEX_EMBED_TEXT_SEARCH_QUERY: &str = "text-search-query/latest"; +/// `text-embeddings/latest` embedding model. +pub const YANDEX_EMBED_TEXT_EMBEDDINGS: &str = "text-embeddings/latest"; + +#[derive(Debug, Clone, Default)] +pub struct YandexExt { + folder: Option, +} + +#[derive(Debug, Clone, Default)] +pub struct YandexExtBuilder { + folder: Option, +} + +type YandexApiKey = BearerAuth; + +pub type Client = client::Client; +pub type ClientBuilder = + client::ClientBuilder; + +impl YandexExt { + fn qualify_completion_model(&self, model: impl Into) -> String { + let model = model.into(); + + if model.starts_with("gpt://") { + return model; + } + + match &self.folder { + Some(folder) => format!("gpt://{folder}/{model}"), + None => model, + } + } + + fn qualify_embedding_model(&self, model: impl Into) -> String { + let model = model.into(); + + if model.starts_with("emb://") { + return model; + } + + match &self.folder { + Some(folder) => format!("emb://{folder}/{model}"), + None => model, + } + } +} + +impl From for YandexExt { + fn from(value: YandexExtBuilder) -> Self { + Self { + folder: value.folder, + } + } +} + +impl DebugExt for YandexExt { + fn fields(&self) -> impl Iterator { + [("folder", (&self.folder as &dyn std::fmt::Debug))].into_iter() + } +} + +impl Provider for YandexExt { + type Builder = YandexExtBuilder; + + const VERIFY_PATH: &'static str = "/models"; + + fn build( + builder: &client::ClientBuilder, + ) -> http_client::Result { + Ok(builder.ext().clone().into()) + } +} + +impl Capabilities for YandexExt { + type Completion = Capable>; + type Embeddings = Capable>; + type Transcription = Nothing; + #[cfg(feature = "image")] + type ImageGeneration = Nothing; + #[cfg(feature = "audio")] + type AudioGeneration = Nothing; +} + +impl ProviderBuilder for YandexExtBuilder { + type Output = YandexExt; + type ApiKey = YandexApiKey; + + const BASE_URL: &'static str = YANDEX_API_BASE_URL; + + fn finish( + &self, + mut builder: client::ClientBuilder, + ) -> http_client::Result> { + if let Some(folder) = &self.folder { + builder.headers_mut().insert( + HeaderName::from_static("openai-project"), + HeaderValue::from_str(folder)?, + ); + } + + *builder.ext_mut() = self.clone(); + + Ok(builder) + } +} + +impl ClientBuilder { + /// Set the folder ID used for the `OpenAI-Project` header and model path. + pub fn folder(self, folder: impl Into) -> Self { + self.over_ext(|mut ext| { + ext.folder = Some(folder.into()); + ext + }) + } +} + +impl Client { + fn qualify_completion_model(&self, model: impl Into) -> String { + self.ext().qualify_completion_model(model) + } + + fn qualify_embedding_model(&self, model: impl Into) -> String { + self.ext().qualify_embedding_model(model) + } +} + +impl ProviderClient for Client { + type Input = YandexApiKey; + + /// Create a new YandexGPT client using `YANDEX_API_KEY` and optional `YANDEX_FOLDER_ID`. + fn from_env() -> Self { + let api_key = std::env::var("YANDEX_API_KEY").expect("YANDEX_API_KEY not set"); + let folder = std::env::var("YANDEX_FOLDER_ID").ok(); + let base_url = std::env::var("YANDEX_BASE_URL").ok(); + + let mut builder = Client::builder().api_key(api_key); + + if let Some(folder) = folder { + builder = builder.folder(folder); + } + + if let Some(base_url) = base_url { + builder = builder.base_url(base_url); + } + + builder.build().unwrap() + } + + fn from_val(input: Self::Input) -> Self { + Self::new(input).unwrap() + } +} + +fn to_openai_responses_client(client: &Client) -> openai::Client { + client::Client::from_parts( + client.base_url().to_string(), + client.headers().clone(), + client.http_client().clone(), + openai::client::OpenAIResponsesExt, + ) +} + +fn to_openai_completions_client(client: &Client) -> openai::CompletionsClient { + client::Client::from_parts( + client.base_url().to_string(), + client.headers().clone(), + client.http_client().clone(), + openai::client::OpenAICompletionsExt, + ) +} + +// ------------------------------------------------------------------ +// Completion wrapper +// ------------------------------------------------------------------ + +#[derive(Clone)] +pub struct CompletionModel { + inner: openai::CompletionModel, +} + +impl CompletionModel { + fn new(client: &Client, model: impl Into) -> Self + where + T: Clone + Default + std::fmt::Debug + 'static, + { + let inner = openai::CompletionModel::new( + to_openai_completions_client(client), + client.qualify_completion_model(model), + ); + + Self { inner } + } +} + +impl completion::CompletionModel for CompletionModel +where + T: HttpClientExt + + Default + + std::fmt::Debug + + Clone + + WasmCompatSend + + WasmCompatSync + + 'static, +{ + type Response = openai::CompletionResponse; + type StreamingResponse = openai::streaming::StreamingCompletionResponse; + + type Client = Client; + + fn make(client: &Self::Client, model: impl Into) -> Self { + Self::new(client, model) + } + + async fn completion( + &self, + completion_request: CoreCompletionRequest, + ) -> Result, CompletionError> { + self.inner.completion(completion_request).await + } + + async fn stream( + &self, + request: CoreCompletionRequest, + ) -> Result< + crate::streaming::StreamingCompletionResponse, + CompletionError, + > { + self.inner.stream(request).await + } +} + +// ------------------------------------------------------------------ +// Embedding wrapper +// ------------------------------------------------------------------ + +#[derive(Clone)] +pub struct EmbeddingModel { + inner: openai::EmbeddingModel, + ndims: usize, +} + +impl EmbeddingModel { + fn new(client: &Client, model: impl Into, ndims: usize) -> Self + where + T: Clone + Default + std::fmt::Debug, + { + let inner = openai::EmbeddingModel::new( + to_openai_responses_client(client), + client.qualify_embedding_model(model), + ndims, + ); + + Self { inner, ndims } + } +} + +impl embeddings::EmbeddingModel for EmbeddingModel +where + T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static, +{ + const MAX_DOCUMENTS: usize = 1024; + + type Client = Client; + + fn make(client: &Self::Client, model: impl Into, ndims: Option) -> Self { + Self::new(client, model, ndims.unwrap_or_default()) + } + + fn ndims(&self) -> usize { + self.ndims + } + + #[cfg_attr(feature = "worker", worker::send)] + async fn embed_texts( + &self, + documents: impl IntoIterator + crate::wasm_compat::WasmCompatSend, + ) -> Result, EmbeddingError> { + // Yandex embeddings endpoint only accepts one string per request. + // Run per-item calls and reassemble. + let docs: Vec = documents.into_iter().collect(); + let mut results = Vec::with_capacity(docs.len()); + + for doc in docs { + let mut single = self.inner.embed_texts(vec![doc.clone()]).await?; + let Some(embed) = single.pop() else { + return Err(EmbeddingError::ResponseError( + "Empty embedding response".to_string(), + )); + }; + + results.push(embeddings::Embedding { + document: doc, + vec: embed.vec, + }); + } + + Ok(results) + } +}