Skip to content

Commit cc569c7

Browse files
NarsilHugoch
andauthored
Clean url validation (#18)
* Suggest removing the docker, and instead installing the CLI. * Putting back the profile. * Getting the API from environment (to catch other things like `HF_HOME`). * Since we want a URL from the start we can actually use a URL all the way. Fixing. * Fixing the URL handling. --------- Co-authored-by: Hugo Larcher <[email protected]>
1 parent b8b086f commit cc569c7

File tree

4 files changed

+18
-21
lines changed

4 files changed

+18
-21
lines changed

src/lib.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ pub use crate::requests::TokenizeOptions;
1414
use chrono::Local;
1515
use crossterm::ExecutableCommand;
1616
use log::{debug, error, info, warn, Level, LevelFilter};
17+
use reqwest::Url;
1718
use tokenizers::{FromPretrainedParameters, Tokenizer};
1819
use tokio::sync::broadcast::Sender;
1920
use tokio::sync::Mutex;
@@ -32,7 +33,7 @@ mod table;
3233
mod writers;
3334

3435
pub struct RunConfiguration {
35-
pub url: String,
36+
pub url: Url,
3637
pub tokenizer_name: String,
3738
pub profile: Option<String>,
3839
pub max_vus: u64,
@@ -85,7 +86,7 @@ pub async fn run(mut run_config: RunConfiguration, stop_sender: Sender<()>) -> a
8586
let tokenizer = Arc::new(tokenizer);
8687
let backend = OpenAITextGenerationBackend::try_new(
8788
"".to_string(),
88-
run_config.url.clone(),
89+
run_config.url,
8990
run_config.model_name.clone(),
9091
tokenizer,
9192
run_config.duration,

src/main.rs

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ struct Args {
4545
warmup: Duration,
4646
/// The URL of the backend to benchmark. Must be compatible with OpenAI Message API
4747
#[clap(default_value = "http://localhost:8000", short, long, env)]
48-
#[arg(value_parser = parse_url)]
49-
url: String,
48+
url: Url,
49+
5050
/// Disable console UI
5151
#[clap(short, long, env)]
5252
no_console: bool,
@@ -115,13 +115,6 @@ fn parse_duration(s: &str) -> Result<Duration, Error> {
115115
humantime::parse_duration(s).map_err(|_| Error::new(InvalidValue))
116116
}
117117

118-
fn parse_url(s: &str) -> Result<String, Error> {
119-
match Url::parse(s) {
120-
Ok(_) => Ok(s.to_string()),
121-
Err(_) => Err(Error::new(InvalidValue)),
122-
}
123-
}
124-
125118
fn parse_key_val(s: &str) -> Result<HashMap<String, String>, Error> {
126119
let mut key_val_map = HashMap::new();
127120
let items = s.split(",").collect::<Vec<&str>>();

src/requests.rs

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use log::{debug, error, info, trace, warn};
66
use rand_distr::Distribution;
77
use rayon::iter::split;
88
use rayon::prelude::*;
9+
use reqwest::Url;
910
use reqwest_eventsource::{Error, Event, EventSource};
1011
use serde::{Deserialize, Serialize};
1112
use std::cmp::Ordering;
@@ -58,7 +59,7 @@ impl Clone for Box<dyn TextGenerationBackend + Send + Sync> {
5859
#[derive(Debug, Clone)]
5960
pub struct OpenAITextGenerationBackend {
6061
pub api_key: String,
61-
pub base_url: String,
62+
pub base_url: Url,
6263
pub model_name: String,
6364
pub client: reqwest::Client,
6465
pub tokenizer: Arc<Tokenizer>,
@@ -101,7 +102,7 @@ pub struct OpenAITextGenerationRequest {
101102
impl OpenAITextGenerationBackend {
102103
pub fn try_new(
103104
api_key: String,
104-
base_url: String,
105+
base_url: Url,
105106
model_name: String,
106107
tokenizer: Arc<Tokenizer>,
107108
timeout: time::Duration,
@@ -128,7 +129,9 @@ impl TextGenerationBackend for OpenAITextGenerationBackend {
128129
request: Arc<TextGenerationRequest>,
129130
sender: Sender<TextGenerationAggregatedResponse>,
130131
) {
131-
let url = format!("{base_url}/v1/chat/completions", base_url = self.base_url);
132+
let mut url = self.base_url.clone();
133+
url.set_path("/v1/chat/completions");
134+
// let url = format!("{base_url}", base_url = self.base_url);
132135
let mut aggregated_response = TextGenerationAggregatedResponse::new(request.clone());
133136
let messages = vec![OpenAITextGenerationMessage {
134137
role: "user".to_string(),
@@ -829,7 +832,7 @@ mod tests {
829832
w.write_all(b"data: [DONE]\n\n")
830833
})
831834
.create_async().await;
832-
let url = s.url();
835+
let url = s.url().parse().unwrap();
833836
let tokenizer = Arc::new(Tokenizer::from_pretrained("gpt2", None).unwrap());
834837
let backend = OpenAITextGenerationBackend::try_new(
835838
"".to_string(),
@@ -890,7 +893,7 @@ mod tests {
890893
w.write_all(b"data: [DONE]\n\n")
891894
})
892895
.create_async().await;
893-
let url = s.url();
896+
let url = s.url().parse().unwrap();
894897
let tokenizer = Arc::new(Tokenizer::from_pretrained("gpt2", None).unwrap());
895898
let backend = OpenAITextGenerationBackend::try_new(
896899
"".to_string(),
@@ -975,7 +978,7 @@ mod tests {
975978
.with_chunked_body(|w| w.write_all(b"data: {\"error\": \"Internal server error\"}\n\n"))
976979
.create_async()
977980
.await;
978-
let url = s.url();
981+
let url = s.url().parse().unwrap();
979982
let tokenizer = Arc::new(Tokenizer::from_pretrained("gpt2", None).unwrap());
980983
let backend = OpenAITextGenerationBackend::try_new(
981984
"".to_string(),
@@ -1021,7 +1024,7 @@ mod tests {
10211024
.with_chunked_body(|w| w.write_all(b"this is wrong\n\n"))
10221025
.create_async()
10231026
.await;
1024-
let url = s.url();
1027+
let url = s.url().parse().unwrap();
10251028
let tokenizer = Arc::new(Tokenizer::from_pretrained("gpt2", None).unwrap());
10261029
let backend = OpenAITextGenerationBackend::try_new(
10271030
"".to_string(),
@@ -1067,7 +1070,7 @@ mod tests {
10671070
.with_chunked_body(|w| w.write_all(b"data: {\"foo\": \"bar\"}\n\n"))
10681071
.create_async()
10691072
.await;
1070-
let url = s.url();
1073+
let url = s.url().parse().unwrap();
10711074
let tokenizer = Arc::new(Tokenizer::from_pretrained("gpt2", None).unwrap());
10721075
let backend = OpenAITextGenerationBackend::try_new(
10731076
"".to_string(),
@@ -1117,7 +1120,7 @@ mod tests {
11171120
w.write_all(b"data: [DONE]\n\n")
11181121
})
11191122
.create_async().await;
1120-
let url = s.url();
1123+
let url = s.url().parse().unwrap();
11211124
let tokenizer = Arc::new(Tokenizer::from_pretrained("gpt2", None).unwrap());
11221125
let backend = OpenAITextGenerationBackend::try_new(
11231126
"".to_string(),

src/scheduler.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ mod tests {
232232
w.write_all(b"data: [DONE]\n\n")
233233
})
234234
.create_async().await;
235-
let url = s.url();
235+
let url = s.url().parse().unwrap();
236236
let tokenizer = Arc::new(Tokenizer::from_pretrained("gpt2", None).unwrap());
237237
let backend = OpenAITextGenerationBackend::try_new(
238238
"".to_string(),

0 commit comments

Comments
 (0)