Skip to content

Commit c2ce8b4

Browse files
author
Dongri Jin
committed
Initial commit
0 parents  commit c2ce8b4

12 files changed

+599
-0
lines changed

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
/target
2+
/Cargo.lock

Cargo.toml

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
[package]
2+
name = "openai-rs"
3+
version = "0.1.0"
4+
edition = "2021"
5+
6+
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
7+
8+
[dependencies]
9+
reqwest = { version = "0.11", features = ["json"] }
10+
tokio = { version = "1", features = ["full"] }
11+
serde = { version = "1", features = ["derive"] }

examples/completion.rs

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
use openai_rs::v1::completion::{self, CompletionRequest};
2+
use openai_rs::v1::api::Client;
3+
use std::env;
4+
5+
#[tokio::main]
6+
async fn main() -> Result<(), Box<dyn std::error::Error>> {
7+
let client = Client::new(env::var("OPENAI_API_KEY").unwrap().to_string());
8+
let req = CompletionRequest {
9+
model: completion::GPT3_TEXT_DAVINCI_003.to_string(),
10+
prompt: Some(String::from("NFTとは何か?")),
11+
suffix: None,
12+
max_tokens: Some(3000),
13+
temperature: Some(0.9),
14+
top_p: Some(1.0),
15+
n: None,
16+
stream: None,
17+
logprobs: None,
18+
echo: None,
19+
stop: Some(vec![String::from(" Human:"), String::from(" AI:")]),
20+
presence_penalty: Some(0.6),
21+
frequency_penalty: Some(0.0),
22+
best_of: None,
23+
logit_bias: None,
24+
user: None,
25+
};
26+
let completion_response = client.completion(req).await?;
27+
println!("{:?}", completion_response.choices[0].text);
28+
29+
Ok(())
30+
}
31+
32+
// cargo run --package openai-rs --example completion

src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
pub mod v1;

src/v1/api.rs

