Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 35 additions & 7 deletions crates/http-api-bindings/src/chat/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,49 @@ use async_openai_alt::config::OpenAIConfig;
use tabby_common::config::HttpModelConfig;
use tabby_inference::{ChatCompletionStream, ExtendedOpenAIConfig};

use super::multi::MultiChatStream;
use super::rate_limit;
use crate::{create_reqwest_client, AZURE_API_VERSION};

pub async fn create(model: &HttpModelConfig) -> Arc<dyn ChatCompletionStream> {
let mut multi_chat_stream = MultiChatStream::new();
add_engine(&mut multi_chat_stream, model);
Arc::new(multi_chat_stream)
}

pub async fn create_multiple(models: &[HttpModelConfig]) -> Arc<dyn ChatCompletionStream> {
let mut multi_chat_stream = MultiChatStream::new();
models.iter().for_each(|model| {
add_engine(&mut multi_chat_stream, model);
});
Arc::new(multi_chat_stream)
}

fn add_engine(multi_chat_stream: &mut MultiChatStream, model: &HttpModelConfig) {
let (model_title, model_name) = model.model_title_and_name();
let engine = Arc::new(rate_limit::new_chat(
create_engine(model),
model.rate_limit.request_per_minute,
));

// Handle model_name first just to set default_model to it
multi_chat_stream.add_chat_stream(model_title, model_name, engine.clone());

if let (Some(supported_models), Some(model_name)) = (&model.supported_models, &model.model_name)
{
for m in supported_models.iter().filter(|m| model_name != *m) {
multi_chat_stream.add_chat_stream(m, m, engine.clone());
}
}
}

fn create_engine(model: &HttpModelConfig) -> Box<dyn ChatCompletionStream> {
let api_endpoint = model
.api_endpoint
.as_deref()
.expect("api_endpoint is required");

let engine: Box<dyn ChatCompletionStream> = match model.kind.as_str() {
match model.kind.as_str() {
"azure/chat" => {
let config = async_openai_alt::config::AzureConfig::new()
.with_api_base(api_endpoint)
Expand Down Expand Up @@ -45,10 +78,5 @@ pub async fn create(model: &HttpModelConfig) -> Arc<dyn ChatCompletionStream> {
)
}
_ => panic!("Unsupported model kind: {}", model.kind),
};

Arc::new(rate_limit::new_chat(
engine,
model.rate_limit.request_per_minute,
))
}
}
2 changes: 2 additions & 0 deletions crates/http-api-bindings/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ mod chat;
mod completion;
mod embedding;
mod rate_limit;
mod multi;

pub use chat::create as create_chat;
pub use chat::create_multiple as create_multiple_chat;
pub use completion::{build_completion_prompt, create};
pub use embedding::create as create_embedding;

Expand Down
116 changes: 116 additions & 0 deletions crates/http-api-bindings/src/multi.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
use async_openai_alt::{
error::OpenAIError,
types::{
ChatCompletionResponseStream, CreateChatCompletionRequest, CreateChatCompletionResponse,
},
};
use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::Arc;
use tabby_inference::ChatCompletionStream;

struct ChatStreamWrapper {
model_name: String,
chat_stream: Arc<dyn ChatCompletionStream>,
}

impl ChatStreamWrapper {
fn new(model_name: String, chat_stream: Arc<dyn ChatCompletionStream>) -> Self {
Self {
model_name,
chat_stream,
}
}

fn process_request(
&self,
mut request: CreateChatCompletionRequest,
) -> CreateChatCompletionRequest {
request.model = self.model_name.clone();
request
}
}

#[async_trait]
impl ChatCompletionStream for ChatStreamWrapper {
async fn chat(
&self,
request: CreateChatCompletionRequest,
) -> Result<CreateChatCompletionResponse, OpenAIError> {
let request = self.process_request(request);
self.chat_stream.chat(request).await
}

async fn chat_stream(
&self,
request: CreateChatCompletionRequest,
) -> Result<ChatCompletionResponseStream, OpenAIError> {
let request = self.process_request(request);
self.chat_stream.chat_stream(request).await
}
}

pub struct MultiChatStream {
chat_streams: HashMap<String, Box<dyn ChatCompletionStream>>,

/// Provide a default value to handle the scenario when the request model is None,
/// which is usually the model value from the first [add_chat_stream]
default_model: Option<String>,
}

impl MultiChatStream {
pub fn new() -> MultiChatStream {
Self {
chat_streams: HashMap::new(),
default_model: None,
}
}

pub fn add_chat_stream(
&mut self,
model_title: impl Into<String>,
model: impl Into<String>,
completion: Arc<dyn ChatCompletionStream>,
) {
let model_title = model_title.into();
if self.default_model.is_none() {
self.default_model = Some(model_title.to_owned());
}
self.chat_streams.insert(
model_title,
Box::new(ChatStreamWrapper::new(model.into(), completion)),
);
}

fn get_chat_stream(&self, model: &str) -> Result<&Box<dyn ChatCompletionStream>, OpenAIError> {
let model = if model.is_empty() {
self.default_model
.as_ref()
.ok_or_else(|| OpenAIError::InvalidArgument("No available model".to_owned()))?
} else {
model
};
self.chat_streams
.get(model)
.ok_or_else(|| OpenAIError::InvalidArgument(format!("Model {} does not exist", model)))
}
}

#[async_trait]
impl ChatCompletionStream for MultiChatStream {
async fn chat(
&self,
request: CreateChatCompletionRequest,
) -> Result<CreateChatCompletionResponse, OpenAIError> {
let chat_stream = self.get_chat_stream(&request.model)?;
chat_stream.chat(request).await
}

async fn chat_stream(
&self,
request: CreateChatCompletionRequest,
) -> Result<ChatCompletionResponseStream, OpenAIError> {
let chat_stream = self.get_chat_stream(&request.model)?;
chat_stream.chat_stream(request).await
}
}
Loading
Loading