Skip to content

Commit 350246d

Browse files
authored
Merge pull request #185 from callanwu/master
Add support for replicate
2 parents ff2821d + b75c7e5 commit 350246d

File tree

5 files changed

+21
-32
lines changed

5 files changed

+21
-32
lines changed

examples/Single_Agent/chat_bot/config.json

-12
Original file line numberDiff line numberDiff line change
@@ -37,20 +37,8 @@
3737
],
3838
"begin_role": "Yang bufan",
3939
"begin_query": "hello,What are you looking for me for?",
40-
"LLM_type": "OpenAI",
41-
"LLM": {
42-
"temperature": 1.0,
43-
"model": "gpt-3.5-turbo-16k-0613",
44-
"log_path": "logs/Response_state"
45-
},
4640
"agent_states": {
4741
"Yang bufan": {
48-
"LLM_type": "OpenAI",
49-
"LLM": {
50-
"temperature": 1.0,
51-
"model": "gpt-3.5-turbo-16k-0613",
52-
"log_path": "logs/Yang_bufan"
53-
},
5442
"style": {
5543
"role": "The director of a private detective agency, a cold detective"
5644
},

requirements.txt

+3-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ fastapi
33
google_api_python_client
44
google_auth_oauthlib
55
gradio
6-
langchain==0.0.329
6+
langchain
77
numpy
88
openai
99
litellm
@@ -21,7 +21,7 @@ text2vec
2121
torch
2222
tqdm
2323
uvicorn
24-
pydantic==1.10.9
25-
typing_extensions==4.5.0
24+
pydantic
25+
typing_extensions
2626
serpapi
2727
google-search-results

src/agents/Agent/Agent.py

+4
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,10 @@ def from_config(cls, config_path):
8787
current_state_begin_role = current_state["begin_role"] if "begin_role" in current_state else current_state["roles"][0]
8888
agent_begins[state_name]["is_begin"] = current_state_begin_role==agent_role if "begin_role" in current_state else False
8989
agent_begins[state_name]["begin_query"] = current_state["begin_query"] if "begin_query" in current_state else " "
90+
if "LLM_type" not in current_state["agent_states"][agent_role]:
91+
current_state["agent_states"][agent_role]["LLM_type"] = config["LLM_type"]
92+
if "LLM" not in current_state["agent_states"][agent_role]:
93+
current_state["agent_states"][agent_role]["LLM"] = config["LLM"]
9094
agent_LLMs[state_name] = init_LLM("logs"+os.sep+f"{agent_name}",**current_state["agent_states"][agent_role])
9195
agents[agent_name] = cls(
9296
agent_name,

src/agents/LLM/base_LLM.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ def __init__(self,**kwargs) -> None:
1919
super().__init__()
2020
self.MAX_CHAT_HISTORY = eval(
2121
os.environ["MAX_CHAT_HISTORY"]) if "MAX_CHAT_HISTORY" in os.environ else 10
22-
2322
self.model = kwargs["model"] if "model" in kwargs else "gpt-3.5-turbo-16k-0613"
2423
self.temperature = kwargs["temperature"] if "temperature" in kwargs else 0.3
2524
self.log_path = kwargs["log_path"].replace("/",os.sep) if "log_path" in kwargs else "logs"
@@ -205,16 +204,15 @@ def get_response(self,
205204
messages=messages,
206205
functions=functions,
207206
function_call=function_call,
208-
temperature=temperature,
209-
custom_llm_provider = "openai"
207+
temperature=temperature
210208
)
211209
else:
212210
response = litellm.completion(
213211
model=model,
214212
messages=messages,
215213
temperature=temperature,
216-
stream=stream,
217-
custom_llm_provider = "openai")
214+
stream=stream
215+
)
218216
break
219217
except Exception as e:
220218
print(e)
@@ -248,9 +246,9 @@ def init_LLM(default_log_path,**kwargs):
248246
)
249247
if LLM_type == "Replicate":
250248
LLM = (
251-
OpenAILLM(**kwargs["LLM"])
249+
ReplicateLLM(**kwargs["LLM"])
252250
if "LLM" in kwargs
253-
else OpenAILLM(model = "replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf",temperature=0.3,log_path=log_path)
251+
else ReplicateLLM(model = "replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf",temperature=0.3,log_path=log_path)
254252
)
255-
return LLM
253+
return LLM
256254

src/agents/utils.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -59,21 +59,20 @@ def get_embedding(sentence):
5959

6060
if embed_model_name in ["text-embedding-ada-002"]:
6161
client = OpenAI(api_key=(
62-
os.environ["EMBED_API_KEY"]
63-
if "EMBED_API_KEY" in os.environ
64-
else os.environ["API_KEY"]
65-
))
62+
os.environ["EMBED_API_KEY"]
63+
if "EMBED_API_KEY" in os.environ
64+
else os.environ["API_KEY"]
65+
))
6666
if "PROXY" in os.environ:
6767
assert (
6868
"http:" in os.environ["PROXY"] or "socks" in os.environ["PROXY"]
6969
), "PROXY error,PROXY must be http or socks"
7070
client.proxies = {os.environ["PROXY"]}
7171
if "EMBED_API_BASE" in os.environ or "EMBED_BASE" in os.environ:
72-
client.base_url = (
73-
os.environ["EMBED_API_BASE"]
74-
if "EMBED_API_BASE" in os.environ
75-
else os.environ["API_BASE"]
76-
)
72+
client.base_url = (os.environ["EMBED_API_BASE"]
73+
if "EMBED_API_BASE" in os.environ
74+
else os.environ["API_BASE"]
75+
)
7776
sentence = sentence.replace("\n", " ")
7877
embed = client.embeddings.create(input = sentence, model= embed_model_name,encoding_format="float").data[0].embedding
7978
embed = torch.tensor(embed, dtype=torch.float32)

0 commit comments

Comments
 (0)