Skip to content

Commit a3864e6

Browse files
committed
Tmp
1 parent 8096d46 commit a3864e6

File tree

6 files changed

+166
-34
lines changed

6 files changed

+166
-34
lines changed

flake.lock

Lines changed: 61 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

flake.nix

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
{
2+
description = "Inference Benchmarker - A terminal-based benchmarker for LLMs";
3+
4+
inputs = {
5+
nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable";
6+
flake-utils.url = "github:numtide/flake-utils";
7+
};
8+
9+
outputs =
10+
{
11+
nixpkgs,
12+
flake-utils,
13+
...
14+
}:
15+
flake-utils.lib.eachDefaultSystem (
16+
system:
17+
let
18+
pkgs = import nixpkgs {
19+
inherit system;
20+
};
21+
in
22+
{
23+
devShells.default = pkgs.mkShell {
24+
buildInputs = with pkgs; [
25+
rustup
26+
pkg-config
27+
openssl
28+
];
29+
30+
};
31+
}
32+
);
33+
}

src/benchmark.rs

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,63 @@ impl Benchmark {
337337
}
338338

339339
pub async fn run_perf(&mut self) -> anyhow::Result<()> {
340+
info!("Running performance benchmark");
341+
342+
let id = "performance".to_string();
343+
344+
// notify start event
345+
self.event_bus.send(Event::BenchmarkStart(BenchmarkEvent {
346+
id: id.clone(),
347+
scheduler_type: ExecutorType::ConstantVUs,
348+
request_throughput: None,
349+
progress: 0.0,
350+
results: None,
351+
successful_requests: 0,
352+
failed_requests: 0,
353+
}))?;
354+
355+
// create progress handler
356+
let tx = self.handle_progress(id.clone()).await;
357+
358+
let mut successful_requests = 0u64;
359+
let mut failed_requests = 0u64;
360+
361+
for i in (1usize..2).map(|i| i.pow(2)) {
362+
// start scheduler
363+
let mut scheduler = scheduler::Scheduler::new(
364+
id.clone(),
365+
self.backend.clone(),
366+
ExecutorType::ConstantVUs,
367+
executors::ExecutorConfig {
368+
max_vus: i as u64,
369+
duration: self.config.duration,
370+
rate: None,
371+
},
372+
self.requests.clone(),
373+
tx.clone(),
374+
self.stop_sender.clone(),
375+
);
376+
scheduler.run().await?;
377+
let results = scheduler.get_results().lock().await.clone();
378+
info!("Result {results:?}");
379+
self.report.add_benchmark_result(results.clone());
380+
successful_requests += results.successful_requests() as u64;
381+
failed_requests += results.failed_requests() as u64;
382+
}
383+
384+
// send None to close the progress handler
385+
tx.send(None).await.unwrap();
386+
387+
// notify end event
388+
self.event_bus.send(Event::BenchmarkEnd(BenchmarkEvent {
389+
id: id.clone(),
390+
scheduler_type: ExecutorType::ConstantVUs,
391+
request_throughput: Some(0.0),
392+
progress: 100.0,
393+
results: None,
394+
successful_requests,
395+
failed_requests,
396+
}))?;
340397
Ok(())
341398
}
342399

