From b4177175e77ab9ed482a4eb71c8c57dc8c7b2403 Mon Sep 17 00:00:00 2001 From: pancake Date: Thu, 23 May 2024 12:32:52 +0200 Subject: [PATCH] Implement the v1/chat/completions endpoint --- r2ai/web.py | 89 ++++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 67 insertions(+), 22 deletions(-) diff --git a/r2ai/web.py b/r2ai/web.py index efbba2c8..6570b9b4 100644 --- a/r2ai/web.py +++ b/r2ai/web.py @@ -21,6 +21,48 @@ def handle_v1_completions_default(self, ai, obj, runline2, method): self.end_headers() return True +# {"messages": [ {"role": "user", "content": "Explain the following code:n'''ndata = {n "messages": [n {n "role": "user",n "content": "Explain code:n" + c_code,n "stream": "true",n "max_tokens": 7100,n "temperature": 0.2n }n ],n}n'''"}],x0dx0a"model":"", "stream": "true", "max_tokens": 7100, "temperature": 0.2# } +def handle_v1_chat_completions(self, ai, obj, runline2, method): + print("/v1/chat/completions") + if obj == None: + print("ObjNone") + self.send_response(200) + self.end_headers() + return True + if "messages" not in obj: + return handle_v1_completions_default(self, ai, obj, runline2, method) + codequery = obj["messages"][0]["content"] + runline2(ai, "-R") + response = json.loads('''{ + "id": "r2ai", + "object": "chat.completion.chunk", + "created": 1, + "choices": [ + { + "finish_reason": "null", + "delta": { + "role": "assistant", + "content": "" + } + } + ] + }''') + response["choices"][0]["delta"]["content"] = "" + jresponse = json.dumps(response) + #self.wfile.write(bytes(f'data: {jresponse}','utf-8')) + ores = runline2(ai, codequery).strip() + print("============") + print(ores) + print("============") + response["choices"][0]["delta"]["content"] = ores + response["choices"][0]["finish_reason"] = "length" + self.send_response(200) + self.end_headers() + jresponse = json.dumps(response) + # print(jresponse) + self.wfile.write(bytes(f'data: {jresponse}','utf-8')) + print("computed") + def handle_v1_completions(self, ai, obj, runline2, method): global ores print("/v1/completions") @@ -72,27 +114,29 @@ def handle_v1_completions(self, ai, obj, runline2, method): def handle_tabby_query(self, ai, obj, runline2, method): global ores - # TODO build proper health json instead of copypasting a stolen one - model = ai.env["llm.model"] - healthobj = { - "model":ai.env["llm.model"], - "device":"gpu" if ai.env["llm.gpu"] else "cpu", - "arch": platform.machine(), - "cpu_info": "", - "cpu_count": 1, - "cuda_devices": [], - "version": { - "build_date": "2024-05-22", - "build_timestamp": "2024-05-22", - "git_sha": "", - "git_describe": "", - }, - } - healthstr=json.dumps(healthobj) - print(healthstr) - if method == "GET" and self.path == "/v1/health": + print(self.path) + if self.path == "/v1/chat/completions": + return handle_v1_chat_completions(self, ai, obj, runline2, method) + if self.path == "/v1/health": ## GET only self.send_response(200) self.end_headers() + # TODO build proper health json instead of copypasting a stolen one + model = ai.env["llm.model"] + healthobj = { + "model":ai.env["llm.model"], + "device":"gpu" if ai.env["llm.gpu"] else "cpu", + "arch": platform.machine(), + "cpu_info": "", + "cpu_count": 1, + "cuda_devices": [], + "version": { + "build_date": "2024-05-22", + "build_timestamp": "2024-05-22", + "git_sha": "", + "git_describe": "", + }, + } + healthstr=json.dumps(healthobj) self.wfile.write(bytes(f'{healthstr}','utf-8')) return True # /v1/completions @@ -102,6 +146,7 @@ def handle_tabby_query(self, ai, obj, runline2, method): print(f"UnkPath: {self.path}") self.send_response(200) self.end_headers() + self.wfile.write(bytes('{}\n','utf-8')) return True def handle_custom_request(self, ai, msg, runline2, method): @@ -112,9 +157,9 @@ def handle_custom_request(self, ai, msg, runline2, method): return False if msg.startswith("{"): obj = json.loads(msg) - if "language" in obj: - handle_tabby_query(self, ai, obj, runline2, method) - return True + #if "language" in obj: + handle_tabby_query(self, ai, obj, runline2, method) + return True return True def start_http_server_now(ai, runline2):