Skip to content
This repository has been archived by the owner on Aug 2, 2024. It is now read-only.

Commit

Permalink
refactor: Refactor chat completion function in prompt.rs
Browse files Browse the repository at this point in the history
- Add new function for determining which base to use for `tiktoken`
- Improve `count_tokens` function to include `cli.system_message` and accept `model` argument
- Enhance `chat_completion` function to account for all tokens used, including system message and prompt
- Add debugging information for `max_tokens` calculation in `chat_completion`
- Update smoke test script with new arguments for `gptee` command
- Set `RUST_LOG` to "warn" if not already set.
  • Loading branch information
zurawiki committed Mar 5, 2023
1 parent 0ad1879 commit 4c3960f
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 13 deletions.
6 changes: 4 additions & 2 deletions src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,10 @@ pub(crate) struct CompletionArgs {
pub(crate) async fn main() -> anyhow::Result<()> {
let cli = CompletionArgs::parse();

// This should come from env var outside the program
std::env::set_var("RUST_LOG", "warn");
// Set RUST_LOG if not set
if std::env::var("RUST_LOG").is_err() {
std::env::set_var("RUST_LOG", "warn");
}

// Setup tracing subscriber so that library can log the rate limited message
tracing_subscriber::registry()
Expand Down
35 changes: 26 additions & 9 deletions src/prompt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use async_openai::{
};

use futures::StreamExt;
use tiktoken_rs::tiktoken::p50k_base;
use tiktoken_rs::tiktoken::{cl100k_base, p50k_base};

use crate::cli::CompletionArgs;

Expand All @@ -26,8 +26,12 @@ fn model_name_to_context_size(model_name: &str) -> u16 {
}
}

fn count_tokens(prompt: &str) -> anyhow::Result<u16> {
let bpe = p50k_base().unwrap();
fn count_tokens(model: &str, prompt: &str) -> anyhow::Result<u16> {
let bpe = match should_use_chat_completion(model) {
true => cl100k_base(),
false => p50k_base(),
}
.unwrap();
let tokens = bpe.encode_with_special_tokens(prompt);
Ok(tokens.len() as u16)
}
Expand Down Expand Up @@ -62,10 +66,22 @@ pub(crate) async fn chat_completion(
}
let request = request.messages(messages);

let max_tokens = model_name_to_context_size(model) - count_tokens(prompt)?;
let max_tokens = cli.max_tokens.unwrap_or(max_tokens);
let request = request.max_tokens(max_tokens);

// let max_tokens = cli.max_tokens.unwrap_or_else(|| {
// model_name_to_context_size(model)
// - count_tokens(
// model,
// &cli.system_message.to_owned().unwrap_or("".to_owned()),
// ).unwrap_or(0)
// - count_tokens(model, prompt).unwrap_or(0)
// // Chat completions use extra tokens for the prompt
// - 10
// });

let request = if cli.max_tokens.is_some() {
request.max_tokens(cli.max_tokens.unwrap())
} else {
request
};
let request = if !cli.stop.is_empty() {
request.stop(&cli.stop)
} else {
Expand Down Expand Up @@ -111,8 +127,9 @@ pub(crate) async fn completion(

let request = request.model(model);

let max_tokens = model_name_to_context_size(model) - count_tokens(&prompt)?;
let max_tokens = cli.max_tokens.unwrap_or(max_tokens);
let max_tokens = cli.max_tokens.unwrap_or_else(|| {
model_name_to_context_size(model) - count_tokens(model, &prompt).unwrap_or(0)
});
let request = request.max_tokens(max_tokens);

let request = if !cli.stop.is_empty() {
Expand Down
6 changes: 4 additions & 2 deletions tests/test_smoke.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#!/bin/sh
#!/usr/bin/env bash
set -eux

gptee --help
Expand All @@ -10,9 +10,11 @@ echo Tell me a joke | gptee

gptee <(echo Tell me a joke)

echo Tell me a joke | gptee -m text-davinci-003
echo Tell me a joke | gptee -m text-davinci-003 --max-tokens 2

echo Give me just a macOS zsh command to get the free space on my hard drive \
| gptee -s "Prefix each line of output with a pound sign if it not meant to be executed"

echo "Tell me I'm pretty" | gptee -s "You only speak French"
echo "Tell me I'm pretty" | gptee -s "You only speak French" --max-tokens 100
echo "Tell me I'm pretty" | gptee -s "You only speak French" --model text-davinci-003 --max-tokens 100

0 comments on commit 4c3960f

Please sign in to comment.