|
6 | 6 | import os |
7 | 7 | import re |
8 | 8 |
|
9 | | -from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType |
| 9 | +from comps import MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceRoleType, ServiceType |
| 10 | +from comps.cores.mega.utils import handle_message |
10 | 11 | from comps.cores.proto.api_protocol import ( |
11 | 12 | ChatCompletionRequest, |
12 | 13 | ChatCompletionResponse, |
@@ -127,14 +128,15 @@ def align_generator(self, gen, **kwargs): |
127 | 128 | yield "data: [DONE]\n\n" |
128 | 129 |
|
129 | 130 |
|
130 | | -class GraphRAGService(Gateway): |
| 131 | +class GraphRAGService: |
131 | 132 | def __init__(self, host="0.0.0.0", port=8000): |
132 | 133 | self.host = host |
133 | 134 | self.port = port |
134 | 135 | ServiceOrchestrator.align_inputs = align_inputs |
135 | 136 | ServiceOrchestrator.align_outputs = align_outputs |
136 | 137 | ServiceOrchestrator.align_generator = align_generator |
137 | 138 | self.megaservice = ServiceOrchestrator() |
| 139 | + self.endpoint = str(MegaServiceEndpoint.GRAPH_RAG) |
138 | 140 |
|
139 | 141 | def add_remote_service(self): |
140 | 142 | retriever = MicroService( |
@@ -180,7 +182,7 @@ def parser_input(data, TypeClass, key): |
180 | 182 | raise ValueError(f"Unknown request type: {data}") |
181 | 183 | if chat_request is None: |
182 | 184 | raise ValueError(f"Unknown request type: {data}") |
183 | | - prompt = self._handle_message(chat_request.messages) |
| 185 | + prompt = handle_message(chat_request.messages) |
184 | 186 | parameters = LLMParams( |
185 | 187 | max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024, |
186 | 188 | top_k=chat_request.top_k if chat_request.top_k else 10, |
@@ -223,14 +225,17 @@ def parser_input(data, TypeClass, key): |
223 | 225 | return ChatCompletionResponse(model="chatqna", choices=choices, usage=usage) |
224 | 226 |
|
225 | 227 | def start(self): |
226 | | - super().__init__( |
227 | | - megaservice=self.megaservice, |
| 228 | + self.service = MicroService( |
| 229 | + self.__class__.__name__, |
| 230 | + service_role=ServiceRoleType.MEGASERVICE, |
228 | 231 | host=self.host, |
229 | 232 | port=self.port, |
230 | | - endpoint=str(MegaServiceEndpoint.GRAPH_RAG), |
| 233 | + endpoint=self.endpoint, |
231 | 234 | input_datatype=ChatCompletionRequest, |
232 | 235 | output_datatype=ChatCompletionResponse, |
233 | 236 | ) |
| 237 | + self.service.add_route(self.endpoint, self.handle_request, methods=["POST"]) |
| 238 | + self.service.start() |
234 | 239 |
|
235 | 240 |
|
236 | 241 | if __name__ == "__main__": |
|
0 commit comments