66import importlib
77
88# library import
9+ from fastapi import Request
10+ from fastapi .responses import StreamingResponse
911
1012# comps import
11- from comps import MicroService , ServiceOrchestrator , ServiceType
12- from app_gateway import AppGateway
13+ from comps import MicroService , ServiceOrchestrator , ServiceRoleType , ServiceType
14+ from comps .cores .mega .utils import handle_message
15+ from comps .cores .proto .api_protocol import (
16+ ChatCompletionRequest ,
17+ ChatCompletionResponse ,
18+ ChatCompletionResponseChoice ,
19+ ChatMessage ,
20+ UsageInfo ,
21+ )
22+ from comps .cores .proto .docarray import LLMParams , RerankerParms , RetrieverParms
1323
14- HOST_IP = os .getenv ("HOST_IP" , "0,0,0,0" )
24+ category_params_map = {
25+ 'LLM' : LLMParams ,
26+ 'Reranking' : RerankerParms ,
27+ 'Retreiver' : RetrieverParms ,
28+ }
29+
30+ HOST_IP = os .getenv ("HOST_IP" , "0.0.0.0" )
1531USE_NODE_ID_AS_IP = os .getenv ("USE_NODE_ID_AS_IP" ,"" ).lower () == 'true'
1632
1733class AppService :
1834 def __init__ (self , host = "0.0.0.0" , port = 8000 ):
1935 self .host = host
2036 self .port = port
2137 self .megaservice = ServiceOrchestrator ()
38+ self .megaservice .align_inputs = self .align_inputs
39+ self .endpoint = "/v1/app-backend"
2240 with open ('config/workflow-info.json' , 'r' ) as f :
2341 self .workflow_info = json .load (f )
2442
@@ -56,8 +74,7 @@ def add_remote_service(self):
5674 self .megaservice .flow_to (services [prev_node ], microservice )
5775 for next_node in node ['connected_to' ]:
5876 nodes .append (next_node )
59- self .megaservice .align_inputs = self .align_inputs
60- self .gateway = AppGateway (megaservice = self .megaservice , host = "0.0.0.0" , port = self .port )
77+
6178 def align_inputs (self , inputs , * args , ** kwargs ):
6279 """Override this method in megaservice definition."""
6380 print ('\n ' * 2 ,'align_inputs' )
@@ -73,10 +90,75 @@ def align_inputs(self, inputs, *args, **kwargs):
7390 except Exception as e :
7491 print ('unable to parse input' , e )
7592 return inputs
93+
94+ async def handle_request (self , request : Request ):
95+ data = await request .json ()
96+ print ('\n ' * 5 , '====== handle_request ======\n ' , data )
97+ if 'chat_completion_ids' in self .workflow_info :
98+ prompt = handle_message (data ['messages' ])
99+ params = {}
100+ llm_parameters = None
101+ for id , node in self .workflow_info ['nodes' ].items ():
102+ if node ['category' ] in category_params_map :
103+ param_class = category_params_map [node ['category' ]]()
104+ param_keys = [key for key in dir (param_class ) if not key .startswith ('__' ) and not callable (getattr (param_class , key ))]
105+ print ('param_keys' , param_keys )
106+ params_dict = {}
107+ for key in param_keys :
108+ if key in data :
109+ params_dict [key ] = data [key ]
110+ # hadle special case for stream and streaming
111+ if key in ['stream' , 'streaming' ]:
112+ params_dict [key ] = data .get ('stream' , True ) and data .get ('streaming' , True )
113+ elif key in node ['inference_params' ]:
114+ params_dict [key ] = node ['inference_params' ][key ]
115+ params [id ] = params_dict
116+ if node ['category' ] == 'LLM' :
117+ params [id ]['max_new_tokens' ] = params [id ]['max_tokens' ]
118+ llm_parameters = LLMParams (** params [id ])
119+ result_dict , runtime_graph = await self .megaservice .schedule (
120+ initial_inputs = {'query' :prompt , 'text' : prompt },
121+ llm_parameters = llm_parameters ,
122+ params = params ,
123+ )
124+ print ('runtime_graph' , runtime_graph .graph )
125+ for node , response in result_dict .items ():
126+ if isinstance (response , StreamingResponse ):
127+ return response
128+ last_node = runtime_graph .all_leaves ()[- 1 ]
129+ print ('result_dict:' , result_dict )
130+ print ('last_node:' ,last_node )
131+ response = result_dict [last_node ]['text' ]
132+ choices = []
133+ usage = UsageInfo ()
134+ choices .append (
135+ ChatCompletionResponseChoice (
136+ index = 0 ,
137+ message = ChatMessage (role = 'assistant' , content = response ),
138+ finish_reason = 'stop' ,
139+ )
140+ )
141+ return ChatCompletionResponse (model = 'custom_app' , choices = choices , usage = usage )
142+
143+ def start (self ):
144+
145+ self .service = MicroService (
146+ self .__class__ .__name__ ,
147+ service_role = ServiceRoleType .MEGASERVICE ,
148+ host = self .host ,
149+ port = self .port ,
150+ endpoint = self .endpoint ,
151+ input_datatype = ChatCompletionRequest ,
152+ output_datatype = ChatCompletionResponse ,
153+ )
76154
155+ self .service .add_route (self .endpoint , self .handle_request , methods = ["POST" ])
156+ self .service .start ()
77157
78158if __name__ == "__main__" :
79159 megaservice_host_ip = None if USE_NODE_ID_AS_IP else HOST_IP
80- chatqna = AppService (host = HOST_IP , port = 8888 )
160+ print ('pre initialize appService' )
161+ app = AppService (host = HOST_IP , port = 8888 )
81162 print ('after initialize appService' )
82- chatqna .add_remote_service ()
163+ app .add_remote_service ()
164+ app .start ()
0 commit comments