Skip to content

Commit

Permalink
Add support for Google Gemini API
Browse files Browse the repository at this point in the history
  • Loading branch information
dnakov authored and trufae committed May 8, 2024
1 parent 079144b commit 260d61c
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 13 deletions.
62 changes: 50 additions & 12 deletions r2ai/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,25 @@ def get_functionary_tokenizer(repo_id):
functionary_tokenizer = AutoTokenizer.from_pretrained(repo_id, legacy=True)
return functionary_tokenizer

def r2cmd(command: str):
"""runs commands in radare2. You can run it multiple times or chain commands with pipes/semicolons. You can also use r2 interpreters to run scripts using the `#`, '#!', etc. commands. The output could be long, so try to use filters if possible or limit. This is your preferred tool"""
builtins.print('\x1b[1;32mRunning \x1b[4m' + command + '\x1b[0m')
res = r2lang.cmd(command)
builtins.print(res)
return res

def run_python(command: str):
"""runs a python script and returns the results"""
with open('r2ai_tmp.py', 'w') as f:
f.write(command)
builtins.print('\x1b[1;32mRunning \x1b[4m' + "python code" + '\x1b[0m')
builtins.print(command)
r2lang.cmd('#!python r2ai_tmp.py > $tmp')
res = r2lang.cmd('cat $tmp')
r2lang.cmd('rm r2ai_tmp.py')
builtins.print('\x1b[1;32mResult\x1b[0m\n' + res)
return res

def process_tool_calls(interpreter, tool_calls):
interpreter.messages.append({ "content": None, "tool_calls": tool_calls, "role": "assistant" })
for tool_call in tool_calls:
Expand All @@ -101,18 +120,9 @@ def process_tool_calls(interpreter, tool_calls):
if type(args) is str:
args = { "command": args }
if "command" in args:
builtins.print('\x1b[1;32mRunning \x1b[4m' + args["command"] + '\x1b[0m')
res = r2lang.cmd(args["command"])
builtins.print(res)
res = r2cmd(args["command"])
elif tool_call["function"]["name"] == "run_python":
with open('r2ai_tmp.py', 'w') as f:
f.write(args["command"])
builtins.print('\x1b[1;32mRunning \x1b[4m' + "python code" + '\x1b[0m')
builtins.print(args["command"])
r2lang.cmd('#!python r2ai_tmp.py > $tmp')
res = r2lang.cmd('cat $tmp')
r2lang.cmd('rm r2ai_tmp.py')
builtins.print('\x1b[1;32mResult\x1b[0m\n' + res)
res = run_python(args["command"])
if (not res or len(res) == 0) and interpreter.model.startswith('meetkai/'):
res = "OK done"
interpreter.messages.append({"role": "tool", "content": ANSI_REGEX.sub('', res), "name": tool_call["function"]["name"], "tool_call_id": tool_call["id"] if "id" in tool_call else None})
Expand Down Expand Up @@ -197,7 +207,7 @@ def chat(interpreter):
lastmsg = interpreter.messages[-1]["content"]
chat_context = context_from_msg (lastmsg)
#print("#### CONTEXT BEGIN")
print(chat_context) # DEBUG
#print(chat_context) # DEBUG
#print("#### CONTEXT END")
if chat_context != "":
interpreter.messages.insert(0,{"role": "user", "content": chat_context})
Expand Down Expand Up @@ -277,6 +287,34 @@ def chat(interpreter):
temperature=float(interpreter.env["llm.temperature"]),
)
process_streaming_response(interpreter, [response])
elif interpreter.model.startswith("google"):
if not interpreter.google_client:
try:
import google.generativeai as google
google.configure(api_key=os.environ['GOOGLE_API_KEY'])
except ImportError:
print("pip install -U google-generativeai", file=sys.stderr)
return

interpreter.google_client = google.GenerativeModel(interpreter.model[7:])
if not interpreter.google_chat:
interpreter.google_chat = interpreter.google_client.start_chat(
enable_automatic_function_calling=True
)

response = interpreter.google_chat.send_message(
interpreter.messages[-1]["content"],
generation_config={
"max_output_tokens": int(interpreter.env["llm.maxtokens"]),
"temperature": float(interpreter.env["llm.temperature"])
},
safety_settings=[{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_NONE"
}],
tools=[r2cmd, run_python]
)
print(response.text)
else:
chat_format = interpreter.llama_instance.chat_format
is_functionary = interpreter.model.startswith("meetkai/")
Expand Down
33 changes: 33 additions & 0 deletions r2ai/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from .voice import tts
from .const import R2AI_HOMEDIR
from . import auto
import os

try:
from openai import OpenAI
have_openai = True
Expand All @@ -28,6 +30,14 @@
have_groq = False
pass

try:
import google.generativeai as google
google.configure(api_key=os.environ['GOOGLE_API_KEY'])
have_google = True
except Exception as e:
have_google = False
pass

import re
import os
import traceback
Expand Down Expand Up @@ -502,6 +512,8 @@ def __init__(self):
self.openai_client = None
self.anthropic_client = None
self.groq_client = None
self.google_client = None
self.google_chat = None
self.api_base = None # Will set it to whatever OpenAI wants
self.system_message = ""
self.env["debug"] = "false"
Expand Down Expand Up @@ -912,6 +924,27 @@ def respond(self):
temperature=float(self.env["llm.temperature"]),
messages=self.messages
)
if self.env["chat.reply"] == "true":
self.messages.append({"role": "assistant", "content": completion.content})
print(completion.content)
elif self.model.startswith('google:'):
if have_google:
if not self.google_client:
self.google_client = google.GenerativeModel(self.model[7:])
if not self.google_chat:
self.google_chat = self.google_client.start_chat()

completion = self.google_chat.send_message(
self.messages[-1]["content"],
generation_config={
"max_output_tokens": maxtokens,
"temperature": float(self.env["llm.temperature"])
}
)
if self.env["chat.reply"] == "true":
self.messages.append({"role": "assistant", "content": completion.text})
print(completion.text)
return
else:
# non-openai aka local-llama model
if self.llama_instance == None:
Expand Down
1 change: 1 addition & 0 deletions r2ai/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __main__():
have_r2pipe = True
except:
pass

if not have_rlang and not have_r2pipe and sys.argv[0] != 'main.py' and os.path.exists("venv/bin/python"):
os.system("venv/bin/python main.py")
sys.exit(0)
Expand Down
5 changes: 4 additions & 1 deletion r2ai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ def models():
-m groq:gemma-7b-it
-m groq:llama2-70b-4096
-m groq:mixtral-8x7b-32768
Google:
-m google:gemini-1.0-pro
-m google:gemini-1.5-pro-latest
GPT4:
-m NousResearch/Hermes-2-Pro-Mistral-7B-GGUF
-m TheBloke/Chronos-70B-v2-GGUF
Expand Down Expand Up @@ -472,7 +475,7 @@ def enough_disk_space(size, path) -> bool:
return False

def new_get_hf_llm(repo_id, debug_mode, context_window):
if repo_id.startswith("openai:") or repo_id.startswith("anthropic:") or repo_id.startswith("groq:"):
if repo_id.startswith("openai:") or repo_id.startswith("anthropic:") or repo_id.startswith("groq:") or repo_id.startswith("google:"):
return repo_id
if not os.path.exists(repo_id):
return get_hf_llm(repo_id, debug_mode, context_window)
Expand Down

0 comments on commit 260d61c

Please sign in to comment.