Skip to content

Commit

Permalink
Merge branch 'sobelio:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
andychenbruce authored Oct 13, 2023
2 parents d2beece + afef670 commit 214c2ff
Show file tree
Hide file tree
Showing 12 changed files with 403 additions and 13 deletions.
8 changes: 4 additions & 4 deletions crates/llm-chain-mock/examples/simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ extern crate llm_chain_mock;
async fn main() -> Result<(), Box<dyn Error>> {
let raw_args: Vec<String> = args().collect();
let prompt = match &raw_args.len() {
1 => "Rust is a cool programming language because",
2 => raw_args[1].as_str(),
_ => panic!("Usage: cargo run --release --example simple")
1 => "Rust is a cool programming language because",
2 => raw_args[1].as_str(),
_ => panic!("Usage: cargo run --release --example simple"),
};

let exec = executor!(mock)?;
Expand All @@ -26,4 +26,4 @@ async fn main() -> Result<(), Box<dyn Error>> {

println!("{}", res);
Ok(())
}
}
27 changes: 20 additions & 7 deletions crates/llm-chain-mock/src/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pub struct Executor {
#[async_trait]
impl llm_chain::traits::Executor for Executor {
type StepTokenizer<'a> = MockTokenizer;

fn new_with_options(options: Options) -> Result<Self, ExecutorCreationError> {
Ok(Executor { options: options })
}
Expand All @@ -31,7 +31,7 @@ impl llm_chain::traits::Executor for Executor {
options: &Options,
prompt: &Prompt,
) -> Result<TokenCount, PromptTokensError> {
let tokenizer = self.get_tokenizer(options)?;
let tokenizer = self.get_tokenizer(options)?;
let input = prompt.to_text();
let mut tokens_used = tokenizer
.tokenize_str(&input)
Expand Down Expand Up @@ -72,12 +72,22 @@ impl MockTokenizer {

impl Tokenizer for MockTokenizer {
fn tokenize_str(&self, doc: &str) -> Result<TokenCollection, TokenizerError> {
let tokens: Vec<i32> = doc.as_bytes().to_vec().into_iter().map(|c| c as i32).collect();
let tokens: Vec<i32> = doc
.as_bytes()
.to_vec()
.into_iter()
.map(|c| c as i32)
.collect();
Ok(tokens.into())
}

fn to_string(&self, tokens: TokenCollection) -> Result<String, TokenizerError> {
let bytes : Vec<u8> = tokens.as_i32().unwrap().into_iter().map(|c| c as u8).collect();
let bytes: Vec<u8> = tokens
.as_i32()
.unwrap()
.into_iter()
.map(|c| c as u8)
.collect();
let doc = String::from_utf8(bytes).unwrap();
Ok(doc)
}
Expand All @@ -89,14 +99,17 @@ mod tests {
use llm_chain::traits::Executor;
#[test]
fn test_mock_tokenizer() {
let executor: crate::Executor = Executor::new_with_options(Options::empty().clone()).unwrap();
let executor: crate::Executor =
Executor::new_with_options(Options::empty().clone()).unwrap();
let tokenizer = executor.get_tokenizer(&executor.options).unwrap();
let tokens = tokenizer
.tokenize_str("Héllo world") //Notice that the UTF8 character translates to x3 i32s
.expect("failed to tokenize");
println!("{:?}", tokens);
assert_eq!(tokens.len(), 13);
let doc = tokenizer.to_string(tokens).expect("failed to convert back to string");
let doc = tokenizer
.to_string(tokens)
.expect("failed to convert back to string");
assert_eq!(doc, "Héllo world");
}
}
}
11 changes: 11 additions & 0 deletions crates/llm-chain-sagemaker-endpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Changelog
All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## 0.12.3 - 2023-08-15
### Added
- Initial release
- Support for Falcon 7B Instruct and Falcon 40B Instruct models from SageMaker JumpStart
- Initial version number is trying to match the llm-chain model version
28 changes: 28 additions & 0 deletions crates/llm-chain-sagemaker-endpoint/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
[package]
name = "llm-chain-sagemaker-endpoint"
version = "0.12.3"
edition = "2021"
description = "Use `llm-chain` with a SageMaker Endpoint backend."
license = "MIT"
keywords = ["llm", "langchain", "chain"]
categories = ["science"]
authors = ["Shing Lyu <[email protected]>"]
repository = "https://github.com/sobelio/llm-chain/"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
async-trait = "0.1.68"
aws-config = "0.56.0"
aws-sdk-sagemakerruntime = "0.29.0"
futures = "0.3.28"
llm-chain = { path = "../llm-chain", version = "0.12.3", default-features = false }
serde = "1.0.183"
serde_json = "1.0.104"
serde_with = "3.2.0"
strum = "0.25.0"
strum_macros = "0.25.2"
thiserror = "1.0.40"

[dev-dependencies]
tokio = { version = "1.28.2", features = ["macros", "rt"] }
10 changes: 10 additions & 0 deletions crates/llm-chain-sagemaker-endpoint/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# llm-chain-sagemaker-endpoint

Amazon SageMaker Endppoint driver. Allows you to invoke a model hosted on Amazon SageMaker Endpoint, this includes Amazon SageMaker Jumpstart models.

# Getting Started
1. This crate uses the AWS SDK for Rust to communicate with Amazon SageMaker. You need to set up the credentials and a region following [this guide](https://docs.aws.amazon.com/sdk-for-rust/latest/dg/credentials.html)
1. Follow the SageMaker JumpStart documentation to [find an LLM](https://docs.aws.amazon.com/sagemaker/latest/dg/jumpstart-foundation-models-use.html), then [deploy it](https://docs.aws.amazon.com/sagemaker/latest/dg/jumpstart-deploy.html).
1. Note down the SageMaker Endpoint name created by SageMaker JumpStart.
1. Some models is included in this crate, see `model::Model::<model_name>`. Select one in your executor options's `Model` field. See `examples/simple.rs` for example.
1. For custom models or models not included in `model::Model`, use `model::Model::Other(<model_name>)`, where `model_name` is the SageMaker endpoint name.
36 changes: 36 additions & 0 deletions crates/llm-chain-sagemaker-endpoint/examples/simple.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
use llm_chain::executor;
use llm_chain::options;
use llm_chain::options::Options;
use std::{env::args, error::Error};

use llm_chain::{prompt::Data, traits::Executor};

extern crate llm_chain_sagemaker_endpoint;
use llm_chain_sagemaker_endpoint::model::Model;

/// This example demonstrates how to use the llm-chain-mock crate to generate text using a mock model.
///
/// Usage: cargo run --release --package llm-chain-mock --example simple <optional prompt>
///
#[tokio::main(flavor = "current_thread")]
async fn main() -> Result<(), Box<dyn Error>> {
let raw_args: Vec<String> = args().collect();
let prompt = match &raw_args.len() {
1 => "Rust is a cool programming language because",
2 => raw_args[1].as_str(),
_ => panic!("Usage: cargo run --release --example simple <optional prompt>"),
};

let opts = options!(
Model: Model::Falcon7BInstruct, // You need to deploy the Falcon 7B Instruct model using SageMaker JumpStart
MaxTokens: 50usize,
Temperature: 0.8
);
let exec = executor!(sagemaker_endpoint, opts)?;
let res = exec
.execute(Options::empty(), &Data::Text(String::from(prompt)))
.await?;

println!("{}", res);
Ok(())
}
116 changes: 116 additions & 0 deletions crates/llm-chain-sagemaker-endpoint/src/executor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
use crate::model::Formatter;
use crate::model::Model;
use async_trait::async_trait;

use llm_chain::options::Opt;
use llm_chain::options::Options;
use llm_chain::options::OptionsCascade;
use llm_chain::output::Output;
use llm_chain::prompt::Prompt;
use llm_chain::tokens::{
PromptTokensError, TokenCollection, TokenCount, Tokenizer, TokenizerError,
};
use llm_chain::traits::{ExecutorCreationError, ExecutorError};

use std::str::FromStr;

/// Executor is responsible for running the LLM and managing its context.
pub struct Executor {
#[allow(dead_code)]
options: Options,
sagemaker_client: aws_sdk_sagemakerruntime::Client,
}

impl Executor {
fn get_model_from_invocation_options(&self, opts: &OptionsCascade) -> Model {
let Some(Opt::Model(model)) = opts.get(llm_chain::options::OptDiscriminants::Model) else {
// TODO: fail gracefully
panic!("The Model option must not be empty. This option does not have a default.");
};
Model::from_str(&model.to_name()).unwrap()
}

fn cascade<'a>(&'a self, opts: Option<&'a Options>) -> OptionsCascade<'a> {
let mut v: Vec<&'a Options> = vec![&self.options];
if let Some(o) = opts {
v.push(o);
}
OptionsCascade::from_vec(v)
}
}

#[async_trait]
impl llm_chain::traits::Executor for Executor {
type StepTokenizer<'a> = SageMakerEndpointTokenizer;

fn new_with_options(options: Options) -> Result<Self, ExecutorCreationError> {
let config = futures::executor::block_on(aws_config::load_from_env());
let client = aws_sdk_sagemakerruntime::Client::new(&config);
Ok(Executor {
options,
sagemaker_client: client,
})
}

async fn execute(&self, options: &Options, prompt: &Prompt) -> Result<Output, ExecutorError> {
let opts = self.cascade(Some(options));
let model = self.get_model_from_invocation_options(&opts);

let body_blob = model.format_request(prompt, &opts);

let result = self
.sagemaker_client
.invoke_endpoint()
.endpoint_name(model.to_jumpstart_endpoint_name())
.content_type(model.request_content_type())
.body(body_blob)
.send()
.await;
let response = result.map_err(|e| ExecutorError::InnerError(e.into()))?;
let generated_text = model.parse_response(response);

Ok(Output::new_immediate(Prompt::text(generated_text)))
}

fn tokens_used(
&self,
_options: &Options,
_prompt: &Prompt,
) -> Result<TokenCount, PromptTokensError> {
// Not all models expose this information.
unimplemented!();
}

fn max_tokens_allowed(&self, _: &Options) -> i32 {
// Not all models expose this information.
unimplemented!();
}

fn answer_prefix(&self, _prompt: &Prompt) -> Option<String> {
// Not all models expose this information.
unimplemented!();
}

fn get_tokenizer(&self, _: &Options) -> Result<Self::StepTokenizer<'_>, TokenizerError> {
// Not all models expose this information.
unimplemented!();
}
}

pub struct SageMakerEndpointTokenizer {}

impl SageMakerEndpointTokenizer {
pub fn new(_executor: &Executor) -> Self {
SageMakerEndpointTokenizer {}
}
}

impl Tokenizer for SageMakerEndpointTokenizer {
fn tokenize_str(&self, _doc: &str) -> Result<TokenCollection, TokenizerError> {
unimplemented!();
}

fn to_string(&self, _tokens: TokenCollection) -> Result<String, TokenizerError> {
unimplemented!();
}
}
3 changes: 3 additions & 0 deletions crates/llm-chain-sagemaker-endpoint/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
mod executor;
pub use executor::Executor;
pub mod model;
Loading

0 comments on commit 214c2ff

Please sign in to comment.