Skip to content

Commit

Permalink
OpenAI 1.0 and vLLMs support (#127)
Browse files Browse the repository at this point in the history
* migrated to new openai library and added VLLM support

* updated README to add vLLM support

---------

Co-authored-by: taisazero <[email protected]>
  • Loading branch information
chenweize1998 and taisazero authored Mar 24, 2024
1 parent fa916e1 commit c7e5c1c
Show file tree
Hide file tree
Showing 12 changed files with 376 additions and 146 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -173,4 +173,5 @@ results
tmp/
data/toolbench
logs/
ci_smoke_test_output/
ci_smoke_test_output/
.env
17 changes: 17 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,8 @@ https://github.com/OpenBMB/AgentVerse/assets/11704492/4d07da68-f942-4205-b558-f1
- [Framework Required Modules](#framework-required-modules-1)
- [CLI Example](#cli-example-1)
- [Local Model Support](#local-model-support)
- [vLLM Support](#vllm-support)
- [FSChat Support](#fschat-support)
- [1. Install the Additional Dependencies](#1-install-the-additional-dependencies)
- [2. Launch the Local Server](#2-launch-the-local-server)
- [3. Modify the Config File](#3-modify-the-config-file)
Expand Down Expand Up @@ -351,6 +353,21 @@ We have provided more tasks in `agentverse/tasks/tasksolving/tool_using/` that s
Also, you can take a look at `agentverse/tasks/tasksolving` for more experiments we have done in our paper.

## Local Model Support
## vLLM Support
If you want to use vLLM, follow the guide [here](https://docs.vllm.ai/en/latest/getting_started/quickstart.html) to install and setup the vLLM server which is used to handle larger inference workloads. Create the following environment variables to connect to the vLLM server:
```bash
export VLLM_API_KEY="your_api_key_here"
export VLLM_API_BASE="http://your_vllm_url_here"
```

Then modify the `model` in the task config file so that it matches the model name in the vLLM server. For example:
```yaml
model_type: vllm
model: llama-2-7b-chat-hf
```
## FSChat Support
This section provides a step-by-step guide to integrate FSChat into AgentVerse. FSChat is a framework that supports local models such as LLaMA, Vicunna, etc. running on your local machine.
### 1. Install the Additional Dependencies
If you want to use local models such as LLaMA, you need to additionally install some other dependencies:
```bash
Expand Down
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
import json
import ast
import openai
from openai import OpenAI
from string import Template
from colorama import Fore
from aiohttp import ClientSession
from copy import deepcopy
from typing import TYPE_CHECKING, Any, List, Tuple

import httpx
from agentverse.agents import ExecutorAgent
from agentverse.message import Message, ExecutorMessage, SolverMessage
from agentverse.logging import logger

from . import BaseExecutor, executor_registry
import asyncio
from agentverse.llms.utils.jsonrepair import JsonRepair
from agentverse.llms.openai import DEFAULT_CLIENT_ASYNC as client_async

url = "http://127.0.0.1:8080"

SUMMARIZE_PROMPT = """Here is the text gathered from a webpage, and a question you need to answer from the webpage.
-- Webpage --
${webpage}
Expand Down Expand Up @@ -219,7 +219,7 @@ async def _summarize_webpage(webpage, question):
)
for _ in range(3):
try:
response = await openai.ChatCompletion.acreate(
response = await client_async.chat.completions.create(
messages=[{"role": "user", "content": summarize_prompt}],
model="gpt-3.5-turbo-16k",
functions=[
Expand Down Expand Up @@ -261,7 +261,7 @@ async def _summarize_webpage(webpage, question):
continue
arguments = ast.literal_eval(
JsonRepair(
response["choices"][0]["message"]["function_call"]["arguments"]
response.choices[0].message.function_call.arguments
).repair()
)
ret = (
Expand Down Expand Up @@ -300,7 +300,7 @@ async def _summarize_webpage(webpage, question):
}
for i in range(3):
try:
async with ClientSession(cookies=cookies, trust_env=True) as session:
async with httpx.AsyncClient(cookies=cookies, trust_env=True) as session:
if cookies is None:
async with session.post(
f"{url}/get_cookie", timeout=30
Expand All @@ -327,12 +327,12 @@ async def _summarize_webpage(webpage, question):
) as response:
content = await response.text()
if command == "WebEnv_browse_website":
openai.aiosession.set(session)
client_async.http_client = session
result = await _summarize_webpage(
content, arguments["goals_to_browse"]
)
elif command == "WebEnv_search_and_browse":
openai.aiosession.set(session)
client_async.http_client = session
content = json.loads(content)

# for i in range(len(content)):
Expand Down
31 changes: 26 additions & 5 deletions agentverse/llms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,33 @@
"vicuna-13b-v1.5",
]
LOCAL_LLMS_MAPPING = {
"llama-2-7b-chat-hf": "meta-llama/Llama-2-7b-chat-hf",
"llama-2-13b-chat-hf": "meta-llama/Llama-2-13b-chat-hf",
"llama-2-70b-chat-hf": "meta-llama/Llama-2-70b-chat-hf",
"vicuna-7b-v1.5": "lmsys/vicuna-7b-v1.5",
"vicuna-13b-v1.5": "lmsys/vicuna-13b-v1.5",
"llama-2-7b-chat-hf": {
"hf_model_name": "meta-llama/Llama-2-7b-chat-hf",
"base_url": "http://localhost:5000/v1",
"api_key": "EMPTY",
},
"llama-2-13b-chat-hf": {
"hf_model_name": "meta-llama/Llama-2-13b-chat-hf",
"base_url": "http://localhost:5000/v1",
"api_key": "EMPTY",
},
"llama-2-70b-chat-hf": {
"hf_model_name": "meta-llama/Llama-2-70b-chat-hf",
"base_url": "http://localhost:5000/v1",
"api_key": "EMPTY",
},
"vicuna-7b-v1.5": {
"hf_model_name": "lmsys/vicuna-7b-v1.5",
"base_url": "http://localhost:5000/v1",
"api_key": "EMPTY",
},
"vicuna-13b-v1.5": {
"hf_model_name": "lmsys/vicuna-13b-v1.5",
"base_url": "http://localhost:5000/v1",
"api_key": "EMPTY",
},
}


from .base import BaseLLM, BaseChatModel, BaseCompletionModel, LLMResult
from .openai import OpenAIChat
5 changes: 3 additions & 2 deletions agentverse/llms/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from abc import abstractmethod
from typing import Dict, Any

from typing import Any, Dict, Optional
from pydantic import BaseModel, Field


Expand All @@ -20,6 +19,8 @@ class BaseModelArgs(BaseModel):
class BaseLLM(BaseModel):
args: BaseModelArgs = Field(default_factory=BaseModelArgs)
max_retry: int = Field(default=3)
client_args: Optional[Dict] = Field(default={})
is_azure: bool = Field(default=False)

@abstractmethod
def get_spend(self) -> float:
Expand Down
Loading

0 comments on commit c7e5c1c

Please sign in to comment.