Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Tokenizer support for Sagemaker Endpoint Backend #213

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/llm-chain-sagemaker-endpoint/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ serde_with = "3.2.0"
strum = "0.25.0"
strum_macros = "0.25.2"
thiserror = "1.0.40"
tokenizers ={ version = "0.13.4", features = ["http"]}

[dev-dependencies]
tokio = { version = "1.28.2", features = ["macros", "rt"] }
142 changes: 120 additions & 22 deletions crates/llm-chain-sagemaker-endpoint/src/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@ use crate::model::Formatter;
use crate::model::Model;
use async_trait::async_trait;

use llm_chain::options::Opt;
use llm_chain::options::Options;
use llm_chain::options::OptionsCascade;
use llm_chain::options; // for the top-level macro
use llm_chain::options::{Opt, Options, OptionsCascade};
use llm_chain::output::Output;
use llm_chain::prompt::Prompt;
use llm_chain::tokens::{
PromptTokensError, TokenCollection, TokenCount, Tokenizer, TokenizerError,
};
use llm_chain::traits::{ExecutorCreationError, ExecutorError};

use tokenizers::tokenizer::Tokenizer as HuggingFaceTokenizer;

use std::str::FromStr;

/// Executor is responsible for running the LLM and managing its context.
Expand Down Expand Up @@ -74,43 +75,140 @@ impl llm_chain::traits::Executor for Executor {

fn tokens_used(
&self,
_options: &Options,
_prompt: &Prompt,
options: &Options,
prompt: &Prompt,
) -> Result<TokenCount, PromptTokensError> {
// Not all models expose this information.
unimplemented!();
let tokenizer = self.get_tokenizer(options)?;
let input = prompt.to_text();

let tokens_used = tokenizer
.tokenize_str(&input)
.map_err(|_e| PromptTokensError::UnableToCompute)?
.len() as i32;
let max_tokens = self.max_tokens_allowed(options);
Ok(TokenCount::new(max_tokens, tokens_used))
}

fn max_tokens_allowed(&self, _: &Options) -> i32 {
// Not all models expose this information.
unimplemented!();
fn max_tokens_allowed(&self, options: &Options) -> i32 {
let opts = self.cascade(Some(options));
let model = self.get_model_from_invocation_options(&opts);
model.context_window_size().unwrap_or_else(|| {
unimplemented!("This model does not expose max token allowed information.")
})
}

fn answer_prefix(&self, _prompt: &Prompt) -> Option<String> {
// Not all models expose this information.
unimplemented!();
None
}

fn get_tokenizer(&self, _: &Options) -> Result<Self::StepTokenizer<'_>, TokenizerError> {
// Not all models expose this information.
unimplemented!();
fn get_tokenizer(&self, options: &Options) -> Result<Self::StepTokenizer<'_>, TokenizerError> {
Ok(SageMakerEndpointTokenizer::new(self.cascade(Some(options))))
}
}

pub struct SageMakerEndpointTokenizer {}
pub struct SageMakerEndpointTokenizer {
tokenizer: Option<HuggingFaceTokenizer>,
}