+250
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
2+
use crate::v1::completion::{CompletionRequest, CompletionResponse};
3+
use crate::v1::edit::{EditRequest, EditResponse};
4+
use crate::v1::image::{
5+
ImageGenerationRequest,
6+
ImageGenerationResponse,
7+
ImageEditRequest,
8+
ImageEditResponse,
9+
ImageVariationRequest,
10+
ImageVariationResponse,
11+
};
12+
use crate::v1::embedding::{EmbeddingRequest, EmbeddingResponse};
13+
use crate::v1::file::{
14+
FileListResponse,
15+
FileUploadRequest,
16+
FileUploadResponse,
17+
FileDeleteRequest,
18+
FileDeleteResponse,
19+
FileRetrieveRequest,
20+
FileRetrieveResponse,
21+
FileRetrieveContentRequest,
22+
FileRetrieveContentResponse,
23+
};
24+
use reqwest::Response;
25+
26+
const APU_URL_V1: &str = "https://api.openai.com/v1";
27+
28+
pub struct Client {
29+
pub api_key: String,
30+
}
31+
32+
impl Client {
33+
pub fn new(api_key: String) -> Self {
34+
Self { api_key }
35+
}
36+
37+
pub async fn post<T:serde::ser::Serialize>(&self, path: &str, params: &T) -> Result<Response, Box<dyn std::error::Error>> {
38+
let client = reqwest::Client::new();
39+
let url = format!("{}{}", APU_URL_V1, path);
40+
let res = client
41+
.post(&url)
42+
.header(reqwest::header::CONTENT_TYPE, "application/json")
43+
.header(reqwest::header::AUTHORIZATION, "Bearer ".to_owned() + &self.api_key)
44+
.json(&params)
45+
.send()
46+
.await;
47+
match res {
48+
Ok(res) => match res.status().is_success() {
49+
true => Ok(res),
50+
false => {
51+
Err(Box::new(std::io::Error::new(
52+
std::io::ErrorKind::Other,
53+
format!("{}: {}", res.status(), res.text().await.unwrap())
54+
)))
55+
},
56+
},
57+
Err(e) => Err(Box::new(e)),
58+
}
59+
}
60+
61+
pub async fn get(&self, path: &str) -> Result<Response, Box<dyn std::error::Error>> {
62+
let client = reqwest::Client::new();
63+
let url = format!("{}{}", APU_URL_V1, path);
64+
let res = client
65+
.get(&url)
66+
.header(reqwest::header::CONTENT_TYPE, "application/json")
67+
.header(reqwest::header::AUTHORIZATION, "Bearer ".to_owned() + &self.api_key)
68+
.send()
69+
.await;
70+
match res {
71+
Ok(res) => match res.status().is_success() {
72+
true => Ok(res),
73+
false => {
74+
Err(Box::new(std::io::Error::new(
75+
std::io::ErrorKind::Other,
76+
format!("{}: {}", res.status(), res.text().await.unwrap())
77+
)))
78+
},
79+
},
80+
Err(e) => Err(Box::new(e)),
81+
}
82+
}
83+
84+
pub async fn delete(&self, path: &str) -> Result<Response, Box<dyn std::error::Error>> {
85+
let client = reqwest::Client::new();
86+
let url = format!("{}{}", APU_URL_V1, path);
87+
let res = client
88+
.delete(&url)
89+
.header(reqwest::header::CONTENT_TYPE, "application/json")
90+
.header(reqwest::header::AUTHORIZATION, "Bearer ".to_owned() + &self.api_key)
91+
.send()
92+
.await;
93+
match res {
94+
Ok(res) => match res.status().is_success() {
95+
true => Ok(res),
96+
false => {
97+
Err(Box::new(std::io::Error::new(
98+
std::io::ErrorKind::Other,
99+
format!("{}: {}", res.status(), res.text().await.unwrap())
100+
)))
101+
},
102+
},
103+
Err(e) => Err(Box::new(e)),
104+
}
105+
}
106+
107+
pub async fn completion(&self, req: CompletionRequest) -> Result<CompletionResponse, Box<dyn std::error::Error>> {
108+
let res = self.post("/completions", &req).await;
109+
match res {
110+
Ok(res) => {
111+
let r = res.json::<CompletionResponse>().await?;
112+
return Ok(r);
113+
},
114+
Err(e) => {
115+
return Err(e);
116+
},
117+
}
118+
}
119+
120+
pub async fn edit(&self, req: EditRequest) -> Result<EditResponse, Box<dyn std::error::Error>> {
121+
let res = self.post("/edits", &req).await;
122+
match res {
123+
Ok(res) => {
124+
let r = res.json::<EditResponse>().await?;
125+
return Ok(r);
126+
},
127+
Err(e) => {
128+
return Err(e);
129+
},
130+
}
131+
}
132+
133+
pub async fn image_generation(&self, req: ImageGenerationRequest) -> Result<ImageGenerationResponse, Box<dyn std::error::Error>> {
134+
let res = self.post("/images/generations", &req).await;
135+
match res {
136+
Ok(res) => {
137+
let r = res.json::<ImageGenerationResponse>().await?;
138+
return Ok(r);
139+
},
140+
Err(e) => {
141+
return Err(e);
142+
},
143+
}
144+
}
145+
146+
pub async fn image_edit(&self, req: ImageEditRequest) -> Result<ImageEditResponse, Box<dyn std::error::Error>> {
147+
let res = self.post("/images/edits", &req).await;
148+
match res {
149+
Ok(res) => {
150+
let r = res.json::<ImageEditResponse>().await?;
151+
return Ok(r);
152+
},
153+
Err(e) => {
154+
return Err(e);
155+
},
156+
}
157+
}
158+
159+
pub async fn image_variation(&self, req: ImageVariationRequest) -> Result<ImageVariationResponse, Box<dyn std::error::Error>> {
160+
let res = self.post("/images/variations", &req).await;
161+
match res {
162+
Ok(res) => {
163+
let r = res.json::<ImageVariationResponse>().await?;
164+
return Ok(r);
165+
},
166+
Err(e) => {
167+
return Err(e);
168+
},
169+
}
170+
}
171+
172+
pub async fn embedding(&self, req: EmbeddingRequest) -> Result<EmbeddingResponse, Box<dyn std::error::Error>> {
173+
let res = self.post("/embeddings", &req).await;
174+
match res {
175+
Ok(res) => {
176+
let r = res.json::<EmbeddingResponse>().await?;
177+
return Ok(r);
178+
},
179+
Err(e) => {
180+
return Err(e);
181+
},
182+
}
183+
}
184+
185+
pub async fn file_list(&self) -> Result<FileListResponse, Box<dyn std::error::Error>> {
186+
let res = self.get("/files").await;
187+
match res {
188+
Ok(res) => {
189+
let r = res.json::<FileListResponse>().await?;
190+
return Ok(r);
191+
},
192+
Err(e) => {
193+
return Err(e);
194+
},
195+
}
196+
}
197+
198+
pub async fn file_upload(&self, req: FileUploadRequest) -> Result<FileUploadResponse, Box<dyn std::error::Error>> {
199+
let res = self.post("/files", &req).await;
200+
match res {
201+
Ok(res) => {
202+
let r = res.json::<FileUploadResponse>().await?;
203+
return Ok(r);
204+
},
205+
Err(e) => {
206+
return Err(e);
207+
},
208+
}
209+
}
210+
211+
pub async fn file_delete(&self, req: FileDeleteRequest) -> Result<FileDeleteResponse, Box<dyn std::error::Error>> {
212+
let res = self.delete(&format!("{}/{}", "/files", req.file_id)).await;
213+
match res {
214+
Ok(res) => {
215+
let r = res.json::<FileDeleteResponse>().await?;
216+
return Ok(r);
217+
},
218+
Err(e) => {
219+
return Err(e);
220+
},
221+
}
222+
}
223+
224+
pub async fn file_retrieve(&self, req: FileRetrieveRequest) -> Result<FileRetrieveResponse, Box<dyn std::error::Error>> {
225+
let res = self.get(&format!("{}/{}", "/files", req.file_id)).await;
226+
match res {
227+
Ok(res) => {
228+
let r = res.json::<FileRetrieveResponse>().await?;
229+
return Ok(r);
230+
},
231+
Err(e) => {
232+
return Err(e);
233+
},
234+
}
235+
}
236+
237+
pub async fn file_retrieve_content(&self, req: FileRetrieveContentRequest) -> Result<FileRetrieveContentResponse, Box<dyn std::error::Error>> {
238+
let res = self.get(&format!("{}/{}/content", "/files", req.file_id)).await;
239+
match res {
240+
Ok(res) => {
241+
let r = res.json::<FileRetrieveContentResponse>().await?;
242+
return Ok(r);
243+
},
244+
Err(e) => {
245+
return Err(e);
246+
},
247+
}
248+
}
249+
250+
}

