@@ -6,6 +6,7 @@ use log::{debug, error, info, trace, warn};
6
6
use rand_distr:: Distribution ;
7
7
use rayon:: iter:: split;
8
8
use rayon:: prelude:: * ;
9
+ use reqwest:: Url ;
9
10
use reqwest_eventsource:: { Error , Event , EventSource } ;
10
11
use serde:: { Deserialize , Serialize } ;
11
12
use std:: cmp:: Ordering ;
@@ -58,7 +59,7 @@ impl Clone for Box<dyn TextGenerationBackend + Send + Sync> {
58
59
#[ derive( Debug , Clone ) ]
59
60
pub struct OpenAITextGenerationBackend {
60
61
pub api_key : String ,
61
- pub base_url : String ,
62
+ pub base_url : Url ,
62
63
pub model_name : String ,
63
64
pub client : reqwest:: Client ,
64
65
pub tokenizer : Arc < Tokenizer > ,
@@ -101,7 +102,7 @@ pub struct OpenAITextGenerationRequest {
101
102
impl OpenAITextGenerationBackend {
102
103
pub fn try_new (
103
104
api_key : String ,
104
- base_url : String ,
105
+ base_url : Url ,
105
106
model_name : String ,
106
107
tokenizer : Arc < Tokenizer > ,
107
108
timeout : time:: Duration ,
@@ -128,7 +129,9 @@ impl TextGenerationBackend for OpenAITextGenerationBackend {
128
129
request : Arc < TextGenerationRequest > ,
129
130
sender : Sender < TextGenerationAggregatedResponse > ,
130
131
) {
131
- let url = format ! ( "{base_url}/v1/chat/completions" , base_url = self . base_url) ;
132
+ let mut url = self . base_url . clone ( ) ;
133
+ url. set_path ( "/v1/chat/completions" ) ;
134
+ // let url = format!("{base_url}", base_url = self.base_url);
132
135
let mut aggregated_response = TextGenerationAggregatedResponse :: new ( request. clone ( ) ) ;
133
136
let messages = vec ! [ OpenAITextGenerationMessage {
134
137
role: "user" . to_string( ) ,
@@ -829,7 +832,7 @@ mod tests {
829
832
w. write_all ( b"data: [DONE]\n \n " )
830
833
} )
831
834
. create_async ( ) . await ;
832
- let url = s. url ( ) ;
835
+ let url = s. url ( ) . parse ( ) . unwrap ( ) ;
833
836
let tokenizer = Arc :: new ( Tokenizer :: from_pretrained ( "gpt2" , None ) . unwrap ( ) ) ;
834
837
let backend = OpenAITextGenerationBackend :: try_new (
835
838
"" . to_string ( ) ,
@@ -890,7 +893,7 @@ mod tests {
890
893
w. write_all ( b"data: [DONE]\n \n " )
891
894
} )
892
895
. create_async ( ) . await ;
893
- let url = s. url ( ) ;
896
+ let url = s. url ( ) . parse ( ) . unwrap ( ) ;
894
897
let tokenizer = Arc :: new ( Tokenizer :: from_pretrained ( "gpt2" , None ) . unwrap ( ) ) ;
895
898
let backend = OpenAITextGenerationBackend :: try_new (
896
899
"" . to_string ( ) ,
@@ -975,7 +978,7 @@ mod tests {
975
978
. with_chunked_body ( |w| w. write_all ( b"data: {\" error\" : \" Internal server error\" }\n \n " ) )
976
979
. create_async ( )
977
980
. await ;
978
- let url = s. url ( ) ;
981
+ let url = s. url ( ) . parse ( ) . unwrap ( ) ;
979
982
let tokenizer = Arc :: new ( Tokenizer :: from_pretrained ( "gpt2" , None ) . unwrap ( ) ) ;
980
983
let backend = OpenAITextGenerationBackend :: try_new (
981
984
"" . to_string ( ) ,
@@ -1021,7 +1024,7 @@ mod tests {
1021
1024
. with_chunked_body ( |w| w. write_all ( b"this is wrong\n \n " ) )
1022
1025
. create_async ( )
1023
1026
. await ;
1024
- let url = s. url ( ) ;
1027
+ let url = s. url ( ) . parse ( ) . unwrap ( ) ;
1025
1028
let tokenizer = Arc :: new ( Tokenizer :: from_pretrained ( "gpt2" , None ) . unwrap ( ) ) ;
1026
1029
let backend = OpenAITextGenerationBackend :: try_new (
1027
1030
"" . to_string ( ) ,
@@ -1067,7 +1070,7 @@ mod tests {
1067
1070
. with_chunked_body ( |w| w. write_all ( b"data: {\" foo\" : \" bar\" }\n \n " ) )
1068
1071
. create_async ( )
1069
1072
. await ;
1070
- let url = s. url ( ) ;
1073
+ let url = s. url ( ) . parse ( ) . unwrap ( ) ;
1071
1074
let tokenizer = Arc :: new ( Tokenizer :: from_pretrained ( "gpt2" , None ) . unwrap ( ) ) ;
1072
1075
let backend = OpenAITextGenerationBackend :: try_new (
1073
1076
"" . to_string ( ) ,
@@ -1117,7 +1120,7 @@ mod tests {
1117
1120
w. write_all ( b"data: [DONE]\n \n " )
1118
1121
} )
1119
1122
. create_async ( ) . await ;
1120
- let url = s. url ( ) ;
1123
+ let url = s. url ( ) . parse ( ) . unwrap ( ) ;
1121
1124
let tokenizer = Arc :: new ( Tokenizer :: from_pretrained ( "gpt2" , None ) . unwrap ( ) ) ;
1122
1125
let backend = OpenAITextGenerationBackend :: try_new (
1123
1126
"" . to_string ( ) ,
0 commit comments