-
Notifications
You must be signed in to change notification settings - Fork 132
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'sobelio:main' into main
- Loading branch information
Showing
12 changed files
with
403 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# Changelog | ||
All notable changes to this project will be documented in this file. | ||
|
||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), | ||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). | ||
|
||
## 0.12.3 - 2023-08-15 | ||
### Added | ||
- Initial release | ||
- Support for Falcon 7B Instruct and Falcon 40B Instruct models from SageMaker JumpStart | ||
- Initial version number is trying to match the llm-chain model version |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
[package] | ||
name = "llm-chain-sagemaker-endpoint" | ||
version = "0.12.3" | ||
edition = "2021" | ||
description = "Use `llm-chain` with a SageMaker Endpoint backend." | ||
license = "MIT" | ||
keywords = ["llm", "langchain", "chain"] | ||
categories = ["science"] | ||
authors = ["Shing Lyu <[email protected]>"] | ||
repository = "https://github.com/sobelio/llm-chain/" | ||
|
||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html | ||
|
||
[dependencies] | ||
async-trait = "0.1.68" | ||
aws-config = "0.56.0" | ||
aws-sdk-sagemakerruntime = "0.29.0" | ||
futures = "0.3.28" | ||
llm-chain = { path = "../llm-chain", version = "0.12.3", default-features = false } | ||
serde = "1.0.183" | ||
serde_json = "1.0.104" | ||
serde_with = "3.2.0" | ||
strum = "0.25.0" | ||
strum_macros = "0.25.2" | ||
thiserror = "1.0.40" | ||
|
||
[dev-dependencies] | ||
tokio = { version = "1.28.2", features = ["macros", "rt"] } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# llm-chain-sagemaker-endpoint | ||
|
||
Amazon SageMaker Endppoint driver. Allows you to invoke a model hosted on Amazon SageMaker Endpoint, this includes Amazon SageMaker Jumpstart models. | ||
|
||
# Getting Started | ||
1. This crate uses the AWS SDK for Rust to communicate with Amazon SageMaker. You need to set up the credentials and a region following [this guide](https://docs.aws.amazon.com/sdk-for-rust/latest/dg/credentials.html) | ||
1. Follow the SageMaker JumpStart documentation to [find an LLM](https://docs.aws.amazon.com/sagemaker/latest/dg/jumpstart-foundation-models-use.html), then [deploy it](https://docs.aws.amazon.com/sagemaker/latest/dg/jumpstart-deploy.html). | ||
1. Note down the SageMaker Endpoint name created by SageMaker JumpStart. | ||
1. Some models is included in this crate, see `model::Model::<model_name>`. Select one in your executor options's `Model` field. See `examples/simple.rs` for example. | ||
1. For custom models or models not included in `model::Model`, use `model::Model::Other(<model_name>)`, where `model_name` is the SageMaker endpoint name. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
use llm_chain::executor; | ||
use llm_chain::options; | ||
use llm_chain::options::Options; | ||
use std::{env::args, error::Error}; | ||
|
||
use llm_chain::{prompt::Data, traits::Executor}; | ||
|
||
extern crate llm_chain_sagemaker_endpoint; | ||
use llm_chain_sagemaker_endpoint::model::Model; | ||
|
||
/// This example demonstrates how to use the llm-chain-mock crate to generate text using a mock model. | ||
/// | ||
/// Usage: cargo run --release --package llm-chain-mock --example simple <optional prompt> | ||
/// | ||
#[tokio::main(flavor = "current_thread")] | ||
async fn main() -> Result<(), Box<dyn Error>> { | ||
let raw_args: Vec<String> = args().collect(); | ||
let prompt = match &raw_args.len() { | ||
1 => "Rust is a cool programming language because", | ||
2 => raw_args[1].as_str(), | ||
_ => panic!("Usage: cargo run --release --example simple <optional prompt>"), | ||
}; | ||
|
||
let opts = options!( | ||
Model: Model::Falcon7BInstruct, // You need to deploy the Falcon 7B Instruct model using SageMaker JumpStart | ||
MaxTokens: 50usize, | ||
Temperature: 0.8 | ||
); | ||
let exec = executor!(sagemaker_endpoint, opts)?; | ||
let res = exec | ||
.execute(Options::empty(), &Data::Text(String::from(prompt))) | ||
.await?; | ||
|
||
println!("{}", res); | ||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
use crate::model::Formatter; | ||
use crate::model::Model; | ||
use async_trait::async_trait; | ||
|
||
use llm_chain::options::Opt; | ||
use llm_chain::options::Options; | ||
use llm_chain::options::OptionsCascade; | ||
use llm_chain::output::Output; | ||
use llm_chain::prompt::Prompt; | ||
use llm_chain::tokens::{ | ||
PromptTokensError, TokenCollection, TokenCount, Tokenizer, TokenizerError, | ||
}; | ||
use llm_chain::traits::{ExecutorCreationError, ExecutorError}; | ||
|
||
use std::str::FromStr; | ||
|
||
/// Executor is responsible for running the LLM and managing its context. | ||
pub struct Executor { | ||
#[allow(dead_code)] | ||
options: Options, | ||
sagemaker_client: aws_sdk_sagemakerruntime::Client, | ||
} | ||
|
||
impl Executor { | ||
fn get_model_from_invocation_options(&self, opts: &OptionsCascade) -> Model { | ||
let Some(Opt::Model(model)) = opts.get(llm_chain::options::OptDiscriminants::Model) else { | ||
// TODO: fail gracefully | ||
panic!("The Model option must not be empty. This option does not have a default."); | ||
}; | ||
Model::from_str(&model.to_name()).unwrap() | ||
} | ||
|
||
fn cascade<'a>(&'a self, opts: Option<&'a Options>) -> OptionsCascade<'a> { | ||
let mut v: Vec<&'a Options> = vec![&self.options]; | ||
if let Some(o) = opts { | ||
v.push(o); | ||
} | ||
OptionsCascade::from_vec(v) | ||
} | ||
} | ||
|
||
#[async_trait] | ||
impl llm_chain::traits::Executor for Executor { | ||
type StepTokenizer<'a> = SageMakerEndpointTokenizer; | ||
|
||
fn new_with_options(options: Options) -> Result<Self, ExecutorCreationError> { | ||
let config = futures::executor::block_on(aws_config::load_from_env()); | ||
let client = aws_sdk_sagemakerruntime::Client::new(&config); | ||
Ok(Executor { | ||
options, | ||
sagemaker_client: client, | ||
}) | ||
} | ||
|
||
async fn execute(&self, options: &Options, prompt: &Prompt) -> Result<Output, ExecutorError> { | ||
let opts = self.cascade(Some(options)); | ||
let model = self.get_model_from_invocation_options(&opts); | ||
|
||
let body_blob = model.format_request(prompt, &opts); | ||
|
||
let result = self | ||
.sagemaker_client | ||
.invoke_endpoint() | ||
.endpoint_name(model.to_jumpstart_endpoint_name()) | ||
.content_type(model.request_content_type()) | ||
.body(body_blob) | ||
.send() | ||
.await; | ||
let response = result.map_err(|e| ExecutorError::InnerError(e.into()))?; | ||
let generated_text = model.parse_response(response); | ||
|
||
Ok(Output::new_immediate(Prompt::text(generated_text))) | ||
} | ||
|
||
fn tokens_used( | ||
&self, | ||
_options: &Options, | ||
_prompt: &Prompt, | ||
) -> Result<TokenCount, PromptTokensError> { | ||
// Not all models expose this information. | ||
unimplemented!(); | ||
} | ||
|
||
fn max_tokens_allowed(&self, _: &Options) -> i32 { | ||
// Not all models expose this information. | ||
unimplemented!(); | ||
} | ||
|
||
fn answer_prefix(&self, _prompt: &Prompt) -> Option<String> { | ||
// Not all models expose this information. | ||
unimplemented!(); | ||
} | ||
|
||
fn get_tokenizer(&self, _: &Options) -> Result<Self::StepTokenizer<'_>, TokenizerError> { | ||
// Not all models expose this information. | ||
unimplemented!(); | ||
} | ||
} | ||
|
||
pub struct SageMakerEndpointTokenizer {} | ||
|
||
impl SageMakerEndpointTokenizer { | ||
pub fn new(_executor: &Executor) -> Self { | ||
SageMakerEndpointTokenizer {} | ||
} | ||
} | ||
|
||
impl Tokenizer for SageMakerEndpointTokenizer { | ||
fn tokenize_str(&self, _doc: &str) -> Result<TokenCollection, TokenizerError> { | ||
unimplemented!(); | ||
} | ||
|
||
fn to_string(&self, _tokens: TokenCollection) -> Result<String, TokenizerError> { | ||
unimplemented!(); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
mod executor; | ||
pub use executor::Executor; | ||
pub mod model; |
Oops, something went wrong.