src/main.rs

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
use anyhow::Result;
12
use clap::error::ErrorKind::InvalidValue;
23
use clap::{ArgGroup, Error, Parser};
34
use inference_benchmarker::{run, BenchmarkKind, RunConfiguration, TokenizeOptions};
4-
use log::{debug, error};
5+
use log::debug;
56
use reqwest::Url;
67
use std::collections::HashMap;
78
use std::time::Duration;
@@ -23,7 +24,7 @@ struct Args {
2324
#[clap(default_value = "128", short, long, env, group = "group_manual")]
2425
max_vus: u64,
2526
/// The duration of each benchmark step
26-
#[clap(default_value = "120s", short, long, env, group = "group_manual")]
27+
#[clap(default_value = "10s", short, long, env, group = "group_manual")]
2728
#[arg(value_parser = parse_duration)]
2829
duration: Duration,
2930
/// A list of rates of requests to send per second (only valid for the ConstantArrivalRate benchmark).
@@ -38,7 +39,7 @@ struct Args {
3839
profile: Option<String>,
3940
/// The kind of benchmark to run (throughput, sweep, optimum)
4041
#[clap(
41-
default_value = "sweep",
42+
default_value = "perf",
4243
short,
4344
long,
4445
env,
@@ -176,7 +177,7 @@ fn parse_tokenizer_options(s: &str) -> Result<TokenizeOptions, Error> {
176177
}
177178

178179
#[tokio::main]
179-
async fn main() {
180+
async fn main() -> Result<()> {
180181
let args = Args::parse();
181182
let git_sha = option_env!("VERGEN_GIT_SHA").unwrap_or("unknown");
182183
println!(
@@ -234,14 +235,5 @@ async fn main() {
234235
model_name,
235236
run_id,
236237
};
237-
let main_thread = tokio::spawn(async move {
238-
match run(run_config, stop_sender_clone).await {
239-
Ok(_) => {}
240-
Err(e) => {
241-
error!("Fatal: {:?}", e);
242-
println!("Fatal: {:?}", e)
243-
}
244-
};
245-
});
246-
let _ = main_thread.await;
238+
run(run_config, stop_sender_clone).await
247239
}

src/requests.rs

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ impl TextGenerationBackend for OpenAITextGenerationBackend {
140140
let body = OpenAITextGenerationRequest {
141141
model: self.model_name.clone(),
142142
messages,
143-
max_tokens: request.num_decode_tokens,
143+
max_tokens: Some(20),
144144
stream: true,
145145
stop: None,
146146
temperature: 0.0,
@@ -154,7 +154,6 @@ impl TextGenerationBackend for OpenAITextGenerationBackend {
154154
)
155155
.json(&serde_json::json!(body))
156156
.timeout(self.timeout);
157-
info!("Sending request");
158157
// start timer
159158
aggregated_response.start();
160159
let mut es = EventSource::new(req).unwrap();
@@ -218,24 +217,14 @@ impl TextGenerationBackend for OpenAITextGenerationBackend {
218217
};
219218
}
220219
Err(e) => {
221-
error!("Got SSE error : {e}");
222220
match e {
223-
Error::Utf8(_) => {
224-
aggregated_response.fail();
225-
}
226-
Error::Parser(_) => {
227-
aggregated_response.fail();
228-
}
229-
Error::Transport(_) => {
230-
aggregated_response.fail();
231-
}
232-
Error::InvalidContentType(_, _) => {
233-
aggregated_response.fail();
234-
}
235-
Error::InvalidStatusCode(_, _) => {
236-
aggregated_response.fail();
237-
}
238-
Error::InvalidLastEventId(_) => {
221+
Error::Utf8(_)
222+
| Error::Parser(_)
223+
| Error::Transport(_)
224+
| Error::InvalidContentType(_, _)
225+
| Error::InvalidStatusCode(_, _)
226+
| Error::InvalidLastEventId(_) => {
227+
error!("Got SSE error : {e}");
239228
aggregated_response.fail();
240229
}
241230
Error::StreamEnded => {

src/results.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ impl BenchmarkResults {
241241
/// Calculate the quantile of a given data set using interpolation method
242242
/// Results are similar to `numpy.percentile`
243243
fn quantile_duration(&self, mut data: Vec<Duration>, quantile: f64) -> anyhow::Result<f64> {
244-
if self.is_ready() {
244+
if self.is_ready() && data.len() > 1 {
245245
data.sort();
246246
let i = (quantile * (data.len() - 1) as f64).floor();
247247
let delta = (data.len() - 1) as f64 * quantile - i;

0 commit comments

Comments
 (0)