diff --git a/crates/llm-chain-mock/examples/simple.rs b/crates/llm-chain-mock/examples/simple.rs index aa85661c..7866b625 100644 --- a/crates/llm-chain-mock/examples/simple.rs +++ b/crates/llm-chain-mock/examples/simple.rs @@ -14,9 +14,9 @@ extern crate llm_chain_mock; async fn main() -> Result<(), Box> { let raw_args: Vec = args().collect(); let prompt = match &raw_args.len() { - 1 => "Rust is a cool programming language because", - 2 => raw_args[1].as_str(), - _ => panic!("Usage: cargo run --release --example simple") + 1 => "Rust is a cool programming language because", + 2 => raw_args[1].as_str(), + _ => panic!("Usage: cargo run --release --example simple"), }; let exec = executor!(mock)?; @@ -26,4 +26,4 @@ async fn main() -> Result<(), Box> { println!("{}", res); Ok(()) -} \ No newline at end of file +} diff --git a/crates/llm-chain-mock/src/executor.rs b/crates/llm-chain-mock/src/executor.rs index 885b82ca..0ad94c71 100644 --- a/crates/llm-chain-mock/src/executor.rs +++ b/crates/llm-chain-mock/src/executor.rs @@ -16,7 +16,7 @@ pub struct Executor { #[async_trait] impl llm_chain::traits::Executor for Executor { type StepTokenizer<'a> = MockTokenizer; - + fn new_with_options(options: Options) -> Result { Ok(Executor { options: options }) } @@ -31,7 +31,7 @@ impl llm_chain::traits::Executor for Executor { options: &Options, prompt: &Prompt, ) -> Result { - let tokenizer = self.get_tokenizer(options)?; + let tokenizer = self.get_tokenizer(options)?; let input = prompt.to_text(); let mut tokens_used = tokenizer .tokenize_str(&input) @@ -72,12 +72,22 @@ impl MockTokenizer { impl Tokenizer for MockTokenizer { fn tokenize_str(&self, doc: &str) -> Result { - let tokens: Vec = doc.as_bytes().to_vec().into_iter().map(|c| c as i32).collect(); + let tokens: Vec = doc + .as_bytes() + .to_vec() + .into_iter() + .map(|c| c as i32) + .collect(); Ok(tokens.into()) } fn to_string(&self, tokens: TokenCollection) -> Result { - let bytes : Vec = tokens.as_i32().unwrap().into_iter().map(|c| c as u8).collect(); + let bytes: Vec = tokens + .as_i32() + .unwrap() + .into_iter() + .map(|c| c as u8) + .collect(); let doc = String::from_utf8(bytes).unwrap(); Ok(doc) } @@ -89,14 +99,17 @@ mod tests { use llm_chain::traits::Executor; #[test] fn test_mock_tokenizer() { - let executor: crate::Executor = Executor::new_with_options(Options::empty().clone()).unwrap(); + let executor: crate::Executor = + Executor::new_with_options(Options::empty().clone()).unwrap(); let tokenizer = executor.get_tokenizer(&executor.options).unwrap(); let tokens = tokenizer .tokenize_str("Héllo world") //Notice that the UTF8 character translates to x3 i32s .expect("failed to tokenize"); println!("{:?}", tokens); assert_eq!(tokens.len(), 13); - let doc = tokenizer.to_string(tokens).expect("failed to convert back to string"); + let doc = tokenizer + .to_string(tokens) + .expect("failed to convert back to string"); assert_eq!(doc, "Héllo world"); } -} \ No newline at end of file +} diff --git a/crates/llm-chain-sagemaker-endpoint/CHANGELOG.md b/crates/llm-chain-sagemaker-endpoint/CHANGELOG.md new file mode 100644 index 00000000..345f8e05 --- /dev/null +++ b/crates/llm-chain-sagemaker-endpoint/CHANGELOG.md @@ -0,0 +1,11 @@ +# Changelog +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## 0.12.3 - 2023-08-15 +### Added +- Initial release +- Support for Falcon 7B Instruct and Falcon 40B Instruct models from SageMaker JumpStart +- Initial version number is trying to match the llm-chain model version diff --git a/crates/llm-chain-sagemaker-endpoint/Cargo.toml b/crates/llm-chain-sagemaker-endpoint/Cargo.toml new file mode 100644 index 00000000..b833daac --- /dev/null +++ b/crates/llm-chain-sagemaker-endpoint/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "llm-chain-sagemaker-endpoint" +version = "0.12.3" +edition = "2021" +description = "Use `llm-chain` with a SageMaker Endpoint backend." +license = "MIT" +keywords = ["llm", "langchain", "chain"] +categories = ["science"] +authors = ["Shing Lyu "] +repository = "https://github.com/sobelio/llm-chain/" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +async-trait = "0.1.68" +aws-config = "0.56.0" +aws-sdk-sagemakerruntime = "0.29.0" +futures = "0.3.28" +llm-chain = { path = "../llm-chain", version = "0.12.3", default-features = false } +serde = "1.0.183" +serde_json = "1.0.104" +serde_with = "3.2.0" +strum = "0.25.0" +strum_macros = "0.25.2" +thiserror = "1.0.40" + +[dev-dependencies] +tokio = { version = "1.28.2", features = ["macros", "rt"] } diff --git a/crates/llm-chain-sagemaker-endpoint/README.md b/crates/llm-chain-sagemaker-endpoint/README.md new file mode 100644 index 00000000..96bf769c --- /dev/null +++ b/crates/llm-chain-sagemaker-endpoint/README.md @@ -0,0 +1,10 @@ +# llm-chain-sagemaker-endpoint + +Amazon SageMaker Endppoint driver. Allows you to invoke a model hosted on Amazon SageMaker Endpoint, this includes Amazon SageMaker Jumpstart models. + +# Getting Started +1. This crate uses the AWS SDK for Rust to communicate with Amazon SageMaker. You need to set up the credentials and a region following [this guide](https://docs.aws.amazon.com/sdk-for-rust/latest/dg/credentials.html) +1. Follow the SageMaker JumpStart documentation to [find an LLM](https://docs.aws.amazon.com/sagemaker/latest/dg/jumpstart-foundation-models-use.html), then [deploy it](https://docs.aws.amazon.com/sagemaker/latest/dg/jumpstart-deploy.html). +1. Note down the SageMaker Endpoint name created by SageMaker JumpStart. +1. Some models is included in this crate, see `model::Model::`. Select one in your executor options's `Model` field. See `examples/simple.rs` for example. +1. For custom models or models not included in `model::Model`, use `model::Model::Other()`, where `model_name` is the SageMaker endpoint name. diff --git a/crates/llm-chain-sagemaker-endpoint/examples/simple.rs b/crates/llm-chain-sagemaker-endpoint/examples/simple.rs new file mode 100644 index 00000000..088d2b9d --- /dev/null +++ b/crates/llm-chain-sagemaker-endpoint/examples/simple.rs @@ -0,0 +1,36 @@ +use llm_chain::executor; +use llm_chain::options; +use llm_chain::options::Options; +use std::{env::args, error::Error}; + +use llm_chain::{prompt::Data, traits::Executor}; + +extern crate llm_chain_sagemaker_endpoint; +use llm_chain_sagemaker_endpoint::model::Model; + +/// This example demonstrates how to use the llm-chain-mock crate to generate text using a mock model. +/// +/// Usage: cargo run --release --package llm-chain-mock --example simple +/// +#[tokio::main(flavor = "current_thread")] +async fn main() -> Result<(), Box> { + let raw_args: Vec = args().collect(); + let prompt = match &raw_args.len() { + 1 => "Rust is a cool programming language because", + 2 => raw_args[1].as_str(), + _ => panic!("Usage: cargo run --release --example simple "), + }; + + let opts = options!( + Model: Model::Falcon7BInstruct, // You need to deploy the Falcon 7B Instruct model using SageMaker JumpStart + MaxTokens: 50usize, + Temperature: 0.8 + ); + let exec = executor!(sagemaker_endpoint, opts)?; + let res = exec + .execute(Options::empty(), &Data::Text(String::from(prompt))) + .await?; + + println!("{}", res); + Ok(()) +} diff --git a/crates/llm-chain-sagemaker-endpoint/src/executor.rs b/crates/llm-chain-sagemaker-endpoint/src/executor.rs new file mode 100644 index 00000000..2d1129a0 --- /dev/null +++ b/crates/llm-chain-sagemaker-endpoint/src/executor.rs @@ -0,0 +1,116 @@ +use crate::model::Formatter; +use crate::model::Model; +use async_trait::async_trait; + +use llm_chain::options::Opt; +use llm_chain::options::Options; +use llm_chain::options::OptionsCascade; +use llm_chain::output::Output; +use llm_chain::prompt::Prompt; +use llm_chain::tokens::{ + PromptTokensError, TokenCollection, TokenCount, Tokenizer, TokenizerError, +}; +use llm_chain::traits::{ExecutorCreationError, ExecutorError}; + +use std::str::FromStr; + +/// Executor is responsible for running the LLM and managing its context. +pub struct Executor { + #[allow(dead_code)] + options: Options, + sagemaker_client: aws_sdk_sagemakerruntime::Client, +} + +impl Executor { + fn get_model_from_invocation_options(&self, opts: &OptionsCascade) -> Model { + let Some(Opt::Model(model)) = opts.get(llm_chain::options::OptDiscriminants::Model) else { + // TODO: fail gracefully + panic!("The Model option must not be empty. This option does not have a default."); + }; + Model::from_str(&model.to_name()).unwrap() + } + + fn cascade<'a>(&'a self, opts: Option<&'a Options>) -> OptionsCascade<'a> { + let mut v: Vec<&'a Options> = vec![&self.options]; + if let Some(o) = opts { + v.push(o); + } + OptionsCascade::from_vec(v) + } +} + +#[async_trait] +impl llm_chain::traits::Executor for Executor { + type StepTokenizer<'a> = SageMakerEndpointTokenizer; + + fn new_with_options(options: Options) -> Result { + let config = futures::executor::block_on(aws_config::load_from_env()); + let client = aws_sdk_sagemakerruntime::Client::new(&config); + Ok(Executor { + options, + sagemaker_client: client, + }) + } + + async fn execute(&self, options: &Options, prompt: &Prompt) -> Result { + let opts = self.cascade(Some(options)); + let model = self.get_model_from_invocation_options(&opts); + + let body_blob = model.format_request(prompt, &opts); + + let result = self + .sagemaker_client + .invoke_endpoint() + .endpoint_name(model.to_jumpstart_endpoint_name()) + .content_type(model.request_content_type()) + .body(body_blob) + .send() + .await; + let response = result.map_err(|e| ExecutorError::InnerError(e.into()))?; + let generated_text = model.parse_response(response); + + Ok(Output::new_immediate(Prompt::text(generated_text))) + } + + fn tokens_used( + &self, + _options: &Options, + _prompt: &Prompt, + ) -> Result { + // Not all models expose this information. + unimplemented!(); + } + + fn max_tokens_allowed(&self, _: &Options) -> i32 { + // Not all models expose this information. + unimplemented!(); + } + + fn answer_prefix(&self, _prompt: &Prompt) -> Option { + // Not all models expose this information. + unimplemented!(); + } + + fn get_tokenizer(&self, _: &Options) -> Result, TokenizerError> { + // Not all models expose this information. + unimplemented!(); + } +} + +pub struct SageMakerEndpointTokenizer {} + +impl SageMakerEndpointTokenizer { + pub fn new(_executor: &Executor) -> Self { + SageMakerEndpointTokenizer {} + } +} + +impl Tokenizer for SageMakerEndpointTokenizer { + fn tokenize_str(&self, _doc: &str) -> Result { + unimplemented!(); + } + + fn to_string(&self, _tokens: TokenCollection) -> Result { + unimplemented!(); + } +} diff --git a/crates/llm-chain-sagemaker-endpoint/src/lib.rs b/crates/llm-chain-sagemaker-endpoint/src/lib.rs new file mode 100644 index 00000000..94f777a8 --- /dev/null +++ b/crates/llm-chain-sagemaker-endpoint/src/lib.rs @@ -0,0 +1,3 @@ +mod executor; +pub use executor::Executor; +pub mod model; diff --git a/crates/llm-chain-sagemaker-endpoint/src/model.rs b/crates/llm-chain-sagemaker-endpoint/src/model.rs new file mode 100644 index 00000000..0378e775 --- /dev/null +++ b/crates/llm-chain-sagemaker-endpoint/src/model.rs @@ -0,0 +1,170 @@ +use aws_sdk_sagemakerruntime::operation::invoke_endpoint::InvokeEndpointOutput; +use aws_sdk_sagemakerruntime::primitives::Blob; +use llm_chain::options::{ModelRef, Opt, OptDiscriminants, OptionsCascade}; +use llm_chain::prompt::Prompt; +use serde::{Deserialize, Serialize}; +use serde_json::json; +use serde_with::skip_serializing_none; +use strum_macros::EnumString; + +/// The `Model` enum represents the available SageMaker Endpoint models. +/// Use SageMaker JumpStart to deploy the model listed here. Or use Model::Other +/// to reference your custom models. For Model::Other, you need to write your own +/// formatting logic for the request and response. +/// +/// # Example +/// +/// ``` +/// use llm_chain_sagemaker_endpoint::model::Model; +/// +/// let falcon_model = Model::Falcon7BInstruct; +/// let custom_model = Model::Other("your_custom_model_name".to_string()); +/// ``` +#[derive(Debug, Default, Clone, Serialize, Deserialize, EnumString, PartialEq, Eq)] +#[non_exhaustive] +pub enum Model { + /// Falcon 7B Instruct BF16 + /// https://huggingface.co/tiiuae/falcon-7b-instruct + #[default] + #[strum(serialize = "falcon-7b-instruct")] + Falcon7BInstruct, + + /// Falcon 40B Instruct BF16 + /// https://huggingface.co/tiiuae/falcon-40b-instruct + #[strum(serialize = "falcon-40b-instruct")] + Falcon40BInstruct, + + /// A variant that allows you to specify a custom model name as a string, in case new models + /// are introduced or you have access to specialized models. + #[strum(default)] + Other(String), +} + +pub trait Formatter { + fn format_request(&self, prompt: &Prompt, options: &OptionsCascade) -> Blob; + fn request_content_type(&self) -> String; + fn parse_response(&self, response: InvokeEndpointOutput) -> String; +} + +impl Formatter for Model { + fn format_request(&self, prompt: &Prompt, options: &OptionsCascade) -> Blob { + match self { + Model::Falcon7BInstruct | Model::Falcon40BInstruct => { + #[skip_serializing_none] + #[derive(Serialize)] + struct Parameters { + max_new_tokens: Option, + max_length: Option, + temperature: Option, + top_k: Option, + top_p: Option, + stop: Option>, + // TODO: num_beams, no_repeat_ngram_size, early_stopping, do_sample, return_full_text + } + + let parameters = Parameters { + max_new_tokens: options.get(OptDiscriminants::MaxTokens).map(|x| match x { + Opt::MaxTokens(i) => *i, + _ => unreachable!("options.get should restrict the enum variant."), + }), + max_length: options + .get(OptDiscriminants::MaxContextSize) + .map(|x| match x { + Opt::MaxContextSize(i) => *i, + _ => unreachable!("options.get should restrict the enum variant."), + }), + temperature: options.get(OptDiscriminants::Temperature).map(|x| match x { + Opt::Temperature(i) => *i, + _ => unreachable!("options.get should restrict the enum variant."), + }), + top_k: options.get(OptDiscriminants::TopK).map(|x| match x { + Opt::TopK(i) => *i, + _ => unreachable!("options.get should restrict the enum variant."), + }), + top_p: options.get(OptDiscriminants::TopP).map(|x| match x { + Opt::TopP(i) => *i, + _ => unreachable!("options.get should restrict the enum variant."), + }), + stop: options + .get(OptDiscriminants::StopSequence) + .map(|x| match x { + Opt::StopSequence(i) => i.clone(), + _ => unreachable!("options.get should restrict the enum variant."), + }), + }; + + let body_json = json!({ + "inputs": prompt.to_string(), + "parameters": parameters + }); + + let body_string = body_json.to_string(); + let body_blob = Blob::new(body_string.as_bytes().to_vec()); + body_blob + } + _ => { + unimplemented!("This model does not have a default formatter. Please format the request with your own code."); + } + } + } + + fn request_content_type(&self) -> String { + match self { + Model::Falcon7BInstruct | Model::Falcon40BInstruct => "application/json".to_string(), + _ => { + unimplemented!("This model does not have a default formatter. Please format the request with your own code."); + } + } + } + + fn parse_response(&self, response: InvokeEndpointOutput) -> String { + match self { + Model::Falcon7BInstruct => { + let output = String::from_utf8(response.body.unwrap().into_inner()).unwrap(); + let output_json: serde_json::Value = serde_json::from_str(&output).unwrap(); + + output_json[0]["generated_text"].to_string() + } + _ => { + unimplemented!("This model does not have a default formatter. Please format the response with your own code."); + } + } + } +} + +impl Model { + /// Convert the model to its SageMaker JumpStart default endpoint name + pub fn to_jumpstart_endpoint_name(&self) -> String { + match &self { + Model::Falcon7BInstruct => "jumpstart-dft-hf-llm-falcon-7b-instruct-bf16".to_string(), + Model::Falcon40BInstruct => "jumpstart-dft-hf-llm-falcon-40b-instruct-bf16".to_string(), + _ => self.to_string(), + } + } +} + +/// The `Model` enum implements the `ToString` trait, allowing you to easily convert it to a string. +impl ToString for Model { + fn to_string(&self) -> String { + match &self { + Model::Falcon7BInstruct => "falcon-7b-instruct".to_string(), + Model::Falcon40BInstruct => "falcon-40b-instruct".to_string(), + //jumpstart-dft-hf-llm-falcon-7b-instruct-bf16 + Model::Other(model) => model.to_string(), + } + } +} + +/// Conversion from Model to ModelRef +impl From for ModelRef { + fn from(value: Model) -> Self { + ModelRef::from_model_name(value.to_string()) + } +} + +/// Conversion from Model to Option +impl From for Opt { + fn from(value: Model) -> Self { + Opt::Model(value.into()) + } +} diff --git a/crates/llm-chain/src/executor.rs b/crates/llm-chain/src/executor.rs index 504c5b77..dbf044cb 100644 --- a/crates/llm-chain/src/executor.rs +++ b/crates/llm-chain/src/executor.rs @@ -72,4 +72,8 @@ macro_rules! executor { use llm_chain::traits::Executor; llm_chain_mock::Executor::new() }}; + (sagemaker_endpoint, $options:expr) => {{ + use llm_chain::traits::Executor; + llm_chain_sagemaker_endpoint::Executor::new_with_options($options) + }}; } diff --git a/crates/llm-chain/src/options.rs b/crates/llm-chain/src/options.rs index 21eba146..88d905f8 100644 --- a/crates/llm-chain/src/options.rs +++ b/crates/llm-chain/src/options.rs @@ -406,7 +406,6 @@ pub enum Opt { /// RoPE frequency scale /// Only for llm-chain-llama RopeFrequencyScale(f32), - } // Helper function to extract environment variables diff --git a/website/docs/llama-tutorial.md b/website/docs/llama-tutorial.md index 75e46c0c..b1fc46b2 100644 --- a/website/docs/llama-tutorial.md +++ b/website/docs/llama-tutorial.md @@ -188,7 +188,7 @@ pip install -r requirements.txt With the Python dependencies installed, you need to run the conversion script that will convert the Alpaca model to a binary format that llama.cpp can read. To do that, run the following command in your terminal: ``` -python convert.py /models/alpaca-native +python convert.py ./models/alpaca-native ``` This will run the `convert.py` script that is located in the `llama.cpp` directory. The script will take the Alpaca model directory as an argument and output a binary file called `ggml-model-f32.bin` in the same directory.