From a3f0a5eae4c029e8665d060badbba20bf1647152 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Sun, 3 Dec 2023 20:51:38 -0500 Subject: [PATCH] moved to reqwest and added streaming event source as response --- Cargo.toml | 4 +- README.md | 2 +- src/lib.rs | 3 +- src/models.rs | 59 ++++++++++++-------------- src/predictions.rs | 101 +++++++++++++++++++++++++++++---------------- 5 files changed, 99 insertions(+), 70 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 6425fe3..bf46186 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,11 +12,13 @@ readme = "README.md" [dependencies] anyhow = "1.0.75" lazy_static = "1.4.0" -isahc = "1.7.2" serde = {version = "1.0.193", features = ["derive"]} serde_json = "1.0" erased-serde = "0.3.31" futures-lite = "2.0.1" +reqwest = {version = "0.11.22", features = ["stream"]} +eventsource-stream = "0.2.3" +bytes = "1.5.0" [dev-dependencies] tokio = { version = "1.34.0", features = ["rt-multi-thread", "macros"] } diff --git a/README.md b/README.md index 161c7f8..0490a5c 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # replicate-rs A "work in progress" un-official minimal async client for [Replicate](https://replicate.com/). -Provides a simple wrapper for interacting with Replicate models with [serde](https://serde.rs/) and [isahc](https://docs.rs/isahc/latest/isahc/). +Provides a simple wrapper for interacting with Replicate models with [serde](https://serde.rs/) and [reqwest](https://crates.io/crates/reqwest). diff --git a/src/lib.rs b/src/lib.rs index 079f9c4..1b1d359 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -35,7 +35,8 @@ //! .create( //! "replicate", //! "hello-world", -//! json!({"text": "kyle"}) +//! json!({"text": "kyle"}), +//! false //! ) //! .await //! .unwrap(); diff --git a/src/models.rs b/src/models.rs index ebe9109..de8accf 100644 --- a/src/models.rs +++ b/src/models.rs @@ -7,8 +7,6 @@ //! - [List all Public Models](https://replicate.com/docs/reference/http#models.list) //! use anyhow::anyhow; -use futures_lite::io::AsyncReadExt; -use isahc::{prelude::*, Request}; use serde::Deserialize; use serde_json::Value; @@ -99,16 +97,15 @@ impl ModelClient { let api_key = self.client.get_api_key()?; let base_url = self.client.get_base_url(); let endpoint = format!("{base_url}/models/{owner}/{name}"); - let response = Request::get(endpoint) + let client = reqwest::Client::new(); + let response = client + .get(endpoint) .header("Authorization", format!("Token {api_key}")) - .body({})? - .send_async() + .send() .await?; - let mut bytes = Vec::new(); - response.into_body().read_to_end(&mut bytes).await?; - - let model: Model = serde_json::from_slice(&bytes)?; + let data = response.text().await?; + let model: Model = serde_json::from_str(&data)?; anyhow::Ok(model) } @@ -122,16 +119,15 @@ impl ModelClient { let api_key = self.client.get_api_key()?; let base_url = self.client.get_base_url(); let endpoint = format!("{base_url}/models/{owner}/{name}/versions/{version_id}"); - let response = Request::get(endpoint) + let client = reqwest::Client::new(); + let response = client + .get(endpoint) .header("Authorization", format!("Token {api_key}")) - .body({})? - .send_async() + .send() .await?; - let mut bytes = Vec::new(); - response.into_body().read_to_end(&mut bytes).await?; - - let model: Model = serde_json::from_slice(&bytes)?; + let data = response.text().await?; + let model: Model = serde_json::from_str(&data)?; anyhow::Ok(model) } @@ -154,20 +150,20 @@ impl ModelClient { let base_url = self.client.get_base_url(); let api_key = self.client.get_api_key()?; let endpoint = format!("{base_url}/models/{owner}/{name}/versions"); - let mut response = Request::get(endpoint) + let client = reqwest::Client::new(); + let response = client + .get(endpoint) .header("Authorization", format!("Token {api_key}")) - .body({})? - .send_async() + .send() .await?; - let mut bytes = Vec::new(); - response.body_mut().read_to_end(&mut bytes).await?; - - if response.status().is_success() { - let data: ModelVersions = serde_json::from_slice(&bytes)?; + let status = response.status(); + let data = response.text().await?; + if status.is_success() { + let data: ModelVersions = serde_json::from_str(&data)?; anyhow::Ok(data) } else { - let data: ModelVersionError = serde_json::from_slice(&bytes)?; + let data: ModelVersionError = serde_json::from_str(&data)?; Err(anyhow!(data.detail)) } } @@ -177,16 +173,15 @@ impl ModelClient { let base_url = self.client.get_base_url(); let api_key = self.client.get_api_key()?; let endpoint = format!("{base_url}/models"); - let mut response = Request::get(endpoint) + let client = reqwest::Client::new(); + let response = client + .get(endpoint) .header("Authorization", format!("Token {api_key}")) - .body({})? - .send_async() + .send() .await?; - let mut bytes = Vec::new(); - response.body_mut().read_to_end(&mut bytes).await?; - - let models: Models = serde_json::from_slice(&bytes)?; + let data = response.text().await?; + let models: Models = serde_json::from_str(&data)?; anyhow::Ok(models) } } diff --git a/src/predictions.rs b/src/predictions.rs index 7892f9f..b278528 100644 --- a/src/predictions.rs +++ b/src/predictions.rs @@ -9,9 +9,9 @@ use crate::config::ReplicateConfig; -use erased_serde::Serialize; -use futures_lite::io::AsyncReadExt; -use isahc::{prelude::*, Request}; +use anyhow::anyhow; +use bytes::Bytes; +use eventsource_stream::{EventStream, Eventsource}; use serde_json::Value; use crate::models::ModelClient; @@ -41,6 +41,8 @@ pub struct PredictionUrls { pub cancel: String, /// Url endpoint to retrieve the specific prediction pub get: String, + /// Url endpoint to receive streamed output + pub stream: Option, } /// Details for a specific prediction @@ -80,19 +82,40 @@ impl Prediction { pub async fn reload(&mut self) -> anyhow::Result<()> { let api_key = api_key()?; let endpoint = self.urls.get.clone(); - let mut response = Request::get(endpoint) + let client = reqwest::Client::new(); + let response = client + .get(endpoint) .header("Authorization", format!("Token {api_key}")) - .body({})? - .send_async() + .send() .await?; - let mut data = String::new(); - response.body_mut().read_to_string(&mut data).await?; - + let data = response.text().await?; let prediction: Prediction = serde_json::from_str(data.as_str())?; *self = prediction; anyhow::Ok(()) } + + /// Get the stream from a prediction + pub async fn get_stream( + &mut self, + ) -> anyhow::Result>>> + { + if let Some(stream_url) = self.urls.stream.clone() { + let api_key = api_key()?; + let client = reqwest::Client::new(); + let stream = client + .get(stream_url) + .header("Autorization", format!("Token {api_key}")) + .send() + .await? + .bytes_stream() + .eventsource(); + + return anyhow::Ok(stream); + } else { + return Err(anyhow!("prediction has no stream url available")); + } + } } /// A client for interacting with 'predictions' endpoint @@ -105,6 +128,7 @@ pub struct PredictionClient { struct PredictionInput { version: String, input: serde_json::Value, + stream: bool, } impl PredictionClient { @@ -118,6 +142,7 @@ impl PredictionClient { owner: &str, name: &str, input: serde_json::Value, + stream: bool, ) -> anyhow::Result { let api_key = api_key()?; let base_url = base_url(); @@ -126,17 +151,22 @@ impl PredictionClient { let version = model_client.get_latest_version(owner, name).await?.id; let endpoint = format!("{base_url}/predictions"); - let input = PredictionInput { version, input }; + let input = PredictionInput { + version, + input, + stream, + }; let body = serde_json::to_string(&input)?; - let response = Request::post(endpoint) + let client = reqwest::Client::new(); + let response = client + .post(endpoint) .header("Authorization", format!("Token {api_key}")) - .body(body)? - .send_async() + .body(body) + .send() .await?; - let mut bytes = Vec::new(); - response.into_body().read_to_end(&mut bytes).await?; - let prediction: Prediction = serde_json::from_slice(&bytes)?; + let data = response.text().await?; + let prediction: Prediction = serde_json::from_str(&data)?; anyhow::Ok(prediction) } @@ -147,16 +177,15 @@ impl PredictionClient { let base_url = self.config.get_base_url(); let endpoint = format!("{base_url}/predictions/{id}"); - let response = Request::get(endpoint) + let client = reqwest::Client::new(); + let response = client + .get(endpoint) .header("Authorization", format!("Token {api_key}")) - .body({})? - .send_async() + .send() .await?; - let mut bytes = Vec::new(); - response.into_body().read_to_end(&mut bytes).await?; - - let prediction: Prediction = serde_json::from_slice(&bytes)?; + let data = response.text().await?; + let prediction: Prediction = serde_json::from_str(&data)?; anyhow::Ok(prediction) } @@ -167,15 +196,15 @@ impl PredictionClient { let base_url = self.config.get_base_url(); let endpoint = format!("{base_url}/predictions"); - let response = Request::get(endpoint) + let client = reqwest::Client::new(); + let response = client + .get(endpoint) .header("Authorization", format!("Token {api_key}")) - .body({})? - .send_async() + .send() .await?; - let mut bytes = Vec::new(); - response.into_body().read_to_end(&mut bytes).await?; - let predictions: Predictions = serde_json::from_slice(&bytes)?; + let data = response.text().await?; + let predictions: Predictions = serde_json::from_str(&data)?; anyhow::Ok(predictions) } @@ -185,15 +214,15 @@ impl PredictionClient { let api_key = self.config.get_api_key()?; let base_url = self.config.get_base_url(); let endpoint = format!("{base_url}/predictions/{id}/cancel"); - let response = Request::post(endpoint) + let client = reqwest::Client::new(); + let response = client + .post(endpoint) .header("Authorization", format!("Token {api_key}")) - .body({})? - .send_async() + .send() .await?; - let mut bytes = Vec::new(); - response.into_body().read_to_end(&mut bytes).await?; - let prediction: Prediction = serde_json::from_slice(&bytes)?; + let data = response.text().await?; + let prediction: Prediction = serde_json::from_str(&data)?; anyhow::Ok(prediction) } @@ -290,6 +319,7 @@ mod tests { "replicate", "hello-world", json!({"text": "This is test input"}), + false, ) .await .unwrap(); @@ -397,6 +427,7 @@ mod tests { "replicate", "hello-world", json!({"text": "This is test input"}), + false, ) .await .unwrap();