diff --git a/README.md b/README.md index e3c9d5b..88e49c2 100644 --- a/README.md +++ b/README.md @@ -14,11 +14,11 @@ Provides simple type safe functionality for interacting with Replicate models wi #### Models - [ ] [Create a Model](https://replicate.com/docs/reference/http#models.create) -- [ ] [Get a Model](https://replicate.com/docs/reference/http#models.get) -- [ ] [Get a Model Version](https://replicate.com/docs/reference/http#models.versions.get) -- [ ] [List a Model's Versions](https://replicate.com/docs/reference/http#models.versions.list) +- [x] [Get a Model](https://replicate.com/docs/reference/http#models.get) +- [x] [Get a Model Version](https://replicate.com/docs/reference/http#models.versions.get) +- [x] [List a Model's Versions](https://replicate.com/docs/reference/http#models.versions.list) - [ ] [Delete a Model Version](https://replicate.com/docs/reference/http#models.versions.delete) -- [ ] [List Public Models](https://replicate.com/docs/reference/http#models.list) +- [x] [List Public Models](https://replicate.com/docs/reference/http#models.list) #### Collections - [ ] [Get a Collection of Models](https://replicate.com/docs/reference/http#collections.get) diff --git a/src/models.rs b/src/models.rs index edd9498..008aaad 100644 --- a/src/models.rs +++ b/src/models.rs @@ -7,50 +7,75 @@ use serde_json::Value; use crate::config::ReplicateConfig; #[derive(Debug, Deserialize)] -pub struct ModelVersionError { +struct ModelVersionError { detail: String, } +/// Version details for a particular model #[derive(Debug, Deserialize, Clone)] pub struct ModelVersion { + /// Id of the model pub id: String, + /// Time in which the model was created pub created_at: String, + /// Version of cog used to create the model pub cog_version: String, + /// OpenAPI Schema of model input and outputs pub openapi_schema: serde_json::Value, } +/// Paginated view of all versions for a particular model #[derive(Debug, Deserialize)] pub struct ModelVersions { + /// Place in pagination pub next: Option, + /// Place in pagination pub previous: Option, + /// List of all versions available pub results: Vec, } +/// All details available for a particular Model #[derive(Deserialize, Debug)] pub struct Model { - url: String, - owner: String, - name: String, - description: String, - visibility: String, - github_url: String, - paper_url: Option, - license_url: Option, - run_count: usize, - cover_image_url: String, - default_example: Value, - pub(crate) latest_version: ModelVersion, + /// URL for model homepage + pub url: String, + /// The owner of the model + pub owner: String, + /// The name of the model + pub name: String, + /// A brief description of the model + pub description: String, + /// Whether the model is public or private + pub visibility: String, + /// Github URL for the associated repo + pub github_url: String, + /// Url for an associated paper + pub paper_url: Option, + /// Url for the model's license + pub license_url: Option, + /// How many times the model has been run + pub run_count: usize, + /// Image URL to show on Replicate's Model page + pub cover_image_url: String, + /// A simple example to show model's use + pub default_example: Value, + /// The latest version's details + pub latest_version: ModelVersion, } +/// A client for interacting with `models` endpoints pub struct ModelClient { client: ReplicateConfig, } impl ModelClient { + /// Create a new `ModelClient` based upon a `ReplicateConfig` object pub fn from(client: ReplicateConfig) -> Self { ModelClient { client } } + /// Retrieve details for a specific model pub async fn get(&self, owner: &str, name: &str) -> anyhow::Result { let api_key = self.client.get_api_key()?; let base_url = self.client.get_base_url(); @@ -68,6 +93,30 @@ impl ModelClient { anyhow::Ok(model) } + /// Retrieve details for a specific model's version + pub async fn get_specific_version( + &self, + owner: &str, + name: &str, + version_id: &str, + ) -> anyhow::Result { + let api_key = self.client.get_api_key()?; + let base_url = self.client.get_base_url(); + let endpoint = format!("{base_url}/models/{owner}/{name}/versions/{version_id}"); + let mut response = Request::get(endpoint) + .header("Authorization", format!("Token {api_key}")) + .body({})? + .send_async() + .await?; + + let mut data = String::new(); + response.body_mut().read_to_string(&mut data).await?; + + let model: Model = serde_json::from_str(data.as_str())?; + anyhow::Ok(model) + } + + /// Retrieve details for latest version of a specific model pub async fn get_latest_version( &self, owner: &str, @@ -81,6 +130,7 @@ impl ModelClient { anyhow::Ok(latest_version.clone()) } + /// Retrieve list of all available versions of a specific model pub async fn list_versions(&self, owner: &str, name: &str) -> anyhow::Result { let base_url = self.client.get_base_url(); let api_key = self.client.get_api_key()?; @@ -139,11 +189,48 @@ mod tests { let client = ReplicateConfig::test(mock_server.base_url()).unwrap(); let model_client = ModelClient::from(client); - let model = model_client.get("replicate", "hello-world").await.unwrap(); + model_client.get("replicate", "hello-world").await.unwrap(); model_mock.assert(); } + #[tokio::test] + async fn test_get_specific_version() { + let mock_server = MockServer::start(); + + let model_mock = mock_server.mock(|when, then| { + when.method(GET) + .path("/models/replicate/hello-world/versions/1234"); + then.status(200).json_body_obj(&json!({ + "url": "https://replicate.com/replicate/hello-world", + "owner": "replicate", + "name": "hello-world", + "description": "A tiny model that says hello", + "visibility": "public", + "github_url": "https://github.com/replicate/cog-examples", + "paper_url": null, + "license_url": null, + "run_count": 5681081, + "cover_image_url": "...", + "default_example": null, + "latest_version": { + "id": "1234", + "created_at": "2022-04-26T19:29:04.418669Z", + "cog_version": "0.3.0", + "openapi_schema": {} + } + })); + }); + + let client = ReplicateConfig::test(mock_server.base_url()).unwrap(); + let model_client = ModelClient::from(client); + model_client + .get_specific_version("replicate", "hello-world", "1234") + .await + .unwrap(); + + model_mock.assert(); + } #[tokio::test] async fn test_list_model_versions() { let mock_server = MockServer::start(); @@ -167,7 +254,7 @@ mod tests { let client = ReplicateConfig::test(mock_server.base_url()).unwrap(); let model_client = ModelClient::from(client); - let model = model_client + model_client .list_versions("replicate", "hello-world") .await .unwrap(); @@ -198,7 +285,7 @@ mod tests { let client = ReplicateConfig::test(mock_server.base_url()).unwrap(); let model_client = ModelClient::from(client); - let model = model_client + model_client .get_latest_version("replicate", "hello-world") .await .unwrap(); diff --git a/src/predictions.rs b/src/predictions.rs index b8c5566..c94b649 100644 --- a/src/predictions.rs +++ b/src/predictions.rs @@ -93,7 +93,7 @@ impl Prediction { } } -/// A client namespace for interacting with 'predictions' endpoint +/// A client for interacting with 'predictions' endpoint #[derive(Debug)] pub struct PredictionClient { config: ReplicateConfig,