Skip to content

Commit

Permalink
moved to reqwest and added streaming event source as response
Browse files Browse the repository at this point in the history
  • Loading branch information
KCaverly committed Dec 4, 2023
1 parent 08deff4 commit a3f0a5e
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 70 deletions.
4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@ readme = "README.md"
[dependencies]
anyhow = "1.0.75"
lazy_static = "1.4.0"
isahc = "1.7.2"
serde = {version = "1.0.193", features = ["derive"]}
serde_json = "1.0"
erased-serde = "0.3.31"
futures-lite = "2.0.1"
reqwest = {version = "0.11.22", features = ["stream"]}
eventsource-stream = "0.2.3"
bytes = "1.5.0"

[dev-dependencies]
tokio = { version = "1.34.0", features = ["rt-multi-thread", "macros"] }
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# replicate-rs
A "work in progress" un-official minimal async client for [Replicate](https://replicate.com/).
Provides a simple wrapper for interacting with Replicate models with [serde](https://serde.rs/) and [isahc](https://docs.rs/isahc/latest/isahc/).
Provides a simple wrapper for interacting with Replicate models with [serde](https://serde.rs/) and [reqwest](https://crates.io/crates/reqwest).

<a href="https://crates.io/crates/replicate-rs"><img src="https://img.shields.io/crates/v/replicate-rs"></a>
<a href="https://docs.rs/replicate-rs/latest/replicate_rs/"><img src="https://img.shields.io/docsrs/replicate-rs"></a>
Expand Down
3 changes: 2 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@
//! .create(
//! "replicate",
//! "hello-world",
//! json!({"text": "kyle"})
//! json!({"text": "kyle"}),
//! false
//! )
//! .await
//! .unwrap();
Expand Down
59 changes: 27 additions & 32 deletions src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
//! - [List all Public Models](https://replicate.com/docs/reference/http#models.list)
//!
use anyhow::anyhow;
use futures_lite::io::AsyncReadExt;
use isahc::{prelude::*, Request};
use serde::Deserialize;
use serde_json::Value;

Expand Down Expand Up @@ -99,16 +97,15 @@ impl ModelClient {
let api_key = self.client.get_api_key()?;
let base_url = self.client.get_base_url();
let endpoint = format!("{base_url}/models/{owner}/{name}");
let response = Request::get(endpoint)
let client = reqwest::Client::new();
let response = client
.get(endpoint)
.header("Authorization", format!("Token {api_key}"))
.body({})?
.send_async()
.send()
.await?;

let mut bytes = Vec::new();
response.into_body().read_to_end(&mut bytes).await?;

let model: Model = serde_json::from_slice(&bytes)?;
let data = response.text().await?;
let model: Model = serde_json::from_str(&data)?;
anyhow::Ok(model)
}

Expand All @@ -122,16 +119,15 @@ impl ModelClient {
let api_key = self.client.get_api_key()?;
let base_url = self.client.get_base_url();
let endpoint = format!("{base_url}/models/{owner}/{name}/versions/{version_id}");
let response = Request::get(endpoint)
let client = reqwest::Client::new();
let response = client
.get(endpoint)
.header("Authorization", format!("Token {api_key}"))
.body({})?
.send_async()
.send()
.await?;

let mut bytes = Vec::new();
response.into_body().read_to_end(&mut bytes).await?;

let model: Model = serde_json::from_slice(&bytes)?;
let data = response.text().await?;
let model: Model = serde_json::from_str(&data)?;
anyhow::Ok(model)
}

Expand All @@ -154,20 +150,20 @@ impl ModelClient {
let base_url = self.client.get_base_url();
let api_key = self.client.get_api_key()?;
let endpoint = format!("{base_url}/models/{owner}/{name}/versions");
let mut response = Request::get(endpoint)
let client = reqwest::Client::new();
let response = client
.get(endpoint)
.header("Authorization", format!("Token {api_key}"))
.body({})?
.send_async()
.send()
.await?;

let mut bytes = Vec::new();
response.body_mut().read_to_end(&mut bytes).await?;

if response.status().is_success() {
let data: ModelVersions = serde_json::from_slice(&bytes)?;
let status = response.status();
let data = response.text().await?;
if status.is_success() {
let data: ModelVersions = serde_json::from_str(&data)?;
anyhow::Ok(data)
} else {
let data: ModelVersionError = serde_json::from_slice(&bytes)?;
let data: ModelVersionError = serde_json::from_str(&data)?;
Err(anyhow!(data.detail))
}
}
Expand All @@ -177,16 +173,15 @@ impl ModelClient {
let base_url = self.client.get_base_url();
let api_key = self.client.get_api_key()?;
let endpoint = format!("{base_url}/models");
let mut response = Request::get(endpoint)
let client = reqwest::Client::new();
let response = client
.get(endpoint)
.header("Authorization", format!("Token {api_key}"))
.body({})?
.send_async()
.send()
.await?;

let mut bytes = Vec::new();
response.body_mut().read_to_end(&mut bytes).await?;

let models: Models = serde_json::from_slice(&bytes)?;
let data = response.text().await?;
let models: Models = serde_json::from_str(&data)?;
anyhow::Ok(models)
}
}
Expand Down
101 changes: 66 additions & 35 deletions src/predictions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
use crate::config::ReplicateConfig;

use erased_serde::Serialize;
use futures_lite::io::AsyncReadExt;
use isahc::{prelude::*, Request};
use anyhow::anyhow;
use bytes::Bytes;
use eventsource_stream::{EventStream, Eventsource};
use serde_json::Value;

use crate::models::ModelClient;
Expand Down Expand Up @@ -41,6 +41,8 @@ pub struct PredictionUrls {
pub cancel: String,
/// Url endpoint to retrieve the specific prediction
pub get: String,
/// Url endpoint to receive streamed output
pub stream: Option<String>,
}

/// Details for a specific prediction
Expand Down Expand Up @@ -80,19 +82,40 @@ impl Prediction {
pub async fn reload(&mut self) -> anyhow::Result<()> {
let api_key = api_key()?;
let endpoint = self.urls.get.clone();
let mut response = Request::get(endpoint)
let client = reqwest::Client::new();
let response = client
.get(endpoint)
.header("Authorization", format!("Token {api_key}"))
.body({})?
.send_async()
.send()
.await?;

let mut data = String::new();
response.body_mut().read_to_string(&mut data).await?;

let data = response.text().await?;
let prediction: Prediction = serde_json::from_str(data.as_str())?;
*self = prediction;
anyhow::Ok(())
}

/// Get the stream from a prediction
pub async fn get_stream(
&mut self,
) -> anyhow::Result<EventStream<impl futures_lite::stream::Stream<Item = reqwest::Result<Bytes>>>>
{
if let Some(stream_url) = self.urls.stream.clone() {
let api_key = api_key()?;
let client = reqwest::Client::new();
let stream = client
.get(stream_url)
.header("Autorization", format!("Token {api_key}"))
.send()
.await?
.bytes_stream()
.eventsource();

return anyhow::Ok(stream);
} else {
return Err(anyhow!("prediction has no stream url available"));
}
}
}

/// A client for interacting with 'predictions' endpoint
Expand All @@ -105,6 +128,7 @@ pub struct PredictionClient {
struct PredictionInput {
version: String,
input: serde_json::Value,
stream: bool,
}

impl PredictionClient {
Expand All @@ -118,6 +142,7 @@ impl PredictionClient {
owner: &str,
name: &str,
input: serde_json::Value,
stream: bool,
) -> anyhow::Result<Prediction> {
let api_key = api_key()?;
let base_url = base_url();
Expand All @@ -126,17 +151,22 @@ impl PredictionClient {
let version = model_client.get_latest_version(owner, name).await?.id;

let endpoint = format!("{base_url}/predictions");
let input = PredictionInput { version, input };
let input = PredictionInput {
version,
input,
stream,
};
let body = serde_json::to_string(&input)?;
let response = Request::post(endpoint)
let client = reqwest::Client::new();
let response = client
.post(endpoint)
.header("Authorization", format!("Token {api_key}"))
.body(body)?
.send_async()
.body(body)
.send()
.await?;

let mut bytes = Vec::new();
response.into_body().read_to_end(&mut bytes).await?;
let prediction: Prediction = serde_json::from_slice(&bytes)?;
let data = response.text().await?;
let prediction: Prediction = serde_json::from_str(&data)?;

anyhow::Ok(prediction)
}
Expand All @@ -147,16 +177,15 @@ impl PredictionClient {
let base_url = self.config.get_base_url();

let endpoint = format!("{base_url}/predictions/{id}");
let response = Request::get(endpoint)
let client = reqwest::Client::new();
let response = client
.get(endpoint)
.header("Authorization", format!("Token {api_key}"))
.body({})?
.send_async()
.send()
.await?;

let mut bytes = Vec::new();
response.into_body().read_to_end(&mut bytes).await?;

let prediction: Prediction = serde_json::from_slice(&bytes)?;
let data = response.text().await?;
let prediction: Prediction = serde_json::from_str(&data)?;

anyhow::Ok(prediction)
}
Expand All @@ -167,15 +196,15 @@ impl PredictionClient {
let base_url = self.config.get_base_url();

let endpoint = format!("{base_url}/predictions");
let response = Request::get(endpoint)
let client = reqwest::Client::new();
let response = client
.get(endpoint)
.header("Authorization", format!("Token {api_key}"))
.body({})?
.send_async()
.send()
.await?;

let mut bytes = Vec::new();
response.into_body().read_to_end(&mut bytes).await?;
let predictions: Predictions = serde_json::from_slice(&bytes)?;
let data = response.text().await?;
let predictions: Predictions = serde_json::from_str(&data)?;

anyhow::Ok(predictions)
}
Expand All @@ -185,15 +214,15 @@ impl PredictionClient {
let api_key = self.config.get_api_key()?;
let base_url = self.config.get_base_url();
let endpoint = format!("{base_url}/predictions/{id}/cancel");
let response = Request::post(endpoint)
let client = reqwest::Client::new();
let response = client
.post(endpoint)
.header("Authorization", format!("Token {api_key}"))
.body({})?
.send_async()
.send()
.await?;

let mut bytes = Vec::new();
response.into_body().read_to_end(&mut bytes).await?;
let prediction: Prediction = serde_json::from_slice(&bytes)?;
let data = response.text().await?;
let prediction: Prediction = serde_json::from_str(&data)?;

anyhow::Ok(prediction)
}
Expand Down Expand Up @@ -290,6 +319,7 @@ mod tests {
"replicate",
"hello-world",
json!({"text": "This is test input"}),
false,
)
.await
.unwrap();
Expand Down Expand Up @@ -397,6 +427,7 @@ mod tests {
"replicate",
"hello-world",
json!({"text": "This is test input"}),
false,
)
.await
.unwrap();
Expand Down

0 comments on commit a3f0a5e

Please sign in to comment.