Skip to content
This repository has been archived by the owner on Aug 2, 2024. It is now read-only.

Commit

Permalink
refactor: Refactor tokenization and update dependencies.
Browse files Browse the repository at this point in the history
- Update dependencies in Cargo.toml
- Refactor prompt.rs to use tiktoken_rs for tokenization
- Add GPT-4 to smoke test and increase max tokens for test prompt in test_smoke.sh
  • Loading branch information
zurawiki committed Mar 18, 2023
1 parent 21681ba commit 428b538
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 57 deletions.
13 changes: 7 additions & 6 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ path = "src/main.rs"

[dependencies]
anyhow = "1.0.69"
async-openai = "0.9.4"
async-openai = "0.9.5"
backoff = { version = "0.4.0", features = ["tokio"] }
clap = { version = "4.1.9", features = ["derive"] }
clap = { version = "4.1.10", features = ["derive"] }
futures = "0.3.27"
tempfile = "3.4.0"
tiktoken-rs = "0.2.2"
tiktoken-rs = "0.3.1"
tokio = {version = "1.26.0", features = ["full"]}
tracing-subscriber = { version = "0.3.16", features = ["env-filter"]}
61 changes: 13 additions & 48 deletions src/prompt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,36 +8,11 @@ use async_openai::{
};

use futures::StreamExt;
use tiktoken_rs::tiktoken::{cl100k_base, p50k_base};

use crate::cli::CompletionArgs;

/// Calculate the maximum number of tokens possible to generate for a model
fn model_name_to_context_size(model_name: &str) -> u16 {
match model_name {
"text-davinci-003" => 4000,
"text-davinci-002" => 4000,
"text-curie-001" => 2048,
"text-babbage-001" => 2048,
"text-ada-001" => 2048,
"code-davinci-002" => 4000,
"code-cushman-001" => 2048,
_ => 4096,
}
}

fn count_tokens(model: &str, prompt: &str) -> anyhow::Result<u16> {
let bpe = match should_use_chat_completion(model) {
true => cl100k_base(),
false => p50k_base(),
}
.unwrap();
let tokens = bpe.encode_with_special_tokens(prompt);
Ok(tokens.len() as u16)
}

pub(crate) fn should_use_chat_completion(model: &str) -> bool {
model.to_lowercase().starts_with("gpt-3.5-turbo")
model.to_lowercase().starts_with("gpt-4") || model.to_lowercase().starts_with("gpt-3.5-turbo")
}

pub(crate) async fn chat_completion(
Expand All @@ -64,24 +39,12 @@ pub(crate) async fn chat_completion(
.build()?,
);
}
let request = request.messages(messages);

// let max_tokens = cli.max_tokens.unwrap_or_else(|| {
// model_name_to_context_size(model)
// - count_tokens(
// model,
// &cli.system_message.to_owned().unwrap_or("".to_owned()),
// ).unwrap_or(0)
// - count_tokens(model, prompt).unwrap_or(0)
// // Chat completions use extra tokens for the prompt
// - 10
// });

let request = if cli.max_tokens.is_some() {
request.max_tokens(cli.max_tokens.unwrap())
} else {
request
};
let request = request.messages(messages.to_owned());
let max_tokens = cli.max_tokens.unwrap_or_else(|| {
tiktoken_rs::get_chat_completion_max_tokens(model, &messages).unwrap() as u16
});

let request = request.max_tokens(max_tokens);
let request = if !cli.stop.is_empty() {
request.stop(&cli.stop)
} else {
Expand Down Expand Up @@ -127,10 +90,12 @@ pub(crate) async fn completion(

let request = request.model(model);

let max_tokens = cli.max_tokens.unwrap_or_else(|| {
model_name_to_context_size(model) - count_tokens(model, &prompt).unwrap_or(0)
});
let request = request.max_tokens(max_tokens);
let request = if let Some(max_tokens) = cli.max_tokens {
request.max_tokens(max_tokens)
} else {
let max_tokens = tiktoken_rs::get_completion_max_tokens(model, &prompt)? as u16;
request.max_tokens(max_tokens)
};

let request = if !cli.stop.is_empty() {
request.stop(&cli.stop)
Expand Down
3 changes: 3 additions & 0 deletions tests/test_smoke.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ echo Tell me a joke | gptee -m text-davinci-003 --max-tokens 2
echo Give me just a macOS zsh command to get the free space on my hard drive \
| gptee -s "Prefix each line of output with a pound sign if it not meant to be executed"

echo Give me just a macOS zsh command to get the free space on my hard drive \
| gptee -s "Prefix each line of output with a pound sign if it not meant to be executed" --model gpt-4

echo "Tell me I'm pretty" | gptee -s "You only speak French"
echo "Tell me I'm pretty" | gptee -s "You only speak French" --max-tokens 100
echo "Tell me I'm pretty" | gptee -s "You only speak French" --model text-davinci-003 --max-tokens 100

0 comments on commit 428b538

Please sign in to comment.