1
+ use std:: sync:: Arc ;
1
2
use std:: sync:: { mpsc, Mutex } ;
2
- use std:: { collections:: HashMap , sync:: Arc } ;
3
3
4
4
use burn:: tensor:: backend:: Backend ;
5
- use reqwest:: header:: { COOKIE , SET_COOKIE } ;
6
- use serde:: { Deserialize , Serialize } ;
5
+ use serde:: Serialize ;
7
6
8
7
use crate :: error:: HeatSdkError ;
9
8
use crate :: experiment:: { Experiment , TempLogStore , WsMessage } ;
10
- use crate :: http_schemas :: { EndExperimentSchema , StartExperimentSchema , URLSchema } ;
9
+ use crate :: http :: { EndExperimentStatus , HttpClient } ;
11
10
use crate :: websocket:: WebSocketClient ;
12
11
13
- enum AccessMode {
14
- Read ,
15
- Write ,
16
- }
17
-
18
12
/// Credentials to connect to the Heat server
19
13
#[ derive( Serialize , Debug , Clone ) ]
20
14
pub struct HeatCredentials {
@@ -92,8 +86,7 @@ impl HeatClientConfigBuilder {
92
86
#[ derive( Debug , Clone ) ]
93
87
pub struct HeatClient {
94
88
config : HeatClientConfig ,
95
- http_client : reqwest:: blocking:: Client ,
96
- session_cookie : String ,
89
+ http_client : HttpClient ,
97
90
active_experiment : Option < Arc < Mutex < Experiment > > > ,
98
91
}
99
92
@@ -102,114 +95,21 @@ pub type HeatClientState = HeatClient;
102
95
103
96
impl HeatClient {
104
97
fn new ( config : HeatClientConfig ) -> HeatClient {
105
- let http_client = reqwest:: blocking:: Client :: builder ( )
106
- . timeout ( std:: time:: Duration :: from_secs ( 15 ) )
107
- . build ( )
108
- . expect ( "Client should be created." ) ;
98
+ let http_client = HttpClient :: new ( config. endpoint . clone ( ) ) ;
109
99
110
100
HeatClient {
111
101
config,
112
102
http_client,
113
- session_cookie : "" . to_string ( ) ,
114
103
active_experiment : None ,
115
104
}
116
105
}
117
106
118
- #[ allow( dead_code) ]
119
- fn health_check ( & self ) -> Result < ( ) , reqwest:: Error > {
120
- let url = format ! ( "{}/health" , self . config. endpoint. clone( ) ) ;
121
- self . http_client . get ( url) . send ( ) ?;
122
-
123
- Ok ( ( ) )
124
- }
125
-
126
- fn create_and_start_experiment ( & self , config : & impl Serialize ) -> Result < String , HeatSdkError > {
127
- #[ derive( Deserialize ) ]
128
- struct ExperimentResponse {
129
- experiment_id : String ,
130
- }
131
-
132
- let url = format ! (
133
- "{}/projects/{}/experiments" ,
134
- self . config. endpoint. clone( ) ,
135
- self . config. project_id. clone( )
136
- ) ;
137
-
138
- // Create a new experiment
139
- let exp_uuid = self
140
- . http_client
141
- . post ( url)
142
- . header ( COOKIE , & self . session_cookie )
143
- . send ( ) ?
144
- . error_for_status ( ) ?
145
- . json :: < ExperimentResponse > ( ) ?
146
- . experiment_id ;
147
-
148
- let json = StartExperimentSchema {
149
- config : serde_json:: to_value ( config) . unwrap ( ) ,
150
- } ;
151
-
152
- // Start the experiment
153
- self . http_client
154
- . put ( format ! (
155
- "{}/experiments/{}/start" ,
156
- self . config. endpoint. clone( ) ,
157
- exp_uuid
158
- ) )
159
- . header ( COOKIE , & self . session_cookie )
160
- . json ( & json)
161
- . send ( ) ?
162
- . error_for_status ( ) ?;
163
-
164
- println ! ( "Experiment UUID: {}" , exp_uuid) ;
165
- Ok ( exp_uuid)
166
- }
167
-
168
107
fn connect ( & mut self ) -> Result < ( ) , HeatSdkError > {
169
- let url = format ! ( "{}/login/api-key" , self . config. endpoint. clone( ) ) ;
170
- let res = self
171
- . http_client
172
- . post ( url)
173
- . form :: < HeatCredentials > ( & self . config . credentials )
174
- . send ( ) ?;
175
- // store session cookie
176
- if res. status ( ) . is_success ( ) {
177
- let cookie_header = res. headers ( ) . get ( SET_COOKIE ) ;
178
- if let Some ( cookie) = cookie_header {
179
- cookie
180
- . to_str ( )
181
- . expect ( "Session cookie should be convert to str" )
182
- . clone_into ( & mut self . session_cookie ) ;
183
- } else {
184
- return Err ( HeatSdkError :: ClientError (
185
- "Cannot connect to Heat server, bad session ID." . to_string ( ) ,
186
- ) ) ;
187
- }
188
- } else {
189
- let error_message = format ! ( "Cannot connect to Heat server({:?})" , res. text( ) ?) ;
190
- return Err ( HeatSdkError :: ClientError ( error_message) ) ;
191
- }
108
+ self . http_client . login ( & self . config . credentials ) ?;
192
109
193
110
Ok ( ( ) )
194
111
}
195
112
196
- fn request_ws ( & self , exp_uuid : & str ) -> Result < String , HeatSdkError > {
197
- let url = format ! (
198
- "{}/experiments/{}/ws" ,
199
- self . config. endpoint. clone( ) ,
200
- exp_uuid
201
- ) ;
202
- let ws_endpoint = self
203
- . http_client
204
- . get ( url)
205
- . header ( COOKIE , & self . session_cookie )
206
- . send ( ) ?
207
- . error_for_status ( ) ?
208
- . json :: < URLSchema > ( ) ?
209
- . url ;
210
- Ok ( ws_endpoint)
211
- }
212
-
213
113
/// Create a new HeatClient with the given configuration.
214
114
pub fn create ( config : HeatClientConfig ) -> Result < HeatClientState , HeatSdkError > {
215
115
let mut client = HeatClient :: new ( config) ;
@@ -232,18 +132,19 @@ impl HeatClient {
232
132
233
133
/// Start a new experiment. This will create a new experiment on the Heat backend and start it.
234
134
pub fn start_experiment ( & mut self , config : & impl Serialize ) -> Result < ( ) , HeatSdkError > {
235
- let exp_uuid = self . create_and_start_experiment ( config) ?;
236
- let ws_endpoint = self . request_ws ( exp_uuid. as_str ( ) ) ?;
135
+ let exp_uuid = self
136
+ . http_client
137
+ . create_experiment ( & self . config . project_id ) ?;
138
+ self . http_client . start_experiment ( & exp_uuid, config) ?;
139
+
140
+ println ! ( "Experiment UUID: {}" , exp_uuid) ;
141
+
142
+ let ws_endpoint = self . http_client . request_websocket_url ( & exp_uuid) ?;
237
143
238
144
let mut ws_client = WebSocketClient :: new ( ) ;
239
- ws_client. connect ( ws_endpoint, & self . session_cookie ) ?;
145
+ ws_client. connect ( ws_endpoint, self . http_client . get_session_cookie ( ) . unwrap ( ) ) ?;
240
146
241
- let exp_log_store = TempLogStore :: new (
242
- self . http_client . clone ( ) ,
243
- self . config . endpoint . clone ( ) ,
244
- exp_uuid. clone ( ) ,
245
- self . session_cookie . clone ( ) ,
246
- ) ;
147
+ let exp_log_store = TempLogStore :: new ( self . http_client . clone ( ) , exp_uuid. clone ( ) ) ;
247
148
248
149
let experiment = Arc :: new ( Mutex :: new ( Experiment :: new (
249
150
exp_uuid,
@@ -255,70 +156,38 @@ impl HeatClient {
255
156
Ok ( ( ) )
256
157
}
257
158
159
+ /// Get the sender for the active experiment's WebSocket connection.
258
160
pub fn get_experiment_sender ( & self ) -> Result < mpsc:: Sender < WsMessage > , HeatSdkError > {
259
161
let experiment = self . active_experiment . as_ref ( ) . unwrap ( ) ;
260
162
let experiment = experiment. lock ( ) . unwrap ( ) ;
261
163
experiment. get_ws_sender ( )
262
164
}
263
165
264
- fn request_checkpoint_url (
265
- & self ,
266
- path : & str ,
267
- access : AccessMode ,
268
- ) -> Result < String , reqwest:: Error > {
269
- let url = format ! ( "{}/checkpoints" , self . config. endpoint. clone( ) ) ;
270
-
271
- let mut body = HashMap :: new ( ) ;
272
- body. insert ( "file_path" , path. to_string ( ) ) ;
273
- body. insert (
274
- "experiment_id" ,
275
- self . active_experiment
276
- . as_ref ( )
277
- . unwrap ( )
278
- . lock ( )
279
- . unwrap ( )
280
- . id ( )
281
- . clone ( ) ,
282
- ) ;
283
-
284
- let response = match access {
285
- AccessMode :: Read => self . http_client . get ( url) ,
286
- AccessMode :: Write => self . http_client . post ( url) ,
287
- }
288
- . header ( COOKIE , & self . session_cookie )
289
- . json ( & body)
290
- . send ( ) ?
291
- . error_for_status ( ) ?
292
- . json :: < URLSchema > ( ) ?;
293
-
294
- Ok ( response. url )
295
- }
296
-
297
- fn upload_checkpoint ( & self , url : & str , checkpoint : Vec < u8 > ) -> Result < ( ) , reqwest:: Error > {
298
- self . http_client . put ( url) . body ( checkpoint) . send ( ) ?;
299
-
300
- Ok ( ( ) )
301
- }
302
-
303
- fn download_checkpoint ( & self , url : & str ) -> Result < Vec < u8 > , reqwest:: Error > {
304
- let response = self . http_client . get ( url) . send ( ) ?. bytes ( ) ?;
305
-
306
- Ok ( response. to_vec ( ) )
307
- }
308
-
309
166
/// Save checkpoint data to the Heat API.
310
167
pub ( crate ) fn save_checkpoint_data (
311
168
& self ,
312
169
path : & str ,
313
170
checkpoint : Vec < u8 > ,
314
171
) -> Result < ( ) , HeatSdkError > {
315
- let url = self . request_checkpoint_url ( path, AccessMode :: Write ) ?;
172
+ let exp_uuid = self
173
+ . active_experiment
174
+ . as_ref ( )
175
+ . unwrap ( )
176
+ . lock ( )
177
+ . unwrap ( )
178
+ . id ( )
179
+ . clone ( ) ;
180
+
181
+ let url = self
182
+ . http_client
183
+ . request_checkpoint_save_url ( & exp_uuid, path) ?;
316
184
317
185
let time = std:: time:: SystemTime :: now ( )
318
186
. duration_since ( std:: time:: UNIX_EPOCH )
319
187
. unwrap ( )
320
188
. as_millis ( ) ;
321
- self . upload_checkpoint ( & url, checkpoint) ?;
189
+
190
+ self . http_client . upload_bytes_to_url ( & url, checkpoint) ?;
322
191
323
192
let time_end = std:: time:: SystemTime :: now ( )
324
193
. duration_since ( std:: time:: UNIX_EPOCH )
@@ -331,12 +200,24 @@ impl HeatClient {
331
200
332
201
/// Load checkpoint data from the Heat API
333
202
pub ( crate ) fn load_checkpoint_data ( & self , path : & str ) -> Result < Vec < u8 > , HeatSdkError > {
334
- let url = self . request_checkpoint_url ( path, AccessMode :: Read ) ?;
335
- let response = self . download_checkpoint ( & url) ?;
203
+ let exp_uuid = self
204
+ . active_experiment
205
+ . as_ref ( )
206
+ . unwrap ( )
207
+ . lock ( )
208
+ . unwrap ( )
209
+ . id ( )
210
+ . clone ( ) ;
336
211
337
- Ok ( response. to_vec ( ) )
212
+ let url = self
213
+ . http_client
214
+ . request_checkpoint_load_url ( & exp_uuid, path) ?;
215
+ let response = self . http_client . download_bytes_from_url ( & url) ?;
216
+
217
+ Ok ( response)
338
218
}
339
219
220
+ /// Save the final model to the Heat backend.
340
221
pub ( crate ) fn save_final_model ( & self , data : Vec < u8 > ) -> Result < ( ) , HeatSdkError > {
341
222
if self . active_experiment . is_none ( ) {
342
223
return Err ( HeatSdkError :: ClientError (
@@ -352,21 +233,10 @@ impl HeatClient {
352
233
. id ( )
353
234
. clone ( ) ;
354
235
355
- let url = format ! (
356
- "{}/experiments/{}/save_model" ,
357
- self . config. endpoint. clone( ) ,
358
- experiment_id
359
- ) ;
360
-
361
- let response = self
236
+ let url = self
362
237
. http_client
363
- . post ( url)
364
- . header ( COOKIE , & self . session_cookie )
365
- . send ( ) ?
366
- . error_for_status ( ) ?
367
- . json :: < URLSchema > ( ) ?;
368
-
369
- self . http_client . put ( response. url ) . body ( data) . send ( ) ?;
238
+ . request_final_model_save_url ( & experiment_id) ?;
239
+ self . http_client . upload_bytes_to_url ( & url, data) ?;
370
240
371
241
Ok ( ( ) )
372
242
}
@@ -387,19 +257,19 @@ impl HeatClient {
387
257
return Err ( HeatSdkError :: ClientError ( e. to_string ( ) ) ) ;
388
258
}
389
259
390
- self . end_experiment_internal ( EndExperimentSchema :: Success )
260
+ self . end_experiment_internal ( EndExperimentStatus :: Success )
391
261
}
392
262
393
263
/// End the active experiment with an error reason.
394
264
/// This will close the WebSocket connection and upload the logs to the Heat backend.
395
265
/// No model will be uploaded.
396
266
pub fn end_experiment_with_error ( & mut self , error_reason : String ) -> Result < ( ) , HeatSdkError > {
397
- self . end_experiment_internal ( EndExperimentSchema :: Fail ( error_reason) )
267
+ self . end_experiment_internal ( EndExperimentStatus :: Fail ( error_reason) )
398
268
}
399
269
400
270
fn end_experiment_internal (
401
271
& mut self ,
402
- end_status : EndExperimentSchema ,
272
+ end_status : EndExperimentStatus ,
403
273
) -> Result < ( ) , HeatSdkError > {
404
274
let experiment: Arc < Mutex < Experiment > > = self . active_experiment . take ( ) . unwrap ( ) ;
405
275
let mut experiment = experiment. lock ( ) ?;
@@ -409,15 +279,7 @@ impl HeatClient {
409
279
410
280
// End the experiment in the backend
411
281
self . http_client
412
- . put ( format ! (
413
- "{}/experiments/{}/end" ,
414
- self . config. endpoint. clone( ) ,
415
- experiment. id( )
416
- ) )
417
- . header ( COOKIE , & self . session_cookie )
418
- . json ( & end_status)
419
- . send ( ) ?
420
- . error_for_status ( ) ?;
282
+ . end_experiment ( experiment. id ( ) , end_status) ?;
421
283
422
284
Ok ( ( ) )
423
285
}
@@ -428,7 +290,7 @@ impl Drop for HeatClient {
428
290
// if the ref count is 1, then we are the last reference to the client, so we should end the experiment
429
291
if let Some ( exp_arc) = & self . active_experiment {
430
292
if Arc :: strong_count ( exp_arc) == 1 {
431
- self . end_experiment_internal ( EndExperimentSchema :: Success )
293
+ self . end_experiment_internal ( EndExperimentStatus :: Success )
432
294
. expect ( "Should be able to end the experiment after dropping the last client." ) ;
433
295
}
434
296
}
0 commit comments