src/v1/common.rs

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
2+
use serde::{Deserialize};
3+
4+
5+
#[derive(Debug, Deserialize)]
6+
pub struct Usage {
7+
pub prompt_tokens: i32,
8+
pub completion_tokens: i32,
9+
pub total_tokens: i32,
10+
}

src/v1/completion.rs

+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
use serde::{Serialize, Deserialize};
2+
use std::option::Option;
3+
use std::collections::HashMap;
4+
5+
use crate::v1::common;
6+
7+
pub const GPT3_TEXT_DAVINCI_003: &str = "text-davinci-003";
8+
pub const GPT3_TEXT_DAVINCI_002: &str = "text-davinci-002";
9+
pub const GPT3_TEXT_CURIE_001: &str = "text-curie-001";
10+
pub const GPT3_TEXT_BABBAGE_001: &str = "text-babbage-001";
11+
pub const GPT3_TEXT_ADA_001: &str = "text-ada-001";
12+
pub const GPT3_TEXT_DAVINCI_001: &str = "text-davinci-001";
13+
pub const GPT3_DAVINCI_INSTRUCT_BETA: &str = "davinci-instruct-beta";
14+
pub const GPT3_DAVINCI: &str = "davinci";
15+
pub const GPT3_CURIE_INSTRUCT_BETA: &str = "curie-instruct-beta";
16+
pub const GPT3_CURIE: &str = "curie";
17+
pub const GPT3_ADA: &str = "ada";
18+
pub const GPT3_BABBAGE: &str = "babbage";
19+
20+
#[derive(Debug, Serialize)]
21+
pub struct CompletionRequest {
22+
pub model: String,
23+
#[serde(skip_serializing_if = "Option::is_none")]
24+
pub prompt: Option<String>,
25+
#[serde(skip_serializing_if = "Option::is_none")]
26+
pub suffix: Option<String>,
27+
#[serde(skip_serializing_if = "Option::is_none")]
28+
pub max_tokens: Option<i32>,
29+
#[serde(skip_serializing_if = "Option::is_none")]
30+
pub temperature: Option<f32>,
31+
#[serde(skip_serializing_if = "Option::is_none")]
32+
pub top_p: Option<f32>,
33+
#[serde(skip_serializing_if = "Option::is_none")]
34+
pub n: Option<i32>,
35+
#[serde(skip_serializing_if = "Option::is_none")]
36+
pub stream: Option<bool>,
37+
#[serde(skip_serializing_if = "Option::is_none")]
38+
pub logprobs: Option<i32>,
39+
#[serde(skip_serializing_if = "Option::is_none")]
40+
pub echo: Option<bool>,
41+
#[serde(skip_serializing_if = "Option::is_none")]
42+
pub stop: Option<Vec<String>>,
43+
#[serde(skip_serializing_if = "Option::is_none")]
44+
pub presence_penalty: Option<f32>,
45+
#[serde(skip_serializing_if = "Option::is_none")]
46+
pub frequency_penalty: Option<f32>,
47+
#[serde(skip_serializing_if = "Option::is_none")]
48+
pub best_of: Option<i32>,
49+
#[serde(skip_serializing_if = "Option::is_none")]
50+
pub logit_bias: Option<HashMap<String, i32>>,
51+
#[serde(skip_serializing_if = "Option::is_none")]
52+
pub user: Option<String>,
53+
}
54+
55+
#[derive(Debug, Deserialize)]
56+
pub struct CompletionChoice {
57+
pub text: String,
58+
pub index: i64,
59+
pub finish_reason: String,
60+
pub logprobs: Option<LogprobResult>,
61+
}
62+
63+
#[derive(Debug, Deserialize)]
64+
pub struct LogprobResult {
65+
pub tokens: Vec<String>,
66+
pub token_logprobs: Vec<f32>,
67+
pub top_logprobs: Vec<HashMap<String, f32>>,
68+
pub text_offset: Vec<i32>,
69+
}
70+
71+
#[derive(Debug, Deserialize)]
72+
pub struct CompletionResponse {
73+
pub id: String,
74+
pub object: String,
75+
pub created: i64,
76+
pub model: String,
77+
pub choices: Vec<CompletionChoice>,
78+
pub usage: common::Usage,
79+
}

0 commit comments

Comments
 (0)