66import os
77import re
88
9- from comps import ChatQnAGateway , MicroService , ServiceOrchestrator , ServiceType
9+ from comps import Gateway , MegaServiceEndpoint , MicroService , ServiceOrchestrator , ServiceType
10+ from comps .cores .proto .api_protocol import (
11+ ChatCompletionRequest ,
12+ ChatCompletionResponse ,
13+ ChatCompletionResponseChoice ,
14+ ChatMessage ,
15+ UsageInfo ,
16+ )
17+ from comps .cores .proto .docarray import LLMParams , RerankerParms , RetrieverParms
18+ from fastapi import Request
19+ from fastapi .responses import StreamingResponse
1020from langchain_core .prompts import PromptTemplate
1121
1222
@@ -35,7 +45,6 @@ def generate_rag_prompt(question, documents):
3545 return template .format (context = context_str , question = question )
3646
3747
38- MEGA_SERVICE_HOST_IP = os .getenv ("MEGA_SERVICE_HOST_IP" , "0.0.0.0" )
3948MEGA_SERVICE_PORT = int (os .getenv ("MEGA_SERVICE_PORT" , 8888 ))
4049GUARDRAIL_SERVICE_HOST_IP = os .getenv ("GUARDRAIL_SERVICE_HOST_IP" , "0.0.0.0" )
4150GUARDRAIL_SERVICE_PORT = int (os .getenv ("GUARDRAIL_SERVICE_PORT" , 80 ))
@@ -178,13 +187,14 @@ def align_generator(self, gen, **kwargs):
178187 yield "data: [DONE]\n \n "
179188
180189
181- class ChatQnAService :
190+ class ChatQnAService ( Gateway ) :
182191 def __init__ (self , host = "0.0.0.0" , port = 8000 ):
183192 self .host = host
184193 self .port = port
185194 ServiceOrchestrator .align_inputs = align_inputs
186195 ServiceOrchestrator .align_outputs = align_outputs
187196 ServiceOrchestrator .align_generator = align_generator
197+
188198 self .megaservice = ServiceOrchestrator ()
189199
190200 def add_remote_service (self ):
@@ -228,7 +238,6 @@ def add_remote_service(self):
228238 self .megaservice .flow_to (embedding , retriever )
229239 self .megaservice .flow_to (retriever , rerank )
230240 self .megaservice .flow_to (rerank , llm )
231- self .gateway = ChatQnAGateway (megaservice = self .megaservice , host = "0.0.0.0" , port = self .port )
232241
233242 def add_remote_service_without_rerank (self ):
234243
@@ -261,7 +270,6 @@ def add_remote_service_without_rerank(self):
261270 self .megaservice .add (embedding ).add (retriever ).add (llm )
262271 self .megaservice .flow_to (embedding , retriever )
263272 self .megaservice .flow_to (retriever , llm )
264- self .gateway = ChatQnAGateway (megaservice = self .megaservice , host = "0.0.0.0" , port = self .port )
265273
266274 def add_remote_service_with_guardrails (self ):
267275 guardrail_in = MicroService (
@@ -319,7 +327,66 @@ def add_remote_service_with_guardrails(self):
319327 self .megaservice .flow_to (retriever , rerank )
320328 self .megaservice .flow_to (rerank , llm )
321329 # self.megaservice.flow_to(llm, guardrail_out)
322- self .gateway = ChatQnAGateway (megaservice = self .megaservice , host = "0.0.0.0" , port = self .port )
330+
331+ async def handle_request (self , request : Request ):
332+ data = await request .json ()
333+ stream_opt = data .get ("stream" , True )
334+ chat_request = ChatCompletionRequest .parse_obj (data )
335+ prompt = self ._handle_message (chat_request .messages )
336+ parameters = LLMParams (
337+ max_tokens = chat_request .max_tokens if chat_request .max_tokens else 1024 ,
338+ top_k = chat_request .top_k if chat_request .top_k else 10 ,
339+ top_p = chat_request .top_p if chat_request .top_p else 0.95 ,
340+ temperature = chat_request .temperature if chat_request .temperature else 0.01 ,
341+ frequency_penalty = chat_request .frequency_penalty if chat_request .frequency_penalty else 0.0 ,
342+ presence_penalty = chat_request .presence_penalty if chat_request .presence_penalty else 0.0 ,
343+ repetition_penalty = chat_request .repetition_penalty if chat_request .repetition_penalty else 1.03 ,
344+ streaming = stream_opt ,
345+ chat_template = chat_request .chat_template if chat_request .chat_template else None ,
346+ )
347+ retriever_parameters = RetrieverParms (
348+ search_type = chat_request .search_type if chat_request .search_type else "similarity" ,
349+ k = chat_request .k if chat_request .k else 4 ,
350+ distance_threshold = chat_request .distance_threshold if chat_request .distance_threshold else None ,
351+ fetch_k = chat_request .fetch_k if chat_request .fetch_k else 20 ,
352+ lambda_mult = chat_request .lambda_mult if chat_request .lambda_mult else 0.5 ,
353+ score_threshold = chat_request .score_threshold if chat_request .score_threshold else 0.2 ,
354+ )
355+ reranker_parameters = RerankerParms (
356+ top_n = chat_request .top_n if chat_request .top_n else 1 ,
357+ )
358+ result_dict , runtime_graph = await self .megaservice .schedule (
359+ initial_inputs = {"text" : prompt },
360+ llm_parameters = parameters ,
361+ retriever_parameters = retriever_parameters ,
362+ reranker_parameters = reranker_parameters ,
363+ )
364+ for node , response in result_dict .items ():
365+ if isinstance (response , StreamingResponse ):
366+ return response
367+ last_node = runtime_graph .all_leaves ()[- 1 ]
368+ response = result_dict [last_node ]["text" ]
369+ choices = []
370+ usage = UsageInfo ()
371+ choices .append (
372+ ChatCompletionResponseChoice (
373+ index = 0 ,
374+ message = ChatMessage (role = "assistant" , content = response ),
375+ finish_reason = "stop" ,
376+ )
377+ )
378+ return ChatCompletionResponse (model = "chatqna" , choices = choices , usage = usage )
379+
380+ def start (self ):
381+
382+ super ().__init__ (
383+ megaservice = self .megaservice ,
384+ host = self .host ,
385+ port = self .port ,
386+ endpoint = str (MegaServiceEndpoint .CHAT_QNA ),
387+ input_datatype = ChatCompletionRequest ,
388+ output_datatype = ChatCompletionResponse ,
389+ )
323390
324391
325392if __name__ == "__main__" :
@@ -329,10 +396,12 @@ def add_remote_service_with_guardrails(self):
329396
330397 args = parser .parse_args ()
331398
332- chatqna = ChatQnAService (host = MEGA_SERVICE_HOST_IP , port = MEGA_SERVICE_PORT )
399+ chatqna = ChatQnAService (port = MEGA_SERVICE_PORT )
333400 if args .without_rerank :
334401 chatqna .add_remote_service_without_rerank ()
335402 elif args .with_guardrails :
336403 chatqna .add_remote_service_with_guardrails ()
337404 else :
338405 chatqna .add_remote_service ()
406+
407+ chatqna .start ()
0 commit comments