From 0cbbcc0541b291afc44665466c1b01b744012127 Mon Sep 17 00:00:00 2001 From: pancake Date: Tue, 14 May 2024 23:56:17 +0200 Subject: [PATCH] Indent with 4 spaces --- Makefile | 9 +- main.py | 2 +- r2ai.sh | 1 + r2ai/anthropic.py | 269 +++++++------ r2ai/auto.py | 604 ++++++++++++++-------------- r2ai/backend/kobaldcpp.py | 46 ++- r2ai/const.py | 5 +- r2ai/index.py | 811 +++++++++++++++++++------------------- r2ai/large.py | 327 +++++++-------- r2ai/voice.py | 144 +++---- r2ai/web.py | 187 +++++---- run-venv.sh | 11 - 12 files changed, 1198 insertions(+), 1218 deletions(-) delete mode 100755 run-venv.sh diff --git a/Makefile b/Makefile index f31fdd9c..772b421e 100644 --- a/Makefile +++ b/Makefile @@ -6,18 +6,23 @@ PIP=$(PYTHON) -m pip LINTED=r2ai/code_block.py LINTED+=r2ai/bubble.py +LINTED+=r2ai/const.py +LINTED+=r2ai/voice.py LINTED+=setup.py LINTED+=main.py +LINTED+=r2ai/backend/kobaldcpp.py +# LINTED+=r2ai/index.py +# LINTED+=r2ai/anthropic.py ifeq ($(R2PM_BINDIR),) FATAL ERROR endif -.PHONY: all all.old venv deps clean deps-global pub lint cilint +.PHONY: all all.old deps clean deps-global pub lint cilint .PHONY: install uninstall user-install user-uninstall all: venv - ./r2ai.sh + @./r2ai.sh large: . venv/bin/activate ; $(PYTHON) main.py -l diff --git a/main.py b/main.py index 7d5c62df..92a9fe6f 100644 --- a/main.py +++ b/main.py @@ -14,4 +14,4 @@ runpy.run_path(os.path.join(r2aihome, 'r2ai', 'main.py')) else: ARGS = " ".join(sys.argv[1:]) - os.system(f"cd {r2aihome}; ./run-venv.sh {ARGS}") + os.system(f"cd {r2aihome}; ./r2ai.sh {ARGS}") diff --git a/r2ai.sh b/r2ai.sh index 0923106e..5dea34b9 100755 --- a/r2ai.sh +++ b/r2ai.sh @@ -14,4 +14,5 @@ if [ ! -d venv ]; then else . venv/bin/activate fi +# export PYTHONPATH=$PWD $PYTHON main.py $@ diff --git a/r2ai/anthropic.py b/r2ai/anthropic.py index 6885c0ab..caa69706 100644 --- a/r2ai/anthropic.py +++ b/r2ai/anthropic.py @@ -1,156 +1,153 @@ import re -import sys import random import string def get_random_tool_call_id(): - return "call_" + "".join( - [random.choice(string.ascii_letters + string.digits) for _ in range(24)] - ) + return "call_" + "".join( + [random.choice(string.ascii_letters + string.digits) for _ in range(24)] + ) def construct_tool_parameters_prompt(parameters): - prompt = "" - props = parameters["properties"] + prompt = "" + props = parameters["properties"] + for name in props: + parameter = props[name] + prompt += ( + "\n" + f"{name}\n" + f"{parameter['description']}\n" + f"{parameter['type']}\n" + "\n" + ) + return prompt - for name in props: - parameter = props[name] - prompt += ( - "\n" - f"{name}\n" - f"{parameter['description']}\n" - f"{parameter['type']}\n" - "\n" +def construct_tool_prompt(func): + tool = func['function'] + prompt = ( + "\n" + f"{tool['name']}\n" + "\n" + f"{tool['description']}\n" + "\n" + "\n" + f"{construct_tool_parameters_prompt(tool['parameters'])}\n" + "\n" + "" ) + return prompt - return prompt +def construct_tool_use_system_prompt(tools): + tool_use_system_prompt = ( + "In this environment you have access to a set of tools " + "you can use to answer the user's question.\n\n" + "You may call them like this:\n" + "\n" + "\n" + "$TOOL_NAME\n" + "\n" + "<$PARAMETER_NAME>$PARAMETER_VALUE\n" + "...\n" + "\n" + "\n" + "\n" + "\n" + "Here are the tools available:\n" + "\n" + + '\n'.join([construct_tool_prompt(tool) for tool in tools]) + + "\n" + ) + return tool_use_system_prompt -def construct_tool_prompt(func): - tool = func['function'] - prompt = ( - "\n" - f"{tool['name']}\n" - "\n" - f"{tool['description']}\n" - "\n" - "\n" - f"{construct_tool_parameters_prompt(tool['parameters'])}\n" - "\n" - "" - ) +TAGS = r'|||||||' - return prompt +def parse_tags(invoke_string): + tool_name = re.findall(r'.*?', invoke_string, re.DOTALL) + if not tool_name: + raise Exception("Missing tags inside of tags.") + if len(tool_name) > 1: + raise Exception("More than one tool_name specified inside single set of tags.") -def construct_tool_use_system_prompt(tools): - tool_use_system_prompt = ( - "In this environment you have access to a set of tools you can use to answer the user's question.\n" - "\n" - "You may call them like this:\n" - "\n" - "\n" - "$TOOL_NAME\n" - "\n" - "<$PARAMETER_NAME>$PARAMETER_VALUE\n" - "...\n" - "\n" - "\n" - "\n" - "\n" - "Here are the tools available:\n" - "\n" - + '\n'.join([construct_tool_prompt(tool) for tool in tools]) + - "\n" - ) - - return tool_use_system_prompt + parameters = re.findall(r'.*?', invoke_string, re.DOTALL) + if not parameters: + raise Exception("Missing tags inside of tags.") + if len(parameters) > 1: + raise Exception("More than one set of tags specified inside single set of tags.") + # Check for balanced tags inside parameters + # TODO: This will fail if the parameter value contains <> pattern + # TODO: or if there is a parameter called parameters. Fix that issue. + tags = re.findall(r'<.*?>', parameters[0].replace('', '').replace('', ''), re.DOTALL) + if len(tags) % 2 != 0: + raise Exception("Imbalanced tags inside tags.") + return tool_name, parameters, tags def _function_calls_valid_format_and_invoke_extraction(last_completion): - """Check if the function call follows a valid format and extract the attempted function calls if so. Does not check if the tools actually exist or if they are called with the requisite params.""" - - # Check if there are any of the relevant XML tags present that would indicate an attempted function call. - function_call_tags = re.findall(r'|||||||', last_completion, re.DOTALL) - if not function_call_tags: - # TODO: Should we return something in the text to claude indicating that it did not do anything to indicate an attempted function call (in case it was in fact trying to and we missed it)? - return {"status": True, "invokes": []} - - # Extract content between tags. If there are multiple we will only parse the first and ignore the rest, regardless of their correctness. - match = re.search(r'(.*)', last_completion, re.DOTALL) - if not match: - return {"status": False, "reason": "No valid tags present in your query."} - - func_calls = match.group(1) + """Check if the function call follows a valid format and extract the + attempted function calls if so. Does not check if the tools actually + exist or if they are called with the requisite params.""" + # Check if there are any of the relevant XML tags present that would + # indicate an attempted function call. + function_call_tags = re.findall(TAGS, last_completion, re.DOTALL) + if not function_call_tags: + # TODO: Should we return something in the text to claude indicating + # that it did not do anything to indicate an attempted function call + # (in case it was in fact trying to and we missed it)? + return {"status": True, "invokes": []} + # Extract content between tags. If there are multiple we + # will only parse the first and ignore the rest, regardless of their correctness. + match = re.search(r'(.*)', last_completion, re.DOTALL) + if not match: + return {"status": False, "reason": "No valid tags present in your query."} + func_calls = match.group(1) - prefix_match = re.search(r'^(.*?)', last_completion, re.DOTALL) - if prefix_match: - func_call_prefix_content = prefix_match.group(1) - - # Check for invoke tags - # TODO: Is this faster or slower than bundling with the next check? - invoke_regex = r'.*?' - if not re.search(invoke_regex, func_calls, re.DOTALL): - return {"status": False, "reason": "Missing tags inside of tags."} - - # Check each invoke contains tool name and parameters - invoke_strings = re.findall(invoke_regex, func_calls, re.DOTALL) - invokes = [] - for invoke_string in invoke_strings: - tool_name = re.findall(r'.*?', invoke_string, re.DOTALL) - if not tool_name: - return {"status": False, "reason": "Missing tags inside of tags."} - if len(tool_name) > 1: - return {"status": False, "reason": "More than one tool_name specified inside single set of tags."} + prefix_match = re.search(r'^(.*?)', last_completion, re.DOTALL) + if prefix_match: + func_call_prefix_content = prefix_match.group(1) + # Check for invoke tags + # TODO: Is this faster or slower than bundling with the next check? + invoke_regex = r'.*?' + if not re.search(invoke_regex, func_calls, re.DOTALL): + return {"status": False, "reason": "Missing tags inside of tags."} + # Check each invoke contains tool name and parameters + invoke_strings = re.findall(invoke_regex, func_calls, re.DOTALL) + invokes = [] + for invoke_string in invoke_strings: + try: + tool_name, parameters, tags = parse_tags(invoke_string) + except Exception as e: + return {"status": False, "reason": e} - parameters = re.findall(r'.*?', invoke_string, re.DOTALL) - if not parameters: - return {"status": False, "reason": "Missing tags inside of tags."} - if len(parameters) > 1: - return {"status": False, "reason": "More than one set of tags specified inside single set of tags."} - - # Check for balanced tags inside parameters - # TODO: This will fail if the parameter value contains <> pattern or if there is a parameter called parameters. Fix that issue. - tags = re.findall(r'<.*?>', parameters[0].replace('', '').replace('', ''), re.DOTALL) - if len(tags) % 2 != 0: - return {"status": False, "reason": "Imbalanced tags inside tags."} - - # Loop through the tags and check if each even-indexed tag matches the tag in the position after it (with the / of course). If valid store their content for later use. - # TODO: Add a check to make sure there aren't duplicates provided of a given parameter. - arguments = {} - for i in range(0, len(tags), 2): - opening_tag = tags[i] - closing_tag = tags[i+1] - closing_tag_without_second_char = closing_tag[:1] + closing_tag[2:] - if closing_tag[1] != '/' or opening_tag != closing_tag_without_second_char: - return {"status": False, "reason": "Non-matching opening and closing tags inside tags."} - - arguments[opening_tag[1:-1]] = re.search(rf'{opening_tag}(.*?){closing_tag}', parameters[0], re.DOTALL).group(1) - - # Parse out the full function call - invokes.append({ - "function": { - "name": tool_name[0].replace('', '').replace('', ''), + # Loop through the tags and check if each even-indexed tag matches the + # tag in the position after it (with the / of course). If valid store + # their content for later use. + # TODO: Add a check to make sure there aren't duplicates provided of a given parameter. + arguments = {} + for i in range(0, len(tags), 2): + opening_tag = tags[i] + closing_tag = tags[i+1] + closing_tag_without_second_char = closing_tag[:1] + closing_tag[2:] + if closing_tag[1] != '/' or opening_tag != closing_tag_without_second_char: + return {"status": False, "reason": "Non-matching opening and closing tags inside tags."} + arguments[opening_tag[1:-1]] = re.search(rf'{opening_tag}(.*?){closing_tag}', parameters[0], re.DOTALL).group(1) + # Parse out the full function call + invokes.append({ + "function": { + "name": tool_name[0].replace('', '').replace('', ''), "arguments": arguments, - }, - "id": get_random_tool_call_id() - }) - - return {"status": True, "invokes": invokes, "prefix_content": func_call_prefix_content} + }, + "id": get_random_tool_call_id() + }) + return {"status": True, "invokes": invokes, "prefix_content": func_call_prefix_content} def extract_claude_tool_calls(interpreter, stream): - msg = '' - res = None - for event in stream: - if event.type == "content_block_delta": - delta = event.delta - msg += delta.text - res = _function_calls_valid_format_and_invoke_extraction(msg) - if res["status"] == True and "invokes" in res and len(res["invokes"]) > 0: - interpreter.messages.append({ "role": "assistant", "content": msg}) - return res["invokes"], res["prefix_content"] - - interpreter.messages.append({ "role": "assistant", "content": msg}) - return [], re.sub(r'.*', '', msg) - - - - - - + msg = '' + res = None + for event in stream: + if event.type == "content_block_delta": + delta = event.delta + msg += delta.text + res = _function_calls_valid_format_and_invoke_extraction(msg) + if res["status"] is True and "invokes" in res and len(res["invokes"]) > 0: + interpreter.messages.append({ "role": "assistant", "content": msg}) + return res["invokes"], res["prefix_content"] + interpreter.messages.append({ "role": "assistant", "content": msg}) + return [], re.sub(r'.*', '', msg) diff --git a/r2ai/auto.py b/r2ai/auto.py index ffc07c30..518497ee 100644 --- a/r2ai/auto.py +++ b/r2ai/auto.py @@ -5,7 +5,6 @@ from llama_cpp import Llama from llama_cpp.llama_tokenizer import LlamaHFTokenizer from transformers import AutoTokenizer - from .anthropic import construct_tool_use_system_prompt, extract_claude_tool_calls import os @@ -67,332 +66,329 @@ """ FUNCTIONARY_PROMPT_AUTO = """ - Think step by step. - Break down the task into steps and execute the necessary `radare2` commands in order to complete the task. +Think step by step. +Break down the task into steps and execute the necessary `radare2` commands in order to complete the task. """ def get_system_prompt(model): - if model.startswith("meetkai/"): - return SYSTEM_PROMPT_AUTO + "\n" + FUNCTIONARY_PROMPT_AUTO - elif model.startswith("anthropic"): - return SYSTEM_PROMPT_AUTO + "\n\n" + construct_tool_use_system_prompt(tools) - else: + if model.startswith("meetkai/"): + return SYSTEM_PROMPT_AUTO + "\n" + FUNCTIONARY_PROMPT_AUTO + if model.startswith("anthropic"): + return SYSTEM_PROMPT_AUTO + "\n\n" + construct_tool_use_system_prompt(tools) return SYSTEM_PROMPT_AUTO functionary_tokenizer = None def get_functionary_tokenizer(repo_id): - global functionary_tokenizer - if functionary_tokenizer is None: - functionary_tokenizer = AutoTokenizer.from_pretrained(repo_id, legacy=True) - return functionary_tokenizer + global functionary_tokenizer + if functionary_tokenizer is None: + functionary_tokenizer = AutoTokenizer.from_pretrained(repo_id, legacy=True) + return functionary_tokenizer def r2cmd(command: str): - """runs commands in radare2. You can run it multiple times or chain commands with pipes/semicolons. You can also use r2 interpreters to run scripts using the `#`, '#!', etc. commands. The output could be long, so try to use filters if possible or limit. This is your preferred tool""" - builtins.print('\x1b[1;32mRunning \x1b[4m' + command + '\x1b[0m') - res = r2lang.cmd(command) - builtins.print(res) - return res + """runs commands in radare2. You can run it multiple times or chain commands + with pipes/semicolons. You can also use r2 interpreters to run scripts using + the `#`, '#!', etc. commands. The output could be long, so try to use filters + if possible or limit. This is your preferred tool""" + builtins.print('\x1b[1;32mRunning \x1b[4m' + command + '\x1b[0m') + res = r2lang.cmd(command) + builtins.print(res) + return res def run_python(command: str): - """runs a python script and returns the results""" - with open('r2ai_tmp.py', 'w') as f: - f.write(command) - builtins.print('\x1b[1;32mRunning \x1b[4m' + "python code" + '\x1b[0m') - builtins.print(command) - r2lang.cmd('#!python r2ai_tmp.py > $tmp') - res = r2lang.cmd('cat $tmp') - r2lang.cmd('rm r2ai_tmp.py') - builtins.print('\x1b[1;32mResult\x1b[0m\n' + res) - return res + """runs a python script and returns the results""" + with open('r2ai_tmp.py', 'w') as f: + f.write(command) + builtins.print('\x1b[1;32mRunning \x1b[4m' + "python code" + '\x1b[0m') + builtins.print(command) + r2lang.cmd('#!python r2ai_tmp.py > $tmp') + res = r2lang.cmd('cat $tmp') + r2lang.cmd('rm r2ai_tmp.py') + builtins.print('\x1b[1;32mResult\x1b[0m\n' + res) + return res def process_tool_calls(interpreter, tool_calls): - interpreter.messages.append({ "content": None, "tool_calls": tool_calls, "role": "assistant" }) - for tool_call in tool_calls: - res = '' - args = tool_call["function"]["arguments"] - if type(args) is str: - try: - args = json.loads(args) - except: - builtins.print(f"Error parsing json: {args}", file=sys.stderr) - - if tool_call["function"]["name"] == "r2cmd": - if type(args) is str: - args = { "command": args } - if "command" in args: - res = r2cmd(args["command"]) - elif tool_call["function"]["name"] == "run_python": - res = run_python(args["command"]) - if (not res or len(res) == 0) and interpreter.model.startswith('meetkai/'): - res = "OK done" - interpreter.messages.append({"role": "tool", "content": ANSI_REGEX.sub('', res), "name": tool_call["function"]["name"], "tool_call_id": tool_call["id"] if "id" in tool_call else None}) + interpreter.messages.append({ "content": None, "tool_calls": tool_calls, "role": "assistant" }) + for tool_call in tool_calls: + res = '' + args = tool_call["function"]["arguments"] + if type(args) is str: + try: + args = json.loads(args) + except: + builtins.print(f"Error parsing json: {args}", file=sys.stderr) + if tool_call["function"]["name"] == "r2cmd": + if type(args) is str: + args = { "command": args } + if "command" in args: + res = r2cmd(args["command"]) + elif tool_call["function"]["name"] == "run_python": + res = run_python(args["command"]) + if (not res or len(res) == 0) and interpreter.model.startswith('meetkai/'): + res = "OK done" + msg = { + "role": "tool", + "content": ANSI_REGEX.sub('', res), + "name": tool_call["function"]["name"], + "tool_call_id": tool_call["id"] if "id" in tool_call else None + } + interpreter.messages.append(msg) def process_hermes_response(interpreter, response): - choice = response["choices"][0] - message = choice["message"] - interpreter.messages.append(message) - r = re.search(r'([\s\S]*?)<\/tool_call>', message["content"]) - tool_call_str = None - if r: - tool_call_str = r.group(1) - tool_calls = [] - if tool_call_str: - tool_call = json.loads(tool_call_str) - tool_calls.append({"function": tool_call}) - - if len(tool_calls) > 0: - process_tool_calls(interpreter, tool_calls) - chat(interpreter) - else: - interpreter.messages.append({ "content": message["content"], "role": "assistant" }) - sys.stdout.write(message["content"]) - builtins.print() - -def process_streaming_response(interpreter, resp): - tool_calls = [] - msgs = [] - for chunk in resp: - try: - chunk = dict(chunk) - except: - pass - delta = None - choice = dict(chunk["choices"][0]) - if "delta" in choice: - delta = dict(choice["delta"]) - else: - delta = dict(choice["message"]) - if "tool_calls" in delta and delta["tool_calls"]: - delta_tool_calls = dict(delta["tool_calls"][0]) - index = 0 if "index" not in delta_tool_calls else delta_tool_calls["index"] - fn_delta = dict(delta_tool_calls["function"]) - tool_call_id = delta_tool_calls["id"] - if len(tool_calls) < index + 1: - tool_calls.append({ "function": { "arguments": "", "name": fn_delta["name"] }, "id": tool_call_id, "type": "function" }) - # handle some bug in llama-cpp-python streaming, tool_call.arguments is sometimes blank, but function_call has it. - if fn_delta["arguments"] == '': - if "function_call" in delta and delta["function_call"]: - tool_calls[index]["function"]["arguments"] += delta["function_call"]["arguments"] - else: - tool_calls[index]["function"]["arguments"] += fn_delta["arguments"] + choice = response["choices"][0] + message = choice["message"] + interpreter.messages.append(message) + r = re.search(r'([\s\S]*?)<\/tool_call>', message["content"]) + tool_call_str = None + if r: + tool_call_str = r.group(1) + tool_calls = [] + if tool_call_str: + tool_call = json.loads(tool_call_str) + tool_calls.append({"function": tool_call}) + if len(tool_calls) > 0: + process_tool_calls(interpreter, tool_calls) + chat(interpreter) else: - if "content" in delta and delta["content"] is not None: - m = delta["content"] - if m is not None: - msgs.append(m) - sys.stdout.write(m) - builtins.print() - if(len(tool_calls) > 0): - process_tool_calls(interpreter, tool_calls) - chat(interpreter) + interpreter.messages.append({ "content": message["content"], "role": "assistant" }) + sys.stdout.write(message["content"]) + builtins.print() - if len(msgs) > 0: - response_message = ''.join(msgs) - interpreter.messages.append({"role": "assistant", "content": response_message}) +def process_streaming_response(interpreter, resp): + tool_calls = [] + msgs = [] + for chunk in resp: + try: + chunk = dict(chunk) + except: + pass + delta = None + choice = dict(chunk["choices"][0]) + if "delta" in choice: + delta = dict(choice["delta"]) + else: + delta = dict(choice["message"]) + if "tool_calls" in delta and delta["tool_calls"]: + delta_tool_calls = dict(delta["tool_calls"][0]) + index = 0 if "index" not in delta_tool_calls else delta_tool_calls["index"] + fn_delta = dict(delta_tool_calls["function"]) + tool_call_id = delta_tool_calls["id"] + if len(tool_calls) < index + 1: + tool_calls.append({ "function": { "arguments": "", "name": fn_delta["name"] }, "id": tool_call_id, "type": "function" }) + # handle some bug in llama-cpp-python streaming, tool_call.arguments is sometimes blank, but function_call has it. + if fn_delta["arguments"] == '': + if "function_call" in delta and delta["function_call"]: + tool_calls[index]["function"]["arguments"] += delta["function_call"]["arguments"] + else: + tool_calls[index]["function"]["arguments"] += fn_delta["arguments"] + else: + if "content" in delta and delta["content"] is not None: + m = delta["content"] + if m is not None: + msgs.append(m) + sys.stdout.write(m) + builtins.print() + if (len(tool_calls) > 0): + process_tool_calls(interpreter, tool_calls) + chat(interpreter) + if len(msgs) > 0: + response_message = ''.join(msgs) + interpreter.messages.append({"role": "assistant", "content": response_message}) def context_from_msg(msg): - keywords = None - datadir = "doc/auto" - use_vectordb = False - matches = index.match(msg, keywords, datadir, False, False, False, False, use_vectordb) - if matches == None: - return "" - # "(analyze using 'af', decompile using 'pdc')" - return "context: " + ", ".join(matches) + keywords = None + datadir = "doc/auto" + use_vectordb = False + matches = index.match(msg, keywords, datadir, False, False, False, False, use_vectordb) + if matches == None: + return "" + # "(analyze using 'af', decompile using 'pdc')" + return "context: " + ", ".join(matches) def chat(interpreter): - if len(interpreter.messages) == 1: - interpreter.messages.insert(0,{"role": "system", "content": get_system_prompt(interpreter.model)}) + if len(interpreter.messages) == 1: + interpreter.messages.insert(0,{"role": "system", "content": get_system_prompt(interpreter.model)}) - lastmsg = interpreter.messages[-1]["content"] - chat_context = context_from_msg (lastmsg) - #print("#### CONTEXT BEGIN") - #print(chat_context) # DEBUG - #print("#### CONTEXT END") - if chat_context != "": - interpreter.messages.insert(0,{"role": "user", "content": chat_context}) + lastmsg = interpreter.messages[-1]["content"] + chat_context = context_from_msg (lastmsg) + #print("#### CONTEXT BEGIN") + #print(chat_context) # DEBUG + #print("#### CONTEXT END") + if chat_context != "": + interpreter.messages.insert(0,{"role": "user", "content": chat_context}) - response = None - if interpreter.model.startswith("openai:"): - if not interpreter.openai_client: - try: - from openai import OpenAI - except ImportError: - print("pip install -U openai", file=sys.stderr) - print("export OPENAI_API_KEY=...", file=sys.stderr) - return - interpreter.openai_client = OpenAI() - - response = interpreter.openai_client.chat.completions.create( - model=interpreter.model[7:], - max_tokens=int(interpreter.env["llm.maxtokens"]), - tools=tools, - messages=interpreter.messages, - tool_choice="auto", - stream=True, - temperature=float(interpreter.env["llm.temperature"]), - ) - process_streaming_response(interpreter, response) - elif interpreter.model.startswith('anthropic:'): - if not interpreter.anthropic_client: - try: - from anthropic import Anthropic - except ImportError: - print("pip install -U anthropic", file=sys.stderr) - return - interpreter.anthropic_client = Anthropic() - messages = [] - system_message = construct_tool_use_system_prompt(tools) - for m in interpreter.messages: - role = m["role"] - if role == "system": - continue - if m["content"] is None: - continue - if role == "tool": - messages.append({ "role": "user", "content": f"\n\n{m['name']}\n{m['content']}\n\n" }) - # TODO: handle errors - else: - messages.append({ "role": role, "content": m["content"] }) - - stream = interpreter.anthropic_client.messages.create( - model=interpreter.model[10:], - max_tokens=int(interpreter.env["llm.maxtokens"]), - messages=messages, - system=system_message, - temperature=float(interpreter.env["llm.temperature"]), - stream=True - ) - (tool_calls, msg) = extract_claude_tool_calls(interpreter, stream) - if len(tool_calls) > 0: - process_tool_calls(interpreter, tool_calls) - chat(interpreter) - else: - builtins.print(msg) - elif interpreter.model.startswith("groq:"): - if not interpreter.groq_client: - try: - from groq import Groq - except ImportError: - print("pip install -U groq", file=sys.stderr) - return - interpreter.groq_client = Groq() - - response = interpreter.groq_client.chat.completions.create( - model=interpreter.model[5:], - max_tokens=int(interpreter.env["llm.maxtokens"]), - tools=tools, - messages=interpreter.messages, - tool_choice="auto", - temperature=float(interpreter.env["llm.temperature"]), - ) - process_streaming_response(interpreter, [response]) - elif interpreter.model.startswith("google"): - if not interpreter.google_client: - try: - import google.generativeai as google - google.configure(api_key=os.environ['GOOGLE_API_KEY']) - except ImportError: - print("pip install -U google-generativeai", file=sys.stderr) - return - - interpreter.google_client = google.GenerativeModel(interpreter.model[7:]) - if not interpreter.google_chat: - interpreter.google_chat = interpreter.google_client.start_chat( - enable_automatic_function_calling=True + response = None + if interpreter.model.startswith("openai:"): + if not interpreter.openai_client: + try: + from openai import OpenAI + except ImportError: + print("pip install -U openai", file=sys.stderr) + print("export OPENAI_API_KEY=...", file=sys.stderr) + return + interpreter.openai_client = OpenAI() + response = interpreter.openai_client.chat.completions.create( + model=interpreter.model[7:], + max_tokens=int(interpreter.env["llm.maxtokens"]), + tools=tools, + messages=interpreter.messages, + tool_choice="auto", + stream=True, + temperature=float(interpreter.env["llm.temperature"]), ) - - response = interpreter.google_chat.send_message( - interpreter.messages[-1]["content"], - generation_config={ - "max_output_tokens": int(interpreter.env["llm.maxtokens"]), - "temperature": float(interpreter.env["llm.temperature"]) - }, - safety_settings=[{ - "category": "HARM_CATEGORY_DANGEROUS_CONTENT", - "threshold": "BLOCK_NONE" - }], - tools=[r2cmd, run_python] - ) - print(response.text) - else: - chat_format = interpreter.llama_instance.chat_format - is_functionary = interpreter.model.startswith("meetkai/") - if is_functionary: - try: - from .functionary import prompt_template - except ImportError: - print("pip install -U functionary", file=sys.stderr) - return - - tokenizer = get_functionary_tokenizer(interpreter.model) - prompt_templ = prompt_template.get_prompt_template_from_tokenizer(tokenizer) - #print("############# BEGIN") - #print(dir(prompt_templ)) - #print("############# MESSAGES") - #print(interpreter.messages) - #print("############# END") - prompt_str = prompt_templ.get_prompt_from_messages(interpreter.messages + [{"role": "assistant"}], tools) - token_ids = tokenizer.encode(prompt_str) - stop_token_ids = [ - tokenizer.encode(token)[-1] - for token in prompt_templ.get_stop_tokens_for_generation() - ] - gen_tokens = [] - for token_id in interpreter.llama_instance.generate(token_ids, temp=float(interpreter.env["llm.temperature"])): - sys.stdout.write(tokenizer.decode([token_id])) - if token_id in stop_token_ids: - break - gen_tokens.append(token_id) - llm_output = tokenizer.decode(gen_tokens) - response = prompt_templ.parse_assistant_response(llm_output) - process_streaming_response(interpreter, iter([ - { "choices": [{ "message": response }] } - ])) - elif interpreter.model.startswith("NousResearch/"): - interpreter.llama_instance.chat_format = "chatml" - messages = [] - for m in interpreter.messages: - role = m["role"] - if m["content"] is None: - continue - if role == "system": - if not '' in m["content"]: - messages.append({ "role": "system", "content": f"""{m['content']}\nYou are a function calling AI model. You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: - {json.dumps(tools)} -For each function call return a json object with function name and arguments within XML tags as follows: - -{{"arguments": , "name": }} -"""}) - elif role == "tool": - messages.append({ "role": "tool", "content": "\n" + '{"name": ' + m['name'] + ', "content": ' + json.dumps(m['content']) + '}\n' }) + process_streaming_response(interpreter, response) + elif interpreter.model.startswith('anthropic:'): + if not interpreter.anthropic_client: + try: + from anthropic import Anthropic + except ImportError: + print("pip install -U anthropic", file=sys.stderr) + return + interpreter.anthropic_client = Anthropic() + messages = [] + system_message = construct_tool_use_system_prompt(tools) + for m in interpreter.messages: + role = m["role"] + if role == "system": + continue + if m["content"] is None: + continue + if role == "tool": + messages.append({ "role": "user", "content": f"\n\n{m['name']}\n{m['content']}\n\n" }) + # TODO: handle errors + else: + messages.append({ "role": role, "content": m["content"] }) + stream = interpreter.anthropic_client.messages.create( + model=interpreter.model[10:], + max_tokens=int(interpreter.env["llm.maxtokens"]), + messages=messages, + system=system_message, + temperature=float(interpreter.env["llm.temperature"]), + stream=True + ) + (tool_calls, msg) = extract_claude_tool_calls(interpreter, stream) + if len(tool_calls) > 0: + process_tool_calls(interpreter, tool_calls) + chat(interpreter) else: - messages.append(m) - - response = interpreter.llama_instance.create_chat_completion( - max_tokens=int(interpreter.env["llm.maxtokens"]), - messages=messages, - temperature=float(interpreter.env["llm.temperature"]), - ) - process_hermes_response(interpreter, response) - interpreter.llama_instance.chat_format = chat_format - + builtins.print(msg) + elif interpreter.model.startswith("groq:"): + if not interpreter.groq_client: + try: + from groq import Groq + except ImportError: + print("pip install -U groq", file=sys.stderr) + return + interpreter.groq_client = Groq() + response = interpreter.groq_client.chat.completions.create( + model=interpreter.model[5:], + max_tokens=int(interpreter.env["llm.maxtokens"]), + tools=tools, + messages=interpreter.messages, + tool_choice="auto", + temperature=float(interpreter.env["llm.temperature"]), + ) + process_streaming_response(interpreter, [response]) + elif interpreter.model.startswith("google"): + if not interpreter.google_client: + try: + import google.generativeai as google + google.configure(api_key=os.environ['GOOGLE_API_KEY']) + except ImportError: + print("pip install -U google-generativeai", file=sys.stderr) + return + interpreter.google_client = google.GenerativeModel(interpreter.model[7:]) + if not interpreter.google_chat: + interpreter.google_chat = interpreter.google_client.start_chat( + enable_automatic_function_calling=True + ) + response = interpreter.google_chat.send_message( + interpreter.messages[-1]["content"], + generation_config={ + "max_output_tokens": int(interpreter.env["llm.maxtokens"]), + "temperature": float(interpreter.env["llm.temperature"]) + }, + safety_settings=[{ + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "threshold": "BLOCK_NONE" + }], + tools=[r2cmd, run_python] + ) + print(response.text) + else: + chat_format = interpreter.llama_instance.chat_format + is_functionary = interpreter.model.startswith("meetkai/") + if is_functionary: + try: + from .functionary import prompt_template + except ImportError: + print("pip install -U functionary", file=sys.stderr) + return + tokenizer = get_functionary_tokenizer(interpreter.model) + prompt_templ = prompt_template.get_prompt_template_from_tokenizer(tokenizer) + #print("############# BEGIN") + #print(dir(prompt_templ)) + #print("############# MESSAGES") + #print(interpreter.messages) + #print("############# END") + prompt_str = prompt_templ.get_prompt_from_messages(interpreter.messages + [{"role": "assistant"}], tools) + token_ids = tokenizer.encode(prompt_str) + stop_token_ids = [ + tokenizer.encode(token)[-1] + for token in prompt_templ.get_stop_tokens_for_generation() + ] + gen_tokens = [] + for token_id in interpreter.llama_instance.generate(token_ids, temp=float(interpreter.env["llm.temperature"])): + sys.stdout.write(tokenizer.decode([token_id])) + if token_id in stop_token_ids: + break + gen_tokens.append(token_id) + llm_output = tokenizer.decode(gen_tokens) + response = prompt_templ.parse_assistant_response(llm_output) + process_streaming_response(interpreter, iter([ + { "choices": [{ "message": response }] } + ])) + elif interpreter.model.startswith("NousResearch/"): + interpreter.llama_instance.chat_format = "chatml" + messages = [] + for m in interpreter.messages: + if m["content"] is None: + continue + role = m["role"] + if role == "system": + if not '' in m["content"]: + messages.append({ "role": "system", "content": f"""{m['content']}\nYou are a function calling AI model. You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: + {json.dumps(tools)} + For each function call return a json object with function name and arguments within XML tags as follows: + + {{"arguments": , "name": }} + """}) + elif role == "tool": + messages.append({ "role": "tool", "content": "\n" + '{"name": ' + m['name'] + ', "content": ' + json.dumps(m['content']) + '}\n' }) + else: + messages.append(m) + response = interpreter.llama_instance.create_chat_completion( + max_tokens=int(interpreter.env["llm.maxtokens"]), + messages=messages, + temperature=float(interpreter.env["llm.temperature"]), + ) + process_hermes_response(interpreter, response) + interpreter.llama_instance.chat_format = chat_format else: - interpreter.llama_instance.chat_format = "chatml-function-calling" - response = interpreter.llama_instance.create_chat_completion( - max_tokens=int(interpreter.env["llm.maxtokens"]), - tools=tools, - messages=interpreter.messages, - tool_choice="auto", - # tool_choice={ - # "type": "function", - # "function": { - # "name": "r2cmd" - # } - # }, - # stream=is_functionary, - temperature=float(interpreter.env["llm.temperature"]), - ) - process_streaming_response(interpreter, iter([response])) - interpreter.llama_instance.chat_format = chat_format - return response + interpreter.llama_instance.chat_format = "chatml-function-calling" + response = interpreter.llama_instance.create_chat_completion( + max_tokens=int(interpreter.env["llm.maxtokens"]), + tools=tools, + messages=interpreter.messages, + tool_choice="auto", + # tool_choice={ + # "type": "function", + # "function": { + # "name": "r2cmd" + # } + # }, + # stream=is_functionary, + temperature=float(interpreter.env["llm.temperature"]), + ) + process_streaming_response(interpreter, iter([response])) + interpreter.llama_instance.chat_format = chat_format + return response diff --git a/r2ai/backend/kobaldcpp.py b/r2ai/backend/kobaldcpp.py index cbe3ac75..f0b62e46 100644 --- a/r2ai/backend/kobaldcpp.py +++ b/r2ai/backend/kobaldcpp.py @@ -1,27 +1,34 @@ - +"""Implementation for kobaldcpp http api call using openai endpoint.""" import json import requests -def chat(message, API_ENDPOINT='http://localhost:5001'): - API_ENDPOINT+='/v1/completions' +PROMPT="""Your name is r2ai, an assistant for radare2. +User will ask about actions and you must respond with the radare2 command +associated or the answer to the question. Be precise and concise when answering +""" + +def chat(message, uri='http://localhost:5001'): + """Send a message to a kobaldcpp server and return the autocompletion response + """ + url = f'{uri}/v1/completions' data = { - "max_length": 1024, - "prompt": message, - "quiet": True, - "n": 1, - "echo": False, - "stop": ["\nUser:"], - "rep_pen": 1.1, - "rep_pen_range": 256, - "rep_pen_slope": 1, - "temperature": 0.3, - "tfs": 1, - "top_a": 0, - "top_k": 100, - "top_p": 0.9, - "typical": 1 + "max_length": 1024, + "prompt": message, + "quiet": True, + "n": 1, + "echo": False, + "stop": ["\nUser:"], + "rep_pen": 1.1, + "rep_pen_range": 256, + "rep_pen_slope": 1, + "temperature": 0.3, + "tfs": 1, + "top_a": 0, + "top_k": 100, + "top_p": 0.9, + "typical": 1 } - r = requests.post(url=API_ENDPOINT, data=json.dumps(data)) + r = requests.post(url=url, data=json.dumps(data), timeout=600) j = json.loads(r.text) i = j["choices"][0]["text"] return i @@ -30,7 +37,6 @@ def chat(message, API_ENDPOINT='http://localhost:5001'): #AI="AI" #US="User" #CTX="Context" -#fullmsg = f"Context:\n```{m}```\nYour name is r2ai, an assistant for radare2. User will ask about actions and you must respond with the radare2 command associated or the answer to the question. Be precise and concise when answering" #while True: # message = input() # qmsg = f"{CTX}:\n```{fullmsg}\n```\n{US}: {message}\n" diff --git a/r2ai/const.py b/r2ai/const.py index 6efbac4f..e6746efb 100644 --- a/r2ai/const.py +++ b/r2ai/const.py @@ -1,3 +1,4 @@ +"""File containing constants.""" import os join = os.path.join @@ -6,6 +7,6 @@ R2AI_HISTFILE = "r2ai.history.txt" # windows path R2AI_RCFILE = "r2ai.txt" if "HOME" in os.environ: - R2AI_HISTFILE = join(os.environ["HOME"], ".r2ai.history") - R2AI_RCFILE = join(os.environ["HOME"], ".r2ai.rc") + R2AI_HISTFILE = join(os.environ["HOME"], ".r2ai.history") + R2AI_RCFILE = join(os.environ["HOME"], ".r2ai.rc") R2AI_USERDIR = join(os.environ["HOME"], ".r2ai.plugins") diff --git a/r2ai/index.py b/r2ai/index.py index f11f4472..811dc168 100644 --- a/r2ai/index.py +++ b/r2ai/index.py @@ -9,11 +9,11 @@ from unidecode import unidecode import sys try: - from .utils import slurp - from .const import R2AI_HISTFILE + from .utils import slurp + from .const import R2AI_HISTFILE except: - from utils import slurp - R2AI_HISTFILE = "/dev/null" + from utils import slurp + R2AI_HISTFILE = "/dev/null" have_vectordb = None vectordb_instance = None @@ -22,452 +22,439 @@ MAXMATCHES = 5 MASTODON_KEY = "" try: - if "HOME" in os.environ: - MASTODON_KEY = slurp(os.environ["HOME"] + "/.r2ai.mastodon-key").strip() + if "HOME" in os.environ: + MASTODON_KEY = slurp(os.environ["HOME"] + "/.r2ai.mastodon-key").strip() except: - pass + pass MASTODON_INSTANCE = "mastodont.cat" if "MASTODON_INSTANCE" in os.environ: - MASTODON_INSTANCE = os.environ["MASTODON_INSTANCE"] + MASTODON_INSTANCE = os.environ["MASTODON_INSTANCE"] def mastodon_search(text): -# print("mastodon", text) - global MASTODON_INSTANCE -# print(f"(mastodon) {text}") - res = [] - full_url = f"https://{MASTODON_INSTANCE}/api/v2/search?resolve=true&limit=8&type=statuses&q={text}" - try: - headers = {"Authorization": f"Bearer {MASTODON_KEY}"} - response = requests.get(full_url, headers=headers) - response.raise_for_status() # Raise an HTTPError for bad responses - for msg in response.json()["statuses"]: - content = re.sub(r'<.*?>', '', msg["content"]) - res.append(content) - except requests.exceptions.RequestException as e: - print(f"Error making request: {e}") - return res + global MASTODON_INSTANCE + res = [] + full_url = f"https://{MASTODON_INSTANCE}/api/v2/search?resolve=true&limit=8&type=statuses&q={text}" + try: + headers = {"Authorization": f"Bearer {MASTODON_KEY}"} + response = requests.get(full_url, headers=headers) + response.raise_for_status() # Raise an HTTPError for bad responses + for msg in response.json()["statuses"]: + content = re.sub(r'<.*?>', '', msg["content"]) + res.append(content) + except requests.exceptions.RequestException as e: + print(f"Error making request: {e}", file=sys.stderr) + return res def mastodon_lines(text, keywords, use_vectordb): - twords = [] - rtlines = [] - if keywords is None: - twords = filter_line(text) - for tw in twords: - if len(tw) > 2: # arbitrary - rtlines.extend(mastodon_search(tw)) - else: - twords = keywords - -# print("MASTODON_LINES...", text) -# print("MASTODON_RTLINES...", rtlines) - if use_vectordb: - return rtlines - words = {} # local rarity ratings - for line in rtlines: - fline = filter_line(line) - for a in fline: - if words.get(a): - words[a] += 1 - else: - words[a] = 1 - rtlines = sorted(set(rtlines)) - rslines = [] - swords = sorted(twords, key=lambda x: words.get(x) or 0) - nwords = " ".join(swords[:5]) - # find rarity of words in the results + text and - # print("NWORDS", nwords) - rslines.extend(mastodon_search(nwords)) - if len(rslines) < 10: - for tw in swords: - w = words.get(tw) - if len(tw) > 4 and w is not None and w > 0 and w < 40: - # print(f"RELEVANT WORD {tw} {w}") - rslines.extend(mastodon_search(tw)) - return rslines + twords = [] + rtlines = [] + if keywords is None: + twords = filter_line(text) + for tw in twords: + if len(tw) > 2: # arbitrary + rtlines.extend(mastodon_search(tw)) + else: + twords = keywords + # print("MASTODON_LINES...", text) + # print("MASTODON_RTLINES...", rtlines) + if use_vectordb: + return rtlines + words = {} # local rarity ratings + for line in rtlines: + fline = filter_line(line) + for a in fline: + if words.get(a): + words[a] += 1 + else: + words[a] = 1 + rtlines = sorted(set(rtlines)) + rslines = [] + swords = sorted(twords, key=lambda x: words.get(x) or 0) + nwords = " ".join(swords[:5]) + # find rarity of words in the results + text and + # print("NWORDS", nwords) + rslines.extend(mastodon_search(nwords)) + if len(rslines) < 10: + for tw in swords: + w = words.get(tw) + if len(tw) > 4 and w is not None and w > 0 and w < 40: + # print(f"RELEVANT WORD {tw} {w}") + rslines.extend(mastodon_search(tw)) + return rslines def hist2txt(text): - newlines = [] - lines = text.split("\n") - for line in lines: - line = line.strip().replace("\\040", " ") - if len(line) < 8: - continue - elif "?" in line: - continue - elif line.startswith("-") or line.startswith("_") or line.startswith("!"): - continue - elif line.startswith("-r"): - # newlines.append(line[2:]) - continue - else: - newlines.append(line) - newlines = sorted(set(newlines)) - return "\n".join(newlines) + newlines = [] + lines = text.split("\n") + for line in lines: + line = line.strip().replace("\\040", " ") + if len(line) < 8: + continue + elif "?" in line: + continue + elif line.startswith("-") or line.startswith("_") or line.startswith("!"): + continue + elif line.startswith("-r"): + # newlines.append(line[2:]) + continue + else: + newlines.append(line) + newlines = sorted(set(newlines)) + return "\n".join(newlines) def json2md(text): - def jsonwalk(obj): - res = "" - if isinstance(obj, list): - for a in obj: - res += jsonwalk(a) - elif isinstance(obj, dict): - if "file" in obj: - pass -# elif "impactType" in obj and obj["impactType"] == "pass": -# pass - elif "ts" in obj: - pass - else: - for k in obj.keys(): - res += "## " + k + "\n" - lst = json.dumps(obj[k]).replace("{","").replace("}", "\n").replace("\"", "").replace(",", "*").split("\n") - res += "\n".join(list(filter(lambda k: 'crc64' not in k and 'file' not in k and 'from_text' not in k and 'backtrace' not in k, lst))) - res += "\n\n" - else: - res += str(obj) # jsonwalk(obj) - return res - doc = json.loads(text) - res = jsonwalk(doc) -# print("==========") -# print(res) -# print("==========") - return res + def jsonwalk(obj): + res = "" + if isinstance(obj, list): + for a in obj: + res += jsonwalk(a) + elif isinstance(obj, dict): + if "file" in obj: + pass + # elif "impactType" in obj and obj["impactType"] == "pass": + # pass + elif "ts" in obj: + pass + else: + for k in obj.keys(): + res += "## " + k + "\n" + lst = json.dumps(obj[k]).replace("{","").replace("}", "\n").replace("\"", "").replace(",", "*").split("\n") + res += "\n".join(list(filter(lambda k: 'crc64' not in k and 'file' not in k and 'from_text' not in k and 'backtrace' not in k, lst))) + res += "\n\n" + else: + res += str(obj) # jsonwalk(obj) + return res + doc = json.loads(text) + res = jsonwalk(doc) + # print("==========") + # print(res) + # print("==========") + return res def md2txt(text): - # parser markdown and return a txt - lines = text.split("\n") - newlines = [] - data = "" - titles = ["", "", ""] - read_block = False - for line in lines: - line = line.strip() - if line == "": - continue - if read_block: - data += line + "\\n" - if line.startswith("```"): - read_block = False - continue - if line.startswith("```"): - read_block = True - elif line.startswith("* "): - if data != "": - newlines.append(":".join(titles) +":"+ data + line) - elif line.startswith("### "): - if data != "": - newlines.append(":".join(titles) +":"+ data) - data = "" - titles = [titles[0], titles[1], line[3:]] - elif line.startswith("## "): - if data != "": - newlines.append(":".join(titles) +":"+ data) - data = "" - titles = [titles[0], line[3:]] - elif line.startswith("# "): - if data != "": - newlines.append(":".join(titles)+ ":"+data) - data = "" - titles = [line[2:], "", ""] - else: - data += line + " " -# print("\n".join(newlines)) - return "\n".join(newlines) + # parser markdown and return a txt + lines = text.split("\n") + newlines = [] + data = "" + titles = ["", "", ""] + read_block = False + for line in lines: + line = line.strip() + if line == "": + continue + if read_block: + data += line + "\\n" + if line.startswith("```"): + read_block = False + continue + if line.startswith("```"): + read_block = True + elif line.startswith("* "): + if data != "": + newlines.append(":".join(titles) +":"+ data + line) + elif line.startswith("### "): + if data != "": + newlines.append(":".join(titles) +":"+ data) + data = "" + titles = [titles[0], titles[1], line[3:]] + elif line.startswith("## "): + if data != "": + newlines.append(":".join(titles) +":"+ data) + data = "" + titles = [titles[0], line[3:]] + elif line.startswith("# "): + if data != "": + newlines.append(":".join(titles)+ ":"+data) + data = "" + titles = [line[2:], "", ""] + else: + data += line + " " + # print("\n".join(newlines)) + return "\n".join(newlines) def filter_line(line): - line = unidecode(line) # remove accents - line = re.sub(r'https?://\S+', '', line) - line = re.sub(r'http?://\S+', '', line) - line = line.replace(":", " ").replace("/", " ").replace("`", " ").replace("?", " ") - line = line.replace("\"", " ").replace("'", " ") - line = line.replace("<", " ").replace(">", " ").replace("@", " ").replace("#", "") -#line = line.replace("-", " ").replace(".", " ").replace(",", " ").replace("(", " ").replace(")", " ").strip(" ") - line = line.replace(".", " ").replace(",", " ").replace("(", " ").replace(")", " ").strip(" ") - line = re.sub(r"\s+", " ", line) - if len(line) > MAXCHARS: - line = line[:MAXCHARS] - words = [] - for a in line.split(" "): - b = a.strip().lower() - try: - int(b) - continue - except: - pass - if len(b) > 0: - words.append(b) - return words + line = unidecode(line) # remove accents + line = re.sub(r'https?://\S+', '', line) + line = re.sub(r'http?://\S+', '', line) + line = line.replace(":", " ").replace("/", " ").replace("`", " ").replace("?", " ") + line = line.replace("\"", " ").replace("'", " ") + line = line.replace("<", " ").replace(">", " ").replace("@", " ").replace("#", "") + # line = line.replace("-", " ").replace(".", " ").replace(",", " ").replace("(", " ").replace(")", " ").strip(" ") + line = line.replace(".", " ").replace(",", " ").replace("(", " ").replace(")", " ").strip(" ") + line = re.sub(r"\s+", " ", line) + if len(line) > MAXCHARS: + line = line[:MAXCHARS] + words = [] + for a in line.split(" "): + b = a.strip().lower() + try: + int(b) + continue + except: + pass + if len(b) > 0: + words.append(b) + return words def smart_slurp(file): - if ignored_file(file): - return "" -# print("smart" + file) -# print(f"slurp: {file}") - text = slurp(file) - if file.endswith("r2ai.history"): - text = hist2txt(text) - elif file.endswith(".json"): - text = md2txt(json2md(text)) - elif file.endswith(".md"): - text = md2txt(text) - return text + if ignored_file(file): + return "" + # print("smart" + file) + # print(f"slurp: {file}") + text = slurp(file) + if file.endswith("r2ai.history"): + text = hist2txt(text) + elif file.endswith(".json"): + text = md2txt(json2md(text)) + elif file.endswith(".md"): + text = md2txt(text) + return text def vectordb_search2(text, keywords, use_mastodon): - global have_vectordb, vectordb_instance - vectordb_init() - result = [] - if use_mastodon: - print ("[r2ai] Searching in Mastodon", text) - lines = mastodon_lines(text, keywords, True) -# print("LINES", lines) - for line in lines: -# print("SAVE", line) - vectordb_instance.save(line, {"url":text}) - if have_vectordb == True and vectordb_instance is not None: - res = [] - try: - res = vectordb_instance.search(text, top_n=MAXMATCHES, unique=True, batch_results="diverse") - except: - try: - res = vectordb_instance.search(text, top_n=MAXMATCHES) - except: - traceback.print_exc() - pass - for r in res: - if "distance" in r: - # print("distance", r["distance"]) - if r['distance'] < 1: - result.append(r["chunk"]) - else: - # when mprt is not available we cant find the distance - result.append(r["chunk"]) - #print(result) - return sorted(set(result)) + global have_vectordb, vectordb_instance + vectordb_init() + result = [] + if use_mastodon: + print ("[r2ai] Searching in Mastodon", text) + lines = mastodon_lines(text, keywords, True) +# print("LINES", lines) + for line in lines: +# print("SAVE", line) + vectordb_instance.save(line, {"url":text}) + if have_vectordb is True and vectordb_instance is not None: + res = [] + try: + res = vectordb_instance.search(text, top_n=MAXMATCHES, unique=True, batch_results="diverse") + except: + try: + res = vectordb_instance.search(text, top_n=MAXMATCHES) + except: + traceback.print_exc() + pass + for r in res: + if "distance" in r: + # print("distance", r["distance"]) + if r['distance'] < 1: + result.append(r["chunk"]) + else: + # when mprt is not available we cant find the distance + result.append(r["chunk"]) + #print(result) + return sorted(set(result)) def vectordb_init(): - global have_vectordb, vectordb_instance - if have_vectordb == False: - print("LEAVING") - return - if vectordb_instance is not None: - return - try: - import vectordb - have_vectordb = True - except Exception as e: - os.system("python -m spacy download en_core_web_sm") - try: - import vectordb - have_vectordb = True - except: - have_vectordb = False - print("To better data index use:") - print(" pip install vectordb2") - print("On macOS you'll need to also do this:") - print(" python -m pip install spacy") - print(" python -m spacy download en_core_web_sm") - return - try: - vectordb_instance = vectordb.Memory(embeddings="best") # normal or fast - except: - vectordb_instance = vectordb.Memory() # normal or fast - if vectordb_instance is not None: - vectordb_instance.save("radare2 is a free reverse engineering tool written by pancake, aka Sergi Alvarez i Capilla. The project started in 2006 as a tool for domestic computer forensics in order to recover some deleted files and it continued the development adding new features like debugging, disassembler, decompiler, code analysis, advanced filesystem capabilities and integration with tons of tools like Frida, Radius, Ghidra, etc", {"url":"."}) # dummy entry + global have_vectordb, vectordb_instance + if have_vectordb is False: + print("LEAVING") + return + if vectordb_instance is not None: + return + try: + import vectordb + have_vectordb = True + except Exception as e: + os.system("python -m spacy download en_core_web_sm") + try: + import vectordb + have_vectordb = True + except: + have_vectordb = False + print("To better data index use:") + print(" pip install vectordb2") + print("On macOS you'll need to also do this:") + print(" python -m pip install spacy") + print(" python -m spacy download en_core_web_sm") + return + try: + vectordb_instance = vectordb.Memory(embeddings="best") # normal or fast + except: + vectordb_instance = vectordb.Memory() # normal or fast + if vectordb_instance is not None: + vectordb_instance.save("radare2 is a free reverse engineering tool written by pancake, aka Sergi Alvarez i Capilla. The project started in 2006 as a tool for domestic computer forensics in order to recover some deleted files and it continued the development adding new features like debugging, disassembler, decompiler, code analysis, advanced filesystem capabilities and integration with tons of tools like Frida, Radius, Ghidra, etc", {"url":"."}) # dummy entry def vectordb_search(text, keywords, source_files, use_mastodon, use_debug): - global have_vectordb, vectordb_instance - if have_vectordb == False: - builtins.print("no vdb found") - return [] - if have_vectordb == True and vectordb_instance is not None: - return vectordb_search2(text, keywords, use_mastodon) - vectordb_init() - if vectordb_instance is None: - builtins.print("vdb not initialized") - return - # indexing data - builtins.print("[r2ai] Indexing local data with vectordb") - saved = 0 - for file in source_files: - if ignored_file(file): - continue - lines = smart_slurp(file).splitlines() - for line in lines: -# vectordb_instance.save(line) - vectordb_instance.save(line, {"url":file}) #, "url": file}) - saved = saved + 1 - if use_mastodon: - lines = mastodon_lines(text, None, True) - for line in lines: - saved = saved + 1 - vectordb_instance.save(line, {"url":text}) - if saved == 0: - print("[r2ai] Nothing indexed") - vectordb_instance.save("", {}) - else: - print("[r2ai] VectorDB index done") - return vectordb_search2(text, keywords, use_mastodon) + global have_vectordb, vectordb_instance + if have_vectordb is False: + builtins.print("no vdb found") + return [] + if have_vectordb is True and vectordb_instance is not None: + return vectordb_search2(text, keywords, use_mastodon) + vectordb_init() + if vectordb_instance is None: + builtins.print("vdb not initialized") + return + # indexing data + builtins.print("[r2ai] Indexing local data with vectordb") + saved = 0 + for file in source_files: + if ignored_file(file): + continue + lines = smart_slurp(file).splitlines() + for line in lines: +# vectordb_instance.save(line) + vectordb_instance.save(line, {"url":file}) #, "url": file}) + saved = saved + 1 + if use_mastodon: + lines = mastodon_lines(text, None, True) + for line in lines: + saved = saved + 1 + vectordb_instance.save(line, {"url":text}) + if saved == 0: + print("[r2ai] Nothing indexed") + vectordb_instance.save("", {}) + else: + print("[r2ai] VectorDB index done") + return vectordb_search2(text, keywords, use_mastodon) class compute_rarity(): - use_mastodon = MASTODON_KEY != "" # False - use_debug = False - words = {} - lines = [] - def __init__(self, source_files, use_mastodon, use_debug): - self.use_mastodon = use_mastodon - for file in source_files: - if ignored_file(file): - continue - lines = smart_slurp(file).splitlines() - for line in lines: - self.lines.append(line) - self.compute_rarity_in_line(line) - def compute_rarity_in_line(self,line): - fline = filter_line(line) - for a in fline: - if self.words.get(a): - self.words[a] += 1 - else: - self.words[a] = 1 - def pull_realtime_lines(self, text, keywords, use_vectordb): - if self.env["debug"] == "true": - print(f"Pulling from mastodon {text}") - return mastodon_lines(text, keywords, use_vectordb) + use_mastodon = MASTODON_KEY != "" # False + use_debug = False + words = {} + lines = [] + def __init__(self, source_files, use_mastodon, use_debug): + self.use_mastodon = use_mastodon + for file in source_files: + if ignored_file(file): + continue + lines = smart_slurp(file).splitlines() + for line in lines: + self.lines.append(line) + self.compute_rarity_in_line(line) + def compute_rarity_in_line(self,line): + fline = filter_line(line) + for a in fline: + if self.words.get(a): + self.words[a] += 1 + else: + self.words[a] = 1 + def pull_realtime_lines(self, text, keywords, use_vectordb): + if self.env["debug"] == "true": + print(f"Pulling from mastodon {text}") + return mastodon_lines(text, keywords, use_vectordb) - def find_matches(self, text, keywords): - if self.use_mastodon: - # pull from mastodon - backup_lines = self.lines - backup_words = self.words - realtime_lines = self.pull_realtime_lines(text, keywords, False) - for line in realtime_lines: - self.compute_rarity_in_line(line) - self.lines.extend(realtime_lines) + def find_matches(self, text, keywords): + if self.use_mastodon: + # pull from mastodon + backup_lines = self.lines + backup_words = self.words + realtime_lines = self.pull_realtime_lines(text, keywords, False) + for line in realtime_lines: + self.compute_rarity_in_line(line) + self.lines.extend(realtime_lines) # find matches - res = [] - twords = filter_line(text) - rarity = [] - for tw in twords: - if self.words.get(tw): - rarity.append(self.words[tw]) - else: - rarity.append(0) - swords = sorted(twords, key=lambda x: self.words.get(x) or 0) - maxrate = 0 - maxline = "" - rates = {} - lines = [] - for line in self.lines: - linewords = filter_line(line) - rate = self.match_line(linewords, swords) - if rate > 0: - lines.append(line) - rates[line] = rate -# print(f"{rate} = {line}") - srates = sorted(lines, key=lambda x: rates.get(x) or 0) - srates.reverse() - if self.use_mastodon: - self.lines = backup_lines - self.words = backup_words - res = srates[0:MAXMATCHES] - res = sorted(set(res)) - return res + res = [] + twords = filter_line(text) + rarity = [] + for tw in twords: + if self.words.get(tw): + rarity.append(self.words[tw]) + else: + rarity.append(0) + swords = sorted(twords, key=lambda x: self.words.get(x) or 0) + rates = {} + lines = [] + for line in self.lines: + linewords = filter_line(line) + rate = self.match_line(linewords, swords) + if rate > 0: + lines.append(line) + rates[line] = rate +# print(f"{rate} = {line}") + srates = sorted(lines, key=lambda x: rates.get(x) or 0) + srates.reverse() + if self.use_mastodon: + self.lines = backup_lines + self.words = backup_words + res = srates[0:MAXMATCHES] + res = sorted(set(res)) + return res - def match_line(self,linewords, swords): - count = 0 - ow = "" - for w in swords: - if w == ow: - continue - if w in linewords: - rarity = 1 - if w in self.words: - rarity = self.words[w] - count += rarity - ow = w - return count + def match_line(self,linewords, swords): + count = 0 + ow = "" + for w in swords: + if w == ow: + continue + if w in linewords: + rarity = 1 + if w in self.words: + rarity = self.words[w] + count += rarity + ow = w + return count def ignored_file(fn): - if fn.endswith("package.json"): - return True - if fn.endswith("package-lock.json"): - return True - if "/." in fn: - return True - return False + if fn.endswith("package.json"): + return True + if fn.endswith("package-lock.json"): + return True + if "/." in fn: + return True + return False def find_sources(srcdir): - files = [] - try: - files = os.walk(srcdir) - except: - return [] - res = [] - for f in files: - directory = f[0] - dirfiles = f[2] - for f2 in dirfiles: - if ignored_file(f2): - continue - if f2.endswith(".txt") or f2.endswith(".md"): - res.append(f"{directory}/{f2}") - elif f2.endswith(".json"): - res.append(f"{directory}/{f2}") - return res + files = [] + try: + files = os.walk(srcdir) + except: + return [] + res = [] + for f in files: + directory = f[0] + dirfiles = f[2] + for f2 in dirfiles: + if ignored_file(f2): + continue + if f2.endswith(".txt") or f2.endswith(".md"): + res.append(f"{directory}/{f2}") + elif f2.endswith(".json"): + res.append(f"{directory}/{f2}") + return res def init(): - print("find sources and such") + print("find sources and such") def source_files(datadir, use_hist): - files = [] - if datadir is not None and datadir != "": - files.extend(find_sources(datadir)) - if use_hist: - files.append(R2AI_HISTFILE) - return files + files = [] + if datadir is not None and datadir != "": + files.extend(find_sources(datadir)) + if use_hist: + files.append(R2AI_HISTFILE) + return files def find_wikit(text, keywords): - print("wikit") - global have_vectordb, vectordb_instance - vectordb_init() - if vectordb_instance is None: - print("vdb not initialized") - return - if keywords is not None: - for kw in keywords: - print("wikit " + kw) - res = syscmdstr("wikit -a " + kw) - if len(res) > 20: - vectordb_instance.save(res, {"url":kw}) - words = filter_line(text) - for kw in words: - print("wikit " + kw) - res = syscmdstr("wikit -a " + kw) - if len(res) > 20: - vectordb_instance.save(res, {"keyword":kw}) - res = syscmdstr("wikit -a '" + " ".join(words) + "'") - if len(res) > 20: - vectordb_instance.save(res, {"keyword":kw}) + print("wikit") + global have_vectordb, vectordb_instance + vectordb_init() + if vectordb_instance is None: + print("vdb not initialized") + return + if keywords is not None: + for kw in keywords: + print("wikit " + kw) + res = syscmdstr("wikit -a " + kw) + if len(res) > 20: + vectordb_instance.save(res, {"url":kw}) + words = filter_line(text) + for kw in words: + print("wikit " + kw) + res = syscmdstr("wikit -a " + kw) + if len(res) > 20: + vectordb_instance.save(res, {"keyword":kw}) + res = syscmdstr("wikit -a '" + " ".join(words) + "'") + if len(res) > 20: + vectordb_instance.save(res, {"keyword":kw}) + def reset(): - global vectordb_instance - vectordb_instance = None + global vectordb_instance + vectordb_instance = None def match(text, keywords, datadir, use_hist, use_mastodon, use_debug, use_wikit, use_vectordb): - files = source_files(datadir, use_hist) - if use_vectordb: - if use_wikit: - find_wikit(text, keywords) - return vectordb_search(text, keywords, files, use_mastodon, use_debug) - raredb = compute_rarity(files, use_mastodon, use_debug) - if use_wikit: - print("[r2ai] Warning: data.wikit only works with vectordb") - - return raredb.find_matches(text, keywords) - -if __name__ == '__main__': - if len(sys.argv) > 1: - matches = main_indexer(sys.argv[1]) - for m in matches: - print(m) - else: - print(f"Usage: index.py [query]") + files = source_files(datadir, use_hist) + if use_vectordb: + if use_wikit: + find_wikit(text, keywords) + return vectordb_search(text, keywords, files, use_mastodon, use_debug) + raredb = compute_rarity(files, use_mastodon, use_debug) + if use_wikit: + print("[R2AI] Warning: data.wikit only works with vectordb") + return raredb.find_matches(text, keywords) diff --git a/r2ai/large.py b/r2ai/large.py index 7d784874..5d9862e5 100644 --- a/r2ai/large.py +++ b/r2ai/large.py @@ -2,167 +2,168 @@ import json class Large: - def __init__(self, ai = None): - self.mistral = None - self.window = 4096 - self.maxlen = 12000 - self.maxtokens = 5000 - # self.model = "TheBloke/Mistral-7B-Instruct-v0.2-GGUF" - self.model = "FaradayDotDev/llama-3-8b-Instruct-GGUF" - if ai is not None: - self.env = ai.env - else: - self.env = {} - self.env["llm.gpu"] = "true" + def __init__(self, ai = None): + self.mistral = None + self.window = 4096 + self.maxlen = 12000 + self.maxtokens = 5000 + # self.model = "TheBloke/Mistral-7B-Instruct-v0.2-GGUF" + self.model = "FaradayDotDev/llama-3-8b-Instruct-GGUF" + if ai is not None: + self.env = ai.env + else: + self.env = {} + self.env["llm.gpu"] = "true" + + def slice_text(self, amsg): + slices = [] + pos = self.maxlen + while(len(amsg) > self.maxlen): + s = amsg[:pos] + amsg = amsg[pos:] + slices.append(s) + slices.append(amsg) + return slices + + def compress_text(self, msg): + if self.mistral == None: + self.mistral = new_get_hf_llm(self, self.model, False, self.window) + # q = f"Rewrite this code into shorter pseudocode (less than 500 tokens). keep the comments and essential logic:\n```\n{msg}\n```\n" + #q = f"Rewrite this code into shorter pseudocode (less than 200 tokens). keep the relevant comments and essential logic:\n```\n{msg}\n```\n" + q = f"Resumen y responde SOLO la información relevante del siguiente texto:\n{msg}" + response = self.mistral(q, stream=False, temperature=0.001, stop="", max_tokens=self.maxtokens) + print(response["choices"]) #json.dumps(response)) + text0 = response["choices"][0]["text"] + return text0 + + def summarize_text(self, amsg): + olen = len(amsg) + while len(amsg) > self.maxlen: + slices = self.slice_text(amsg) + print(f"Re-slicing {len(slices)}") + short_slices = [] + for s in slices: + sm = self.compress_text(s) + short_slices.append(sm) + print(sm) + print(f"Went from {len(s)} to {len(sm)}") + amsg = " ".join(short_slices) + nlen = len(amsg) + print(f"total length {nlen} (original length was {olen})") + return amsg + + def keywords_ai(self, text): + # kws = self.keywords_ai("who is the author of radare?") => "author,radare2" + words = [] + ctxwindow = int(self.env["llm.window"]) + mm = new_get_hf_llm(self, self.model, False, ctxwindow) + msg = f"Considering the sentence \"{text}\" as input, Take the KEYWORDS or combination of TWO words from the given text and respond ONLY a comma separated list of the most relevant words. DO NOT introduce your response, ONLY show the words" + msg = f"Take \"{text}\" as input, and extract the keywords and combination of keywords to make a search online, the output must be a comma separated list" #Take the KEYWORDS or combination of TWO words from the given text and respond ONLY a comma separated list of the most relevant words. DO NOT introduce your response, ONLY show the words" + response = mm(msg, stream=False, temperature=0.001, stop="", max_tokens=1750) + if self.env["debug"] == "true": + print("KWSPLITRESPONSE", response) + text0 = response["choices"][0]["text"] + text0 = text0.replace('"', ",") + if text0.startswith("."): + text0 = text0[1:].strip() + try: + text0 = text0.split(":")[1].strip() + except: + pass + # print(text0) + mm = None + return [word.strip() for word in text0.split(',')] - def slice_text(self, amsg): - slices = [] - pos = self.maxlen - while(len(amsg) > self.maxlen): - s = amsg[:pos] - amsg = amsg[pos:] - slices.append(s) - slices.append(amsg) - return slices - - def compress_text(self, msg): - if self.mistral == None: - self.mistral = new_get_hf_llm(self, self.model, False, self.window) - # q = f"Rewrite this code into shorter pseudocode (less than 500 tokens). keep the comments and essential logic:\n```\n{msg}\n```\n" - #q = f"Rewrite this code into shorter pseudocode (less than 200 tokens). keep the relevant comments and essential logic:\n```\n{msg}\n```\n" - q = f"Resumen y responde SOLO la información relevante del siguiente texto:\n{msg}" - response = self.mistral(q, stream=False, temperature=0.001, stop="", max_tokens=self.maxtokens) - print(response["choices"]) #json.dumps(response)) - text0 = response["choices"][0]["text"] - return text0 - - def summarize_text(self, amsg): - olen = len(amsg) - while len(amsg) > self.maxlen: - slices = self.slice_text(amsg) - print(f"Re-slicing {len(slices)}") - short_slices = [] - for s in slices: - sm = self.compress_text(s) - short_slices.append(sm) - print(sm) - print(f"Went from {len(s)} to {len(sm)}") - amsg = " ".join(short_slices) - nlen = len(amsg) - print(f"total length {nlen} (original length was {olen})") - return amsg - - def keywords_ai(self, text): - # kws = self.keywords_ai("who is the author of radare?") => "author,radare2" - words = [] - ctxwindow = int(self.env["llm.window"]) - mm = new_get_hf_llm(self, self.model, False, ctxwindow) - msg = f"Considering the sentence \"{text}\" as input, Take the KEYWORDS or combination of TWO words from the given text and respond ONLY a comma separated list of the most relevant words. DO NOT introduce your response, ONLY show the words" - msg = f"Take \"{text}\" as input, and extract the keywords and combination of keywords to make a search online, the output must be a comma separated list" #Take the KEYWORDS or combination of TWO words from the given text and respond ONLY a comma separated list of the most relevant words. DO NOT introduce your response, ONLY show the words" - response = mm(msg, stream=False, temperature=0.001, stop="", max_tokens=1750) - if self.env["debug"] == "true": - print("KWSPLITRESPONSE", response) - text0 = response["choices"][0]["text"] - text0 = text0.replace('"', ",") - if text0.startswith("."): - text0 = text0[1:].strip() - try: - text0 = text0.split(":")[1].strip() - except: - pass - # print(text0) - mm = None - return [word.strip() for word in text0.split(',')] - def trimsource(self, msg): - msg = msg.replace("public ", "") - msg = re.sub(r'import.*\;', "", msg) - msg = msg.replace("const ", "") - msg = msg.replace("new ", "") - msg = msg.replace("undefined", "0") - msg = msg.replace("null", "0") - msg = msg.replace("false", "0") - msg = msg.replace("true", "1") - msg = msg.replace("let ", "") - msg = msg.replace("var ", "") - msg = msg.replace("class ", "") - msg = msg.replace("interface ", "") - msg = msg.replace("function ", "fn ") - msg = msg.replace("substring", "") - msg = msg.replace("this.", "") - msg = msg.replace("while (", "while(") - msg = msg.replace("if (", "if(") - msg = msg.replace("!== 0", "") - msg = msg.replace("=== true", "") - msg = msg.replace(" = ", "=") - msg = msg.replace(" === ", "==") - msg = msg.replace("\t", " ") - msg = msg.replace("\n", "") - msg = re.sub(r"/\*.*?\*/", '', msg, flags=re.DOTALL) - # msg = re.sub(r"\n+", "\n", msg) - msg = re.sub(r"\t+", ' ', msg) - msg = re.sub(r"\s+", " ", msg) - # msg = msg.replace(";", "") - return msg.strip() - - def trimsource_ai(self, msg): - words = [] - if self.mistral == None: - ctxwindow = int(self.env["llm.window"]) - self.mistral = new_get_hf_llm(self, self.model, False, ctxwindow) - # q = f"Rewrite this code into shorter pseudocode (less than 500 tokens). keep the comments and essential logic:\n```\n{msg}\n```\n" - q = f"Rewrite this code into shorter pseudocode (less than 200 tokens). keep the relevant comments and essential logic:\n```\n{msg}\n```\n" - response = self.mistral(q, stream=False, temperature=0.1, stop="", max_tokens=4096) - text0 = response["choices"][0]["text"] - if "```" in text0: - return text0.split("```")[1].strip() - return text0.strip().replace("```", "") - - def compress_code_ai(self, code): - piecesize = 1024 * 8 # mistral2 supports 32k vs 4096 - codelen = len(code) - pieces = int(codelen / piecesize) - if pieces < 1: - pieces = 1 - plen = int(codelen / pieces) - off = 0 - res = [] - for i in range(pieces): - piece = i + 1 - print(f"Processing {piece} / {pieces} ...") - if piece == pieces: - r = self.trimsource_ai(code[off:]) - else: - r = self.trimsource_ai(code[off:off+plen]) - res.append(r) - off += plen - return "\n".join(res) - - def compress_messages(self, messages): - # TODO: implement a better logic in here asking the lm to summarize the context - olen = 0 - msglen = 0 - for msg in messages: - if self.env["chat.reply"] == "false": - if msg["role"] != "user": - continue - if "content" in msg: - amsg = msg["content"] - olen += len(amsg) - if len(amsg) > int(self.env["llm.maxmsglen"]): - if "while" in amsg and "```" in amsg: - que = re.search(r"^(.*?)```", amsg, re.DOTALL).group(0).replace("```", "") - cod = re.search(r"```(.*?)$", amsg, re.DOTALL).group(0).replace("```", "") - shortcode = cod - while len(shortcode) > 4000: - olen = len(shortcode) - shortcode = self.compress_code_ai(shortcode) - nlen = len(shortcode) - print(f"Went from {olen} to {nlen}") - msg["content"] = f"{que}\n```\n{shortcode}\n```\n" - else: - print(f"total length {msglen} (original length was {olen})") - msglen += len(msg["content"]) - # print(f"total length {msglen} (original length was {olen})") - # if msglen > 4096: - # ¡print("Query is too large.. you should consider triming old messages") - return messages + def trimsource(self, msg): + msg = msg.replace("public ", "") + msg = re.sub(r'import.*\;', "", msg) + msg = msg.replace("const ", "") + msg = msg.replace("new ", "") + msg = msg.replace("undefined", "0") + msg = msg.replace("null", "0") + msg = msg.replace("false", "0") + msg = msg.replace("true", "1") + msg = msg.replace("let ", "") + msg = msg.replace("var ", "") + msg = msg.replace("class ", "") + msg = msg.replace("interface ", "") + msg = msg.replace("function ", "fn ") + msg = msg.replace("substring", "") + msg = msg.replace("this.", "") + msg = msg.replace("while (", "while(") + msg = msg.replace("if (", "if(") + msg = msg.replace("!== 0", "") + msg = msg.replace("=== true", "") + msg = msg.replace(" = ", "=") + msg = msg.replace(" === ", "==") + msg = msg.replace("\t", " ") + msg = msg.replace("\n", "") + msg = re.sub(r"/\*.*?\*/", '', msg, flags=re.DOTALL) + # msg = re.sub(r"\n+", "\n", msg) + msg = re.sub(r"\t+", ' ', msg) + msg = re.sub(r"\s+", " ", msg) + # msg = msg.replace(";", "") + return msg.strip() + + def trimsource_ai(self, msg): + words = [] + if self.mistral == None: + ctxwindow = int(self.env["llm.window"]) + self.mistral = new_get_hf_llm(self, self.model, False, ctxwindow) + # q = f"Rewrite this code into shorter pseudocode (less than 500 tokens). keep the comments and essential logic:\n```\n{msg}\n```\n" + q = f"Rewrite this code into shorter pseudocode (less than 200 tokens). keep the relevant comments and essential logic:\n```\n{msg}\n```\n" + response = self.mistral(q, stream=False, temperature=0.1, stop="", max_tokens=4096) + text0 = response["choices"][0]["text"] + if "```" in text0: + return text0.split("```")[1].strip() + return text0.strip().replace("```", "") + + def compress_code_ai(self, code): + piecesize = 1024 * 8 # mistral2 supports 32k vs 4096 + codelen = len(code) + pieces = int(codelen / piecesize) + if pieces < 1: + pieces = 1 + plen = int(codelen / pieces) + off = 0 + res = [] + for i in range(pieces): + piece = i + 1 + print(f"Processing {piece} / {pieces} ...") + if piece == pieces: + r = self.trimsource_ai(code[off:]) + else: + r = self.trimsource_ai(code[off:off+plen]) + res.append(r) + off += plen + return "\n".join(res) + + def compress_messages(self, messages): + # TODO: implement a better logic in here asking the lm to summarize the context + olen = 0 + msglen = 0 + for msg in messages: + if self.env["chat.reply"] == "false": + if msg["role"] != "user": + continue + if "content" in msg: + amsg = msg["content"] + olen += len(amsg) + if len(amsg) > int(self.env["llm.maxmsglen"]): + if "while" in amsg and "```" in amsg: + que = re.search(r"^(.*?)```", amsg, re.DOTALL).group(0).replace("```", "") + cod = re.search(r"```(.*?)$", amsg, re.DOTALL).group(0).replace("```", "") + shortcode = cod + while len(shortcode) > 4000: + olen = len(shortcode) + shortcode = self.compress_code_ai(shortcode) + nlen = len(shortcode) + print(f"Went from {olen} to {nlen}") + msg["content"] = f"{que}\n```\n{shortcode}\n```\n" + else: + print(f"total length {msglen} (original length was {olen})") + msglen += len(msg["content"]) + # print(f"total length {msglen} (original length was {olen})") + # if msglen > 4096: + # ¡print("Query is too large.. you should consider triming old messages") + return messages diff --git a/r2ai/voice.py b/r2ai/voice.py index e5b99907..b86dd511 100644 --- a/r2ai/voice.py +++ b/r2ai/voice.py @@ -1,90 +1,92 @@ +"""Helper functions to handle voice recognition and synthesis.""" + +import os +import re import subprocess from .utils import syscmdstr from subprocess import Popen, PIPE -import os -import re -have_whisper = False +HAVE_WHISPER = False model = None voice_model = "large" # base DEVICE = None try: - import whisper - have_whisper = True + import whisper + HAVE_WHISPER = True except: - pass + pass have_festival = os.path.isfile("/usr/bin/festival") def run(models): - for model in models: - cmd = f"ffmpeg -f avfoundation -list_devices true -i '' 2>&1 | grep '{model}'|cut -d '[' -f 3" - output = syscmdstr(cmd) - if output != "": - return ":" + output[0] - return None + for model in models: + cmd = f"ffmpeg -f avfoundation -list_devices true -i '' 2>&1 | grep '{model}'|cut -d '[' -f 3" + output = syscmdstr(cmd) + if output != "": + return ":" + output[0] + return None def get_microphone(lang): - global DEVICE - print (f"DE {DEVICE}") - if DEVICE is not None: - return DEVICE - tts("(r2ai)", "un moment", lang) - DEVICE = run(["AirPods", "MacBook Pro"]) - print(f"DEVICE: {DEVICE}") - return DEVICE + global DEVICE + print (f"DE {DEVICE}") + if DEVICE is not None: + return DEVICE + tts("(r2ai)", "un moment", lang) + DEVICE = run(["AirPods", "MacBook Pro"]) + print(f"DEVICE: {DEVICE}") + return DEVICE def stt(seconds, lang): - global model - global DEVICE - global voice_model - if lang == "": - lang = None - if model == None: - model = whisper.load_model(voice_model) - device = get_microphone(lang) - if device is None: - tts("(r2ai)", "cannot find a microphone", lang) - return - tts("(r2ai) listening for 5s... ", "digues?", lang) - print(f"DEVICE IS {device}") - os.system("rm -f .audiomsg.wav") - rc = os.system(f"ffmpeg -f avfoundation -t 5 -i '{device}' .audiomsg.wav > /dev/null 2>&1") - if rc != 0: - tts("(r2ai)", "cannot record from microphone. missing permissions in terminal?", lang) - return - result = None - if lang is None: - result = model.transcribe(".audiomsg.wav") - else: - result = model.transcribe(".audiomsg.wav", language=lang) - os.system("rm -f .audiomsg.wav") - tts("(r2ai)", "ok", lang) - text = result["text"].strip() - if text == "you": - return "" -# print(f"User: {text}") - return text + global model + global DEVICE + global voice_model + if lang == "": + lang = None + if model == None: + model = whisper.load_model(voice_model) + device = get_microphone(lang) + if device is None: + tts("(r2ai)", "cannot find a microphone", lang) + return + tts("(r2ai) listening for 5s... ", "digues?", lang) + print(f"DEVICE IS {device}") + os.system("rm -f .audiomsg.wav") + rc = os.system(f"ffmpeg -f avfoundation -t 5 -i '{device}' .audiomsg.wav > /dev/null 2>&1") + if rc != 0: + tts("(r2ai)", "cannot record from microphone. missing permissions in terminal?", lang) + return + result = None + if lang is None: + result = model.transcribe(".audiomsg.wav") + else: + result = model.transcribe(".audiomsg.wav", language=lang) + os.system("rm -f .audiomsg.wav") + tts("(r2ai)", "ok", lang) + text = result["text"].strip() + if text == "you": + return "" +# print(f"User: {text}") + return text def tts(author, text, lang): - clean_text = re.sub(r'https?://\S+', '', text) - clean_text = re.sub(r'http?://\S+', '', clean_text) - print(f"{author}: {text}") - if have_festival: - festlang = "english" - if lang == "ca": - festlang = "catalan" - elif lang == "es": - festlang = "spanish" - elif lang == "it": - festlang = "italian" - p = Popen(['festival', '--tts', '--language', festlang], stdin=PIPE) - p.communicate(input=text) - else: - if lang == "es": - VOICE = "Marisol" - elif lang == "ca": - VOICE = "Montse" - else: - VOICE = "Moira" - subprocess.run(["say", "-v", VOICE, clean_text]) + clean_text = re.sub(r'https?://\S+', '', text) + clean_text = re.sub(r'http?://\S+', '', clean_text) + print(f"{author}: {text}") + if have_festival: + festlang = "english" + if lang == "ca": + festlang = "catalan" + elif lang == "es": + festlang = "spanish" + elif lang == "it": + festlang = "italian" + p = Popen(['festival', '--tts', '--language', festlang], stdin=PIPE) + p.communicate(input=text) + else: + if lang == "es": + VOICE = "Marisol" + elif lang == "ca": + VOICE = "Montse" + else: + VOICE = "Moira" + subprocess.run(["say", "-v", VOICE, clean_text]) diff --git a/r2ai/web.py b/r2ai/web.py index d73222df..2a29d3e7 100644 --- a/r2ai/web.py +++ b/r2ai/web.py @@ -4,42 +4,41 @@ ores = "" def handle_tabby_query(self, ai, obj, runline2, method): - global ores - # TODO build proper health json instead of copypasting a stolen one - healthstr=''' - {"model":"TabbyML/StarCoder-1B","device":"cpu","arch":"aarch64","cpu_info":"Apple M1 Max","cpu_count":10,"cuda_devices":[],"version":{"build_date":"2024-04-22","build_timestamp":"2024-04-22T21:00:09.963266000Z","git_sha":"0b5504eccbbdde20aba26f6dbd5810f57497e6a4","git_describe":"v0.10.0"}}''' - if method == "GET" and self.path == "/v1/health": - self.send_response(200) - self.end_headers() - self.wfile.write(bytes(f'{healthstr}','utf-8')) - return True - # /v1/completions - if self.path != "/v1/completions": - print(f"UnkPath: {self.path}") - self.send_response(200) - self.end_headers() - return True - print("/v1/completions") - if obj == None: - print("ObjNone") - self.send_response(200) - self.end_headers() - return True - if "segments" not in obj: - print("Nothing") - return True - pfx = obj["segments"]["prefix"].strip() - sfx = obj["segments"]["suffix"].strip() - lng = obj["language"] - if pfx == "": - self.send_response(200) - self.end_headers() - return True - runline2(ai, "-R") - #codequery = f"What's between `{pfx}` and `{sfx}` in `{lng}`, without including the context" - codequery = f"Complete the code between `{pfx}` and `{sfx}` in `{lng}`" - response = json.loads(''' - { + global ores + # TODO build proper health json instead of copypasting a stolen one + healthstr=''' + {"model":"TabbyML/StarCoder-1B","device":"cpu","arch":"aarch64","cpu_info":"Apple M1 Max","cpu_count":10,"cuda_devices":[],"version":{"build_date":"2024-04-22","build_timestamp":"2024-04-22T21:00:09.963266000Z","git_sha":"0b5504eccbbdde20aba26f6dbd5810f57497e6a4","git_describe":"v0.10.0"}}''' + if method == "GET" and self.path == "/v1/health": + self.send_response(200) + self.end_headers() + self.wfile.write(bytes(f'{healthstr}','utf-8')) + return True + # /v1/completions + if self.path != "/v1/completions": + print(f"UnkPath: {self.path}") + self.send_response(200) + self.end_headers() + return True + print("/v1/completions") + if obj == None: + print("ObjNone") + self.send_response(200) + self.end_headers() + return True + if "segments" not in obj: + print("Nothing") + return True + pfx = obj["segments"]["prefix"].strip() + sfx = obj["segments"]["suffix"].strip() + lng = obj["language"] + if pfx == "": + self.send_response(200) + self.end_headers() + return True + runline2(ai, "-R") + #codequery = f"What's between `{pfx}` and `{sfx}` in `{lng}`, without including the context" + codequery = f"Complete the code between `{pfx}` and `{sfx}` in `{lng}`" + response = json.loads('''{ "id": "cmpl-9d8aab26-ddc1-4314-a937-6654f2c13932", "choices": [ { @@ -48,25 +47,25 @@ def handle_tabby_query(self, ai, obj, runline2, method): } ] }''') - print(f"PREFIX {pfx}") - print(f"SUFFIX {sfx}") - print(f"RES {ores}") - response["choices"][0]["text"] = ores - jresponse = json.dumps(response) - self.send_response(200) - self.end_headers() - self.wfile.write(bytes(f'{jresponse}','utf-8')) - print("compute query") - ores = runline2(ai, codequery).strip() - ores = ores.replace(pfx, "") - ores = ores.replace(sfx, "") - ores = re.sub(r'```.*$', '', ores) - ores = ores.replace("```javascript", "") - ores = ores.replace("```", "") - ores = ores.strip() - print(f"RES2 {ores}") - #ores = ores.replace("\n", "") - print("computed") + print(f"PREFIX {pfx}") + print(f"SUFFIX {sfx}") + print(f"RES {ores}") + response["choices"][0]["text"] = ores + jresponse = json.dumps(response) + self.send_response(200) + self.end_headers() + self.wfile.write(bytes(f'{jresponse}','utf-8')) + print("compute query") + ores = runline2(ai, codequery).strip() + ores = ores.replace(pfx, "") + ores = ores.replace(sfx, "") + ores = re.sub(r'```.*$', '', ores) + ores = ores.replace("```javascript", "") + ores = ores.replace("```", "") + ores = ores.strip() + print(f"RES2 {ores}") + #ores = ores.replace("\n", "") + print("computed") def handle_custom_request(self, ai, msg, runline2, method): print("CUSTOM") @@ -82,45 +81,41 @@ def handle_custom_request(self, ai, msg, runline2, method): return True def start_http_server(ai, runline2): - import http.server - import socketserver - - WANTCTX = ai.env["http.chatctx"] == "true" - PORT = int(ai.env["http.port"]) - BASEPATH = ai.env["http.path"] - - Handler = http.server.SimpleHTTPRequestHandler - - class SimpleHTTPRequestHandler(Handler): - def do_GET(self): - print("GET") - if handle_custom_request(self, ai, "", runline2, "GET"): - return - self.send_response(404) - self.end_headers() - self.wfile.write(bytes(f'Invalid request. Use POST and /{BASEPATH}', 'utf-8')) - def do_POST(self): - print("POST") - if not WANTCTX: - runline2(ai, "-R") - content_length = int(self.headers['Content-Length']) - msg = self.rfile.read(content_length).decode('utf-8') - if handle_custom_request(self, ai, msg, runline2, "POST"): - return - if self.path.startswith(BASEPATH): - self.send_response(200) - self.end_headers() - res = runline2(ai, msg) - self.wfile.write(bytes(f'{res}','utf-8')) - else: - self.send_response(404) - self.end_headers() - self.wfile.write(bytes(f'Invalid request. Use {BASEPATH}')) - - print("[r2ai] Serving at port", PORT) - Handler.protocol_version = "HTTP/1.0" - server = socketserver.TCPServer(("", PORT), SimpleHTTPRequestHandler) - server.allow_reuse_address = True - server.allow_reuse_port = True - server.serve_forever() + import http.server + import socketserver + WANTCTX = ai.env["http.chatctx"] == "true" + PORT = int(ai.env["http.port"]) + BASEPATH = ai.env["http.path"] + Handler = http.server.SimpleHTTPRequestHandler + class SimpleHTTPRequestHandler(Handler): + def do_GET(self): + print("GET") + if handle_custom_request(self, ai, "", runline2, "GET"): + return + self.send_response(404) + self.end_headers() + self.wfile.write(bytes(f'Invalid request. Use POST and /{BASEPATH}', 'utf-8')) + def do_POST(self): + print("POST") + if not WANTCTX: + runline2(ai, "-R") + content_length = int(self.headers['Content-Length']) + msg = self.rfile.read(content_length).decode('utf-8') + if handle_custom_request(self, ai, msg, runline2, "POST"): + return + if self.path.startswith(BASEPATH): + self.send_response(200) + self.end_headers() + res = runline2(ai, msg) + self.wfile.write(bytes(f'{res}','utf-8')) + else: + self.send_response(404) + self.end_headers() + self.wfile.write(bytes(f'Invalid request. Use {BASEPATH}')) + print("[R2AI] Serving at port", PORT) + Handler.protocol_version = "HTTP/1.0" + server = socketserver.TCPServer(("", PORT), SimpleHTTPRequestHandler) + server.allow_reuse_address = True + server.allow_reuse_port = True + server.serve_forever() diff --git a/run-venv.sh b/run-venv.sh deleted file mode 100755 index ebfc3110..00000000 --- a/run-venv.sh +++ /dev/null @@ -1,11 +0,0 @@ -#!/bin/sh -if [ ! -d venv ]; then - python -m venv venv - if [ ! -d venv ]; then - echo "Cannot create venv" >&2 - exit 1 - fi -fi -. venv/bin/activate -export PYTHONPATH=$PWD -python r2ai/main.py $@