Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Baidu Qianfan/Ernie backend #287

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,105 changes: 1,022 additions & 83 deletions Cargo.lock

Large diffs are not rendered by default.

24 changes: 24 additions & 0 deletions crates/llm-chain-ernie/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
[package]
name = "llm-chain-ernie"
version = "0.1.0"
edition = "2021"
description = "Use `llm-chain` with Baidu Qianfan(Ernie/Wenxin) platform."
license = "MIT"
keywords = ["llm", "langchain", "chain"]
categories = ["science"]
authors = ["Wanqi Chen <[email protected]>"]
repository = "https://github.com/sobelio/llm-chain/"


# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]

llm-chain = { path = "../llm-chain", version = "0.13.0", default-features = false }
erniebot-rs = "0.4.1"
async-trait = "0.1.68"
serde = { version = "1.0.163", features = ["derive"] }
strum = "0.26.2"
strum_macros = "0.26.2"
tokio = { version = "1.28.0", features = ["rt-multi-thread"] }
tokio-stream = "0.1.14"
21 changes: 21 additions & 0 deletions crates/llm-chain-ernie/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# llm_chain_ernie

Baidu Qianfan (also referred to as ernie/Wenxin) platform integration. This enables you to seamlessly access and utilize models hosted on the Baidu Qianfan platform.

