Skip to content

Commit a9380b3

Browse files
authored
Http client refactor (#22)
* Add HTTP module for client and schemas, refactor all usage of reqwest http client into new HttpClient struct * docs * fmt * clippy * unused import in training.rs * removed session_cookie from HeatClient * Moved reqwest error conversion inside of http mod
1 parent 446508b commit a9380b3

File tree

7 files changed

+397
-235
lines changed

7 files changed

+397
-235
lines changed

crates/heat-sdk/src/client.rs

+54-192
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,14 @@
1+
use std::sync::Arc;
12
use std::sync::{mpsc, Mutex};
2-
use std::{collections::HashMap, sync::Arc};
33

44
use burn::tensor::backend::Backend;
5-
use reqwest::header::{COOKIE, SET_COOKIE};
6-
use serde::{Deserialize, Serialize};
5+
use serde::Serialize;
76

87
use crate::error::HeatSdkError;
98
use crate::experiment::{Experiment, TempLogStore, WsMessage};
10-
use crate::http_schemas::{EndExperimentSchema, StartExperimentSchema, URLSchema};
9+
use crate::http::{EndExperimentStatus, HttpClient};
1110
use crate::websocket::WebSocketClient;
1211

13-
enum AccessMode {
14-
Read,
15-
Write,
16-
}
17-
1812
/// Credentials to connect to the Heat server
1913
#[derive(Serialize, Debug, Clone)]
2014
pub struct HeatCredentials {
@@ -92,8 +86,7 @@ impl HeatClientConfigBuilder {
9286
#[derive(Debug, Clone)]
9387
pub struct HeatClient {
9488
config: HeatClientConfig,
95-
http_client: reqwest::blocking::Client,
96-
session_cookie: String,
89+
http_client: HttpClient,
9790
active_experiment: Option<Arc<Mutex<Experiment>>>,
9891
}
9992

@@ -102,114 +95,21 @@ pub type HeatClientState = HeatClient;
10295

10396
impl HeatClient {
10497
fn new(config: HeatClientConfig) -> HeatClient {
105-
let http_client = reqwest::blocking::Client::builder()
106-
.timeout(std::time::Duration::from_secs(15))
107-
.build()
108-
.expect("Client should be created.");
98+
let http_client = HttpClient::new(config.endpoint.clone());
10999

110100
HeatClient {
111101
config,
112102
http_client,
113-
session_cookie: "".to_string(),
114103
active_experiment: None,
115104
}
116105
}
117106

118-
#[allow(dead_code)]
119-
fn health_check(&self) -> Result<(), reqwest::Error> {
120-
let url = format!("{}/health", self.config.endpoint.clone());
121-
self.http_client.get(url).send()?;
122-
123-
Ok(())
124-
}
125-
126-
fn create_and_start_experiment(&self, config: &impl Serialize) -> Result<String, HeatSdkError> {
127-
#[derive(Deserialize)]
128-
struct ExperimentResponse {
129-
experiment_id: String,
130-
}
131-
132-
let url = format!(
133-
"{}/projects/{}/experiments",
134-
self.config.endpoint.clone(),
135-
self.config.project_id.clone()
136-
);
137-
138-
// Create a new experiment
139-
let exp_uuid = self
140-
.http_client
141-
.post(url)
142-
.header(COOKIE, &self.session_cookie)
143-
.send()?
144-
.error_for_status()?
145-
.json::<ExperimentResponse>()?
146-
.experiment_id;
147-
148-
let json = StartExperimentSchema {
149-
config: serde_json::to_value(config).unwrap(),
150-
};
151-
152-
// Start the experiment
153-
self.http_client
154-
.put(format!(
155-
"{}/experiments/{}/start",
156-
self.config.endpoint.clone(),
157-
exp_uuid
158-
))
159-
.header(COOKIE, &self.session_cookie)
160-
.json(&json)
161-
.send()?
162-
.error_for_status()?;
163-
164-
println!("Experiment UUID: {}", exp_uuid);
165-
Ok(exp_uuid)
166-
}
167-
168107
fn connect(&mut self) -> Result<(), HeatSdkError> {
169-
let url = format!("{}/login/api-key", self.config.endpoint.clone());
170-
let res = self
171-
.http_client
172-
.post(url)
173-
.form::<HeatCredentials>(&self.config.credentials)
174-
.send()?;
175-
// store session cookie
176-
if res.status().is_success() {
177-
let cookie_header = res.headers().get(SET_COOKIE);
178-
if let Some(cookie) = cookie_header {
179-
cookie
180-
.to_str()
181-
.expect("Session cookie should be convert to str")
182-
.clone_into(&mut self.session_cookie);
183-
} else {
184-
return Err(HeatSdkError::ClientError(
185-
"Cannot connect to Heat server, bad session ID.".to_string(),
186-
));
187-
}
188-
} else {
189-
let error_message = format!("Cannot connect to Heat server({:?})", res.text()?);
190-
return Err(HeatSdkError::ClientError(error_message));
191-
}
108+
self.http_client.login(&self.config.credentials)?;
192109

193110
Ok(())
194111
}
195112

196-
fn request_ws(&self, exp_uuid: &str) -> Result<String, HeatSdkError> {
197-
let url = format!(
198-
"{}/experiments/{}/ws",
199-
self.config.endpoint.clone(),
200-
exp_uuid
201-
);
202-
let ws_endpoint = self
203-
.http_client
204-
.get(url)
205-
.header(COOKIE, &self.session_cookie)
206-
.send()?
207-
.error_for_status()?
208-
.json::<URLSchema>()?
209-
.url;
210-
Ok(ws_endpoint)
211-
}
212-
213113
/// Create a new HeatClient with the given configuration.
214114
pub fn create(config: HeatClientConfig) -> Result<HeatClientState, HeatSdkError> {
215115
let mut client = HeatClient::new(config);
@@ -232,18 +132,19 @@ impl HeatClient {
232132

233133
/// Start a new experiment. This will create a new experiment on the Heat backend and start it.
234134
pub fn start_experiment(&mut self, config: &impl Serialize) -> Result<(), HeatSdkError> {
235-
let exp_uuid = self.create_and_start_experiment(config)?;
236-
let ws_endpoint = self.request_ws(exp_uuid.as_str())?;
135+
let exp_uuid = self
136+
.http_client
137+
.create_experiment(&self.config.project_id)?;
138+
self.http_client.start_experiment(&exp_uuid, config)?;
139+
140+
println!("Experiment UUID: {}", exp_uuid);
141+
142+
let ws_endpoint = self.http_client.request_websocket_url(&exp_uuid)?;
237143

238144
let mut ws_client = WebSocketClient::new();
239-
ws_client.connect(ws_endpoint, &self.session_cookie)?;
145+
ws_client.connect(ws_endpoint, self.http_client.get_session_cookie().unwrap())?;
240146

241-
let exp_log_store = TempLogStore::new(
242-
self.http_client.clone(),
243-
self.config.endpoint.clone(),
244-
exp_uuid.clone(),
245-
self.session_cookie.clone(),
246-
);
147+
let exp_log_store = TempLogStore::new(self.http_client.clone(), exp_uuid.clone());
247148

248149
let experiment = Arc::new(Mutex::new(Experiment::new(
249150
exp_uuid,
@@ -255,70 +156,38 @@ impl HeatClient {
255156
Ok(())
256157
}
257158

159+
/// Get the sender for the active experiment's WebSocket connection.
258160
pub fn get_experiment_sender(&self) -> Result<mpsc::Sender<WsMessage>, HeatSdkError> {
259161
let experiment = self.active_experiment.as_ref().unwrap();
260162
let experiment = experiment.lock().unwrap();
261163
experiment.get_ws_sender()
262164
}
263165

264-
fn request_checkpoint_url(
265-
&self,
266-
path: &str,
267-
access: AccessMode,
268-
) -> Result<String, reqwest::Error> {
269-
let url = format!("{}/checkpoints", self.config.endpoint.clone());
270-
271-
let mut body = HashMap::new();
272-
body.insert("file_path", path.to_string());
273-
body.insert(
274-
"experiment_id",
275-
self.active_experiment
276-
.as_ref()
277-
.unwrap()
278-
.lock()
279-
.unwrap()
280-
.id()
281-
.clone(),
282-
);
283-
284-
let response = match access {
285-
AccessMode::Read => self.http_client.get(url),
286-
AccessMode::Write => self.http_client.post(url),
287-
}
288-
.header(COOKIE, &self.session_cookie)
289-
.json(&body)
290-
.send()?
291-
.error_for_status()?
292-
.json::<URLSchema>()?;
293-
294-
Ok(response.url)
295-
}
296-
297-
fn upload_checkpoint(&self, url: &str, checkpoint: Vec<u8>) -> Result<(), reqwest::Error> {
298-
self.http_client.put(url).body(checkpoint).send()?;
299-
300-
Ok(())
301-
}
302-
303-
fn download_checkpoint(&self, url: &str) -> Result<Vec<u8>, reqwest::Error> {
304-
let response = self.http_client.get(url).send()?.bytes()?;
305-
306-
Ok(response.to_vec())
307-
}
308-
309166
/// Save checkpoint data to the Heat API.
310167
pub(crate) fn save_checkpoint_data(
311168
&self,
312169
path: &str,
313170
checkpoint: Vec<u8>,
314171
) -> Result<(), HeatSdkError> {
315-
let url = self.request_checkpoint_url(path, AccessMode::Write)?;
172+
let exp_uuid = self
173+
.active_experiment
174+
.as_ref()
175+
.unwrap()
176+
.lock()
177+
.unwrap()
178+
.id()
179+
.clone();
180+
181+
let url = self
182+
.http_client
183+
.request_checkpoint_save_url(&exp_uuid, path)?;
316184

317185
let time = std::time::SystemTime::now()
318186
.duration_since(std::time::UNIX_EPOCH)
319187
.unwrap()
320188
.as_millis();
321-
self.upload_checkpoint(&url, checkpoint)?;
189+
190+
self.http_client.upload_bytes_to_url(&url, checkpoint)?;
322191

323192
let time_end = std::time::SystemTime::now()
324193
.duration_since(std::time::UNIX_EPOCH)
@@ -331,12 +200,24 @@ impl HeatClient {
331200

332201
/// Load checkpoint data from the Heat API
333202
pub(crate) fn load_checkpoint_data(&self, path: &str) -> Result<Vec<u8>, HeatSdkError> {
334-
let url = self.request_checkpoint_url(path, AccessMode::Read)?;
335-
let response = self.download_checkpoint(&url)?;
203+
let exp_uuid = self
204+
.active_experiment
205+
.as_ref()
206+
.unwrap()
207+
.lock()
208+
.unwrap()
209+
.id()
210+
.clone();
336211

337-
Ok(response.to_vec())
212+
let url = self
213+
.http_client
214+
.request_checkpoint_load_url(&exp_uuid, path)?;
215+
let response = self.http_client.download_bytes_from_url(&url)?;
216+
217+
Ok(response)
338218
}
339219

220+
/// Save the final model to the Heat backend.
340221
pub(crate) fn save_final_model(&self, data: Vec<u8>) -> Result<(), HeatSdkError> {
341222
if self.active_experiment.is_none() {
342223
return Err(HeatSdkError::ClientError(
@@ -352,21 +233,10 @@ impl HeatClient {
352233
.id()
353234
.clone();
354235

355-
let url = format!(
356-
"{}/experiments/{}/save_model",
357-
self.config.endpoint.clone(),
358-
experiment_id
359-
);
360-
361-
let response = self
236+
let url = self
362237
.http_client
363-
.post(url)
364-
.header(COOKIE, &self.session_cookie)
365-
.send()?
366-
.error_for_status()?
367-
.json::<URLSchema>()?;
368-
369-
self.http_client.put(response.url).body(data).send()?;
238+
.request_final_model_save_url(&experiment_id)?;
239+
self.http_client.upload_bytes_to_url(&url, data)?;
370240

371241
Ok(())
372242
}
@@ -387,19 +257,19 @@ impl HeatClient {
387257
return Err(HeatSdkError::ClientError(e.to_string()));
388258
}
389259

390-
self.end_experiment_internal(EndExperimentSchema::Success)
260+
self.end_experiment_internal(EndExperimentStatus::Success)
391261
}
392262

393263
/// End the active experiment with an error reason.
394264
/// This will close the WebSocket connection and upload the logs to the Heat backend.
395265
/// No model will be uploaded.
396266
pub fn end_experiment_with_error(&mut self, error_reason: String) -> Result<(), HeatSdkError> {
397-
self.end_experiment_internal(EndExperimentSchema::Fail(error_reason))
267+
self.end_experiment_internal(EndExperimentStatus::Fail(error_reason))
398268
}
399269

400270
fn end_experiment_internal(
401271
&mut self,
402-
end_status: EndExperimentSchema,
272+
end_status: EndExperimentStatus,
403273
) -> Result<(), HeatSdkError> {
404274
let experiment: Arc<Mutex<Experiment>> = self.active_experiment.take().unwrap();
405275
let mut experiment = experiment.lock()?;
@@ -409,15 +279,7 @@ impl HeatClient {
409279

410280
// End the experiment in the backend
411281
self.http_client
412-
.put(format!(
413-
"{}/experiments/{}/end",
414-
self.config.endpoint.clone(),
415-
experiment.id()
416-
))
417-
.header(COOKIE, &self.session_cookie)
418-
.json(&end_status)
419-
.send()?
420-
.error_for_status()?;
282+
.end_experiment(experiment.id(), end_status)?;
421283

422284
Ok(())
423285
}
@@ -428,7 +290,7 @@ impl Drop for HeatClient {
428290
// if the ref count is 1, then we are the last reference to the client, so we should end the experiment
429291
if let Some(exp_arc) = &self.active_experiment {
430292
if Arc::strong_count(exp_arc) == 1 {
431-
self.end_experiment_internal(EndExperimentSchema::Success)
293+
self.end_experiment_internal(EndExperimentStatus::Success)
432294
.expect("Should be able to end the experiment after dropping the last client.");
433295
}
434296
}

crates/heat-sdk/src/error.rs

-14
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,6 @@ pub enum HeatSdkError {
1616
UnknownError(String),
1717
}
1818

19-
impl From<reqwest::Error> for HeatSdkError {
20-
fn from(error: reqwest::Error) -> Self {
21-
match error.status() {
22-
Some(status) => match status {
23-
reqwest::StatusCode::REQUEST_TIMEOUT => {
24-
HeatSdkError::ServerTimeoutError(error.to_string())
25-
}
26-
_ => HeatSdkError::ServerError(status.to_string()),
27-
},
28-
None => HeatSdkError::ServerError(error.to_string()),
29-
}
30-
}
31-
}
32-
3319
impl<T> From<std::sync::PoisonError<std::sync::MutexGuard<'_, T>>> for HeatSdkError {
3420
fn from(error: std::sync::PoisonError<std::sync::MutexGuard<'_, T>>) -> Self {
3521
HeatSdkError::ClientError(error.to_string())

0 commit comments

Comments
 (0)