Skip to content

Commit

Permalink
updated predictions
Browse files Browse the repository at this point in the history
  • Loading branch information
KCaverly committed Nov 29, 2023
1 parent 65be8c2 commit 826f245
Showing 1 changed file with 134 additions and 62 deletions.
196 changes: 134 additions & 62 deletions src/predictions.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,29 @@
use std::io::Read;
//! Utilities for interacting with all prediction endpoints.
//!
//! This includes the following:
//! - [Create Prediction](https://replicate.com/docs/reference/http#predictions.create)
//! - [Get Prediction](https://replicate.com/docs/reference/http#predictions.get)
//!
//! # Example
//! ```rust
//! use replicate_rs::client::ReplicateClient;
//! use replicate_rs::predictions::PredictionClient;
//! ```
use crate::client::ReplicateClient;

use erased_serde::Serialize;
use futures_lite::io::AsyncReadExt;
use isahc::{prelude::*, Request};
use serde_json::Value;

use crate::api_key;
use crate::models::Model;
use crate::models::ModelClient;
use crate::{api_key, base_url};

#[derive(Debug)]
pub struct PredictionClient {
client: ReplicateClient,
}

#[derive(serde::Deserialize, Debug)]
pub struct Prediction {
Expand All @@ -19,8 +36,8 @@ pub struct Prediction {
pub urls: PredictionUrls,
}

#[serde(rename_all = "lowercase")]
#[derive(serde::Deserialize, Debug)]
#[serde(rename_all = "lowercase")]
pub enum PredictionStatus {
Starting,
Processing,
Expand All @@ -41,21 +58,24 @@ struct PredictionInput {
input: Box<dyn Serialize>,
}

impl Prediction {
pub async fn create_from_model_details(
impl PredictionClient {
pub fn from(client: ReplicateClient) -> Self {
PredictionClient { client }
}
pub async fn create(
&self,
owner: &str,
name: &str,
input: Box<dyn Serialize>,
) -> anyhow::Result<Prediction> {
let api_key = api_key()?;
let model = Model::get(owner, name).await?;

let version = model.latest_version.id;
let base_url = base_url();

let endpoint = "https://api.replicate.com/v1/predictions";
let model_client = ModelClient::from(self.client.clone());
let version = model_client.get_latest_version(owner, name).await?.id;

let endpoint = format!("{base_url}/predictions");
let input = PredictionInput { version, input };

let body = serde_json::to_string(&input)?;
let mut response = Request::post(endpoint)
.header("Authorization", format!("Token {api_key}"))
Expand All @@ -66,79 +86,131 @@ impl Prediction {
let mut data = String::new();
response.body_mut().read_to_string(&mut data).await?;

let prediction: Prediction = serde_json::from_str(data.as_str())?;

anyhow::Ok(prediction)
}

pub async fn reload(&mut self) -> anyhow::Result<()> {
let api_key = api_key()?;
let mut response = Request::get(&self.urls.get)
.header("Authorization", format!("Token {api_key}"))
.body({})?
.send_async()
.await?;

let mut data = String::new();
response.body_mut().read_to_string(&mut data).await?;
dbg!(&data);

let prediction: Prediction = serde_json::from_str(data.as_str())?;
*self = prediction;

anyhow::Ok(())
anyhow::Ok(prediction)
}
}

#[cfg(test)]
mod tests {
use httpmock::prelude::*;
use serde::Serialize;
use serde_json::json;

use super::*;

#[tokio::test]
async fn test_create_prediction() {
#[derive(Serialize)]
struct EmbeddingsInput {
texts: String,
batch_size: usize,
normalize_embeddings: bool,
convert_to_numpy: bool,
}

let input = Box::new(EmbeddingsInput {
texts: r#"["In the water, fish are swimming.", "Fish swim in the water."]"#.to_string(),
batch_size: 32,
normalize_embeddings: true,
convert_to_numpy: false,
async fn test_create() {
let server = MockServer::start();

let prediction_mock = server.mock(|when, then| {
when.method(POST).path("/predictions");
then.status(200).json_body_obj(&json!(
{
"id": "gm3qorzdhgbfurvjtvhg6dckhu",
"model": "replicate/hello-world",
"version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
"input": {
"text": "Alice"
},
"logs": "",
"error": null,
"status": "starting",
"created_at": "2023-09-08T16:19:34.765994657Z",
"urls": {
"cancel": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu/cancel",
"get": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu"
}
}
));
});

Prediction::create_from_model_details("nateraw", "bge-large-en-v1.5", input)
let model_mock = server.mock(|when, then| {
when.method(GET)
.path("/models/replicate/hello-world/versions");

then.status(200).json_body_obj(&json!({
"next": null,
"previous": null,
"results": [{
"id": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
"created_at": "2022-04-26T19:29:04.418669Z",
"cog_version": "0.3.0",
"openapi_schema": null
}]
}));
});

let mock_url = server.base_url();
let client = ReplicateClient::test(server.base_url()).unwrap();

let prediction_client = PredictionClient::from(client);
let prediction = prediction_client
.create(
"replicate",
"hello-world",
Box::new(json!({"text": "This is test input"})),
)
.await
.unwrap();
}

#[tokio::test]
async fn test_create_and_reload_prediction() {
#[derive(Serialize)]
struct EmbeddingsInput {
texts: String,
batch_size: usize,
normalize_embeddings: bool,
convert_to_numpy: bool,
}

let input = Box::new(EmbeddingsInput {
texts: r#"["In the water, fish are swimming.", "Fish swim in the water."]"#.to_string(),
batch_size: 32,
normalize_embeddings: true,
convert_to_numpy: false,
async fn test_create_and_reload() {
let server = MockServer::start();

let prediction_mock = server.mock(|when, then| {
when.method(POST).path("/predictions");
then.status(200).json_body_obj(&json!(
{
"id": "gm3qorzdhgbfurvjtvhg6dckhu",
"model": "replicate/hello-world",
"version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
"input": {
"text": "Alice"
},
"logs": "",
"error": null,
"status": "starting",
"created_at": "2023-09-08T16:19:34.765994657Z",
"urls": {
"cancel": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu/cancel",
"get": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu"
}
}
));
});

let model_mock = server.mock(|when, then| {
when.method(GET)
.path("/models/replicate/hello-world/versions");

then.status(200).json_body_obj(&json!({
"next": null,
"previous": null,
"results": [{
"id": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
"created_at": "2022-04-26T19:29:04.418669Z",
"cog_version": "0.3.0",
"openapi_schema": null
}]
}));
});

let mut prediction =
Prediction::create_from_model_details("nateraw", "bge-large-en-v1.5", input)
.await
.unwrap();
let mock_url = server.base_url();
let client = ReplicateClient::test(server.base_url()).unwrap();

prediction.reload().await.unwrap();
let prediction_client = PredictionClient::from(client);
let prediction = prediction_client
.create(
"replicate",
"hello-world",
Box::new(json!({"text": "This is test input"})),
)
.await
.unwrap();
}
}

0 comments on commit 826f245

Please sign in to comment.