impl SageMakerEndpointTokenizer {
pub fn new(_executor: &Executor) -> Self {
SageMakerEndpointTokenizer {}
pub fn new(options: OptionsCascade) -> Self {
let optional_tokenizer = match options.get(llm_chain::options::OptDiscriminants::Model) {
Some(Opt::Model(model)) => {
let model_struct = Model::from_str(&model.to_name()).unwrap();
Some(
HuggingFaceTokenizer::from_pretrained(
&model_struct.to_huggingface_name(),
None,
)
.unwrap(),
) // TODO: no options
}
_ => None,
};

SageMakerEndpointTokenizer {
tokenizer: optional_tokenizer,
}
}
}

impl Tokenizer for SageMakerEndpointTokenizer {
fn tokenize_str(&self, _doc: &str) -> Result<TokenCollection, TokenizerError> {
unimplemented!();
fn tokenize_str(&self, doc: &str) -> Result<TokenCollection, TokenizerError> {
match &self.tokenizer {
Some(tokenizer) => {
let encoding = tokenizer
.encode(doc, false)
.map_err(|_| TokenizerError::TokenizationError)?;
let ids: Vec<_> = encoding.get_ids().iter().map(|x| *x as i32).collect();
Ok(TokenCollection::from(ids))
}
None => unimplemented!("This model does not have a tokenizer impelmentation."),
}
}

fn to_string(&self, tokens: TokenCollection) -> Result<String, TokenizerError> {
match &self.tokenizer {
Some(tokenizer) => {
let ids: Vec<_> = tokens.as_i32().unwrap().iter().map(|x| *x as u32).collect();
Ok(tokenizer
.decode(ids.as_slice(), false)
.map_err(|_| TokenizerError::TokenizationError)?)
}
None => unimplemented!("This model does not have a tokenizer impelmentation."),
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use llm_chain::traits::Executor;

#[test]
fn test_tokenizer() {
let opts = options!(
Model: Model::Falcon7BInstruct
);
let executor: super::Executor = Executor::new_with_options(opts.clone()).unwrap();
let opts_cascade = executor.cascade(Some(&opts));
let tokenizer = SageMakerEndpointTokenizer::new(opts_cascade);
let doc = "This is a example string to be tokenized";
let tokens = vec![1182, 304, 241, 1945, 3821, 271, 314, 10930, 1190];

assert_eq!(tokenizer.tokenize_str(doc).unwrap().len(), 9);
assert_eq!(
tokenizer.tokenize_str(doc).unwrap().as_i32().unwrap(),
tokens
);

assert_eq!(
tokenizer.to_string(TokenCollection::from(tokens)).unwrap(),
doc
);
}

#[test]
fn test_max_token_allowed() {
let opts = options!(
Model: Model::Falcon7BInstruct
);
let executor: super::Executor = Executor::new_with_options(opts.clone()).unwrap();

assert_eq!(executor.max_tokens_allowed(&opts), 2048);
}

fn to_string(&self, _tokens: TokenCollection) -> Result<String, TokenizerError> {
unimplemented!();
#[test]
fn test_token_used() {
let opts = options!(
Model: Model::Falcon7BInstruct
);
let executor: super::Executor = Executor::new_with_options(opts.clone()).unwrap();
let doc = "This is a example string to be tokenized"; // 9 tokens
let prompt = Prompt::text(doc.to_string());

assert_eq!(
executor.tokens_used(&opts, &prompt).unwrap(),
TokenCount::new(2048, 9)
);
}
}
18 changes: 17 additions & 1 deletion crates/llm-chain-sagemaker-endpoint/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ impl Formatter for Model {

fn parse_response(&self, response: InvokeEndpointOutput) -> String {
match self {
Model::Falcon7BInstruct => {
Model::Falcon7BInstruct | Model::Falcon40BInstruct => {
let output = String::from_utf8(response.body.unwrap().into_inner()).unwrap();
let output_json: serde_json::Value = serde_json::from_str(&output).unwrap();

Expand All @@ -141,6 +141,22 @@ impl Model {
_ => self.to_string(),
}
}
/// Convert the model to its HuggingFace model name
pub fn to_huggingface_name(&self) -> String {
match &self {
Model::Falcon7BInstruct => "tiiuae/falcon-7b-instruct".to_string(),
Model::Falcon40BInstruct => "tiiuae/falcon-40b-instruct".to_string(),
_ => self.to_string(),
}
}

pub fn context_window_size(&self) -> Option<i32> {
match &self {
Model::Falcon7BInstruct => Some(2048),
Model::Falcon40BInstruct => Some(2048),
_ => None,
}
}
}

/// The `Model` enum implements the `ToString` trait, allowing you to easily convert it to a string.
Expand Down
1 change: 1 addition & 0 deletions crates/llm-chain/src/tokens.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ impl<E: traits::Executor> ExecutorTokenCountExt for E {}

/// Struct representing token count information, including the maximum tokens allowed and the
/// total number of tokens used.
#[derive(Debug, PartialEq)]
pub struct TokenCount {
/// The maximum number of tokens allowed.
max_tokens: i32,
Expand Down
Loading