Powered by the [[erniebot-rs](https://github.com/chenwanqq/erniebot-rs)] RUST SDK, this integration provides a smooth bridge between your applications and Baidu's AI capabilities. Please note that both this integration and erniebot-rs are community-supported and *not* officially endorsed by Baidu.

Currently, this integration primarily supports chat models. However, future development plans include adding support for embedding models and conducting extensive testing to ensure compatibility and performance across a wide range of use cases.

## Getting Started

1. **Set up Baidu AI Cloud Platform**: Begin by following [[this detailed guide](https://cloud.baidu.com/doc/WENXINWORKSHOP/s/7ltgucw50)] to set up your account and access the necessary services.

2. **Configure Environment Variables**: Before running any applications, ensure you have exported your QIANFAN_AK and QIANFAN_SK as environment variables.

```bash
export QIANFAN_AK=<your_access_key>
export QIANFAN_SK=<your_secret_key>
```
3. **Follow the Example**: Refer to the example provided in [simple_generator.rs](./examples/simple_generator.rs). The library includes predefined models such as ErnieBot, ErnieBotTurbo, and Ernie40. However, you have the flexibility to use other models as well.

To utilize a different model, simply identify its name from the API path. For instance, in the URL https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-8k-preview, the model name is ernie-4.0-8k-preview. Use this name to specify the desired model in your application code.
39 changes: 39 additions & 0 deletions crates/llm-chain-ernie/examples/simple_generator.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
use llm_chain::executor;
use llm_chain::options;
use llm_chain::options::Options;
use std::{env::args, error::Error};

use llm_chain::{prompt::Data, traits::Executor};

extern crate llm_chain_ernie;
use llm_chain_ernie::model::Model;

/// This example demonstrates how to use the llm-chain-ernie crate to generate text.
///
/// Usage: before running this example code, you need to export QIANFAN_AK and QIANFAN_SK:
/// export QIANFAN_AK = <YOUR_AK>
/// export QIANFAN_SK = <YOUR_SK>
/// cargo run --release --package llm-chain-ernie --example simple_generator <optional prompt>
///
#[tokio::main(flavor = "current_thread")]
async fn main() -> Result<(), Box<dyn Error>> {
let raw_args: Vec<String> = args().collect();
let prompt = match &raw_args.len() {
1 => "Rust is a cool programming language because",
2 => raw_args[1].as_str(),
_ => panic!("Usage: cargo run --release --example simple <optional prompt>"),
};

let opts = options!(
Model: Model::ErnieBotTurbo,
MaxTokens: 50usize,
Temperature: 0.8
);
let exec = executor!(ernie_endpoint, opts)?;
let res = exec
.execute(Options::empty(), &Data::Text(String::from(prompt)))
.await?;

println!("{}", res);
Ok(())
}
165 changes: 165 additions & 0 deletions crates/llm-chain-ernie/src/executor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
use async_trait::async_trait;
use erniebot_rs::chat::{ChatEndpoint, ChatOpt};
use llm_chain::options::{Opt, Options, OptionsCascade};
use llm_chain::output::{Output, StreamSegment};
use llm_chain::prompt::Prompt;
use llm_chain::tokens::{
PromptTokensError, TokenCount, Tokenizer as TokenizerTrait, TokenizerError,
};
use llm_chain::traits::{Executor as ExecutorTrait, ExecutorCreationError, ExecutorError};
use tokio;
use tokio_stream::StreamExt;

use crate::model::Model;
use crate::prompt::create_message;

#[derive(Clone)]
pub struct Executor {
options: Options,
}

impl Executor {
fn get_model_from_invocation_options(&self, opts: &OptionsCascade) -> String {
let Some(Opt::Model(model)) = opts.get(llm_chain::options::OptDiscriminants::Model) else {
return Model::ErnieBotTurbo.to_string();
};
model.to_name()
}

fn cascade<'a>(&'a self, opts: Option<&'a Options>) -> OptionsCascade<'a> {
let mut v: Vec<&'a Options> = vec![&self.options];
if let Some(o) = opts {
v.push(o);
}
OptionsCascade::from_vec(v)
}
/// transform the options into a vector of ChatOpts, to be used in the chat endpoint
fn option_transform(&self, opts: &OptionsCascade) -> Vec<ChatOpt> {
let mut chat_opts = Vec::new();
// Below code is so weird. Is there a method to enumerate options?
if let Some(Opt::Temperature(temp)) =
opts.get(llm_chain::options::OptDiscriminants::Temperature)
{
chat_opts.push(ChatOpt::Temperature(*temp));
}
if let Some(Opt::TopK(top_k)) = opts.get(llm_chain::options::OptDiscriminants::TopK) {
chat_opts.push(ChatOpt::TopK(*top_k as u32));
}
if let Some(Opt::TopP(top_p)) = opts.get(llm_chain::options::OptDiscriminants::TopP) {
chat_opts.push(ChatOpt::TopP(*top_p));
}
if let Some(Opt::RepeatPenalty(repeat_penalty)) =
opts.get(llm_chain::options::OptDiscriminants::RepeatPenalty)
{
chat_opts.push(ChatOpt::PenaltyScore(*repeat_penalty));
}
if let Some(Opt::StopSequence(stop_sequence)) =
opts.get(llm_chain::options::OptDiscriminants::StopSequence)
{
chat_opts.push(ChatOpt::Stop(stop_sequence.clone()));
}
if let Some(Opt::MaxTokens(max_tokens)) =
opts.get(llm_chain::options::OptDiscriminants::MaxTokens)
{
chat_opts.push(ChatOpt::MaxOutputTokens(*max_tokens as u32));
}
chat_opts
}
}

#[async_trait]
impl ExecutorTrait for Executor {
type StepTokenizer<'a> = ErnieTokenizer;
fn new_with_options(options: Options) -> Result<Self, ExecutorCreationError> {
Ok(Executor { options })
}

async fn execute(&self, options: &Options, prompt: &Prompt) -> Result<Output, ExecutorError> {
let opts = self.cascade(Some(options));
let model = self.get_model_from_invocation_options(&opts);
let chat_endpoint =
if let Ok(chat_endpoint) = ChatEndpoint::new_with_custom_endpoint(&model.to_string()) {
chat_endpoint
} else {
return Err(ExecutorError::InvalidOptions);
};
let chat_opts = self.option_transform(&opts);
let messages = create_message(prompt);
if opts.is_streaming() {
let mut stream_response = chat_endpoint
.astream(&messages, &chat_opts)
.await
.map_err(|e| ExecutorError::InnerError(Box::new(e)))?;
let (sender, result_stream) = Output::new_stream();
tokio::spawn(async move {
while let Some(chunk) = stream_response.next().await {
let segment = match chunk.get_chat_result() {
Ok(result) => StreamSegment::Content(result),
Err(e) => StreamSegment::Err(ExecutorError::InnerError(Box::new(e))),
};
if sender.send(segment).is_err() {
break;
}
}
});
Ok(result_stream)
} else {
let response = chat_endpoint
.ainvoke(&messages, &chat_opts)
.await
.map_err(|e| ExecutorError::InnerError(Box::new(e)))?;
let chat_result = response
.get_chat_result()
.map_err(|e| ExecutorError::InnerError(Box::new(e)))?;
Ok(Output::new_immediate(Prompt::text(chat_result)))
}
}

fn tokens_used(
&self,
_options: &Options,
_prompt: &Prompt,
) -> Result<TokenCount, PromptTokensError> {
// Not all models expose this information.
unimplemented!();
}

fn max_tokens_allowed(&self, _: &Options) -> i32 {
// Not all models expose this information.
unimplemented!();
}

fn answer_prefix(&self, _prompt: &Prompt) -> Option<String> {
// Not all models expose this information.
unimplemented!();
}

fn get_tokenizer(&self, _: &Options) -> Result<Self::StepTokenizer<'_>, TokenizerError> {
// Not all models expose this information.
unimplemented!();
}
}

pub struct ErnieTokenizer {}

impl ErnieTokenizer {
pub fn new(_executor: &Executor) -> Self {
ErnieTokenizer {}
}
}

impl TokenizerTrait for ErnieTokenizer {
fn tokenize_str(
&self,
_doc: &str,
) -> Result<llm_chain::tokens::TokenCollection, llm_chain::tokens::TokenizerError> {
unimplemented!()
}

fn to_string(
&self,
_tokens: llm_chain::tokens::TokenCollection,
) -> Result<String, llm_chain::tokens::TokenizerError> {
unimplemented!()
}
}
5 changes: 5 additions & 0 deletions crates/llm-chain-ernie/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pub mod executor;
pub mod model;
mod prompt;

pub use executor::Executor;
67 changes: 67 additions & 0 deletions crates/llm-chain-ernie/src/model.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
use llm_chain::options::{ModelRef, Opt};
use serde::{Deserialize, Serialize};
use strum_macros::{Display, EnumString};

#[derive(Debug, Default, Clone, Serialize, Deserialize, EnumString, Display, PartialEq, Eq)]
#[non_exhaustive]
pub enum Model {
/// Ernie 3.5 turbo
#[default]
#[strum(serialize = "eb-instant")]
#[serde(rename = "eb-instant")]
ErnieBotTurbo,
/// ernie 3.5
#[strum(serialize = "completions")]
#[serde(rename = "completions")]
ErnieBot,
/// ernie 4.0
#[strum(serialize = "completions_pro")]
#[serde(rename = "completions_pro")]
Ernie40,
/// A variant that allows you to specify a custom model name as a string, in case new models
/// are introduced or you have access to specialized models.
#[strum(default)]
Other(String),
}

/// Conversion from Model to ModelRef
impl From<Model> for ModelRef {
fn from(value: Model) -> Self {
ModelRef::from_model_name(value.to_string())
}
}

/// Conversion from Model to Option
impl From<Model> for Opt {
fn from(value: Model) -> Self {
Opt::Model(value.into())
}
}

#[cfg(test)]
mod tests {
#[test]
fn test_model_to_string() {
use super::Model;
assert_eq!(Model::ErnieBotTurbo.to_string(), "eb-instant");
assert_eq!(Model::ErnieBot.to_string(), "completions");
assert_eq!(Model::Ernie40.to_string(), "completions_pro");
assert_eq!(
Model::Other("your_custom_model_name".to_string()).to_string(),
"your_custom_model_name"
);
}

#[test]
fn test_model_from_string() {
use super::Model;
use std::str::FromStr;
assert_eq!(Model::from_str("eb-instant").unwrap(), Model::ErnieBotTurbo);
assert_eq!(Model::from_str("completions").unwrap(), Model::ErnieBot);
assert_eq!(Model::from_str("completions_pro").unwrap(), Model::Ernie40);
assert_eq!(
Model::from_str("your_custom_model_name").unwrap(),
Model::Other("your_custom_model_name".to_string())
);
}
}
31 changes: 31 additions & 0 deletions crates/llm-chain-ernie/src/prompt.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
use erniebot_rs::chat::{Message, Role};
use llm_chain::prompt::Prompt;

/// Creates a chat message from a prompt.
pub fn create_message(prompt: &Prompt) -> Vec<Message> {
match prompt {
Prompt::Text(text) => vec![Message {
role: Role::User,
content: text.clone(),
..Default::default()
}],
Prompt::Chat(chat) => {
let mut messages = Vec::new();
for message in chat.iter() {
let role = match message.role() {
llm_chain::prompt::ChatRole::User => Role::User,
llm_chain::prompt::ChatRole::Assistant => Role::Assistant,
llm_chain::prompt::ChatRole::System => Role::Assistant, // ernie doesn't have a system role
llm_chain::prompt::ChatRole::Other(_) => todo!(),
};
let content = message.body();
messages.push(Message {
role,
content: content.clone(),
..Default::default()
});
}
messages
}
}
}
4 changes: 4 additions & 0 deletions crates/llm-chain/src/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,4 +80,8 @@ macro_rules! executor {
use llm_chain::traits::Executor;
llm_chain_sagemaker_endpoint::Executor::new_with_options($options)
}};
(ernie_endpoint,$options:expr) => {{
use llm_chain::traits::Executor;
llm_chain_ernie::Executor::new_with_options($options)
}};
}
5 changes: 5 additions & 0 deletions crates/llm-chain/src/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,11 @@ pub enum Opt {
UseMmap(bool),
// Force the system to keep the model in memory for llm-chain-llama.
UseMlock(bool),

// Disable realtime online search fro llm-chain-ernie
DisableSearch(bool),
// Enable return citation info for llm-chain-ernie
EnableCitation(bool),
}

// Helper function to extract environment variables
Expand Down