Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions .env_template
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
ENVIRONMENT=development

# ===== MODEL =====
API_TYPE=openai or azure
API_KEY={如:sk-e3fba700e89140408078debbc9fa9c0c}
API_BASE_URL=https://api.deepseek.com/v1
API_VERSION=your-azure-api-version
# it won't work if API_TYPE is openai
MODEL_NAME=deepseek-chat
MAX_TOKENS=4096
TEMPERATURE=0.0
Expand All @@ -22,32 +25,40 @@ TAVILY_API_KEY={如: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}
# Co-Sight可分层配置模型:规划,执行,工具以及多模态
# 在对应的模型配置项下面,配置模型参数(API_KEY,API_BASE_URL,MODEL_NAME都配置方可生效)
# # ===== PLAN MODEL =====
# PLAN_API_TYPE=
# PLAN_API_KEY=
# PLAN_API_BASE_URL=
# PLAN_API_VERSION=
# PLAN_MODEL_NAME=
# PLAN_MAX_TOKENS=
# PLAN_TEMPERATURE=
# PLAN_PROXY=
#
# # ===== ACT MODEL =====
# ACT_API_TYPE=
# ACT_API_KEY=
# ACT_API_BASE_URL=
# ACT_API_VERSION=
# ACT_MODEL_NAME=
# ACT_MAX_TOKENS=
# ACT_TEMPERATURE=
# ACT_PROXY=
#
# # ===== TOOL MODEL =====
# TOOL_API_TYPE=
# TOOL_API_KEY=
# TOOL_API_BASE_URL=
# TOOL_API_VERSION=
# TOOL_MODEL_NAME=
# TOOL_MAX_TOKENS=
# TOOL_TEMPERATURE=
# TOOL_PROXY=
#
# # ===== VISION MODEL =====
# VISION_API_TYPE=
# VISION_API_KEY=
# VISION_API_BASE_URL=
# VISION_API_VERSION=
# VISION_MODEL_NAME=
# VISION_MAX_TOKENS=
# VISION_TEMPERATURE=
Expand Down
7 changes: 4 additions & 3 deletions app/cosight/llm/chat_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,20 @@
# License for the specific language governing permissions and limitations
# under the License.

from typing import List, Dict, Any
from openai import OpenAI
from typing import List, Dict, Any, Optional, Union
from openai import OpenAI,AzureOpenAI

from app.cosight.task.time_record_util import time_record


class ChatLLM:
def __init__(self, base_url: str, api_key: str, model: str, client: OpenAI, max_tokens: int = 4096,
def __init__(self, base_url: str, api_key: str, model: str, client: Union[OpenAI,AzureOpenAI],api_version:str = None, max_tokens: int = 4096,
temperature: float = 0.0, stream: bool = False, tools: List[Any] = None):
self.tools = tools or []
self.client = client
self.base_url = base_url
self.api_key = api_key
self.api_version = api_version
self.model = model
self.stream = stream
self.temperature = temperature
Expand Down
33 changes: 28 additions & 5 deletions config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ def get_model_config() -> dict[str, Optional[str | int | float]]:
temperature = os.environ.get("TEMPERATURE")

return {
"api_type": os.environ.get("API_TYPE","openai"),
"api_key": os.environ.get("API_KEY"),
"base_url": os.environ.get("API_BASE_URL"),
"api_version": os.environ.get("API_VERSION",None),
"model": os.environ.get("MODEL_NAME"),
"max_tokens": int(max_tokens) if max_tokens and max_tokens.strip() else None,
"temperature": float(temperature) if temperature and temperature.strip() else None,
Expand All @@ -40,10 +42,14 @@ def get_model_config() -> dict[str, Optional[str | int | float]]:
# ========== 规划大模型配置 ==========
def get_plan_model_config() -> dict[str, Optional[str | int | float]]:
"""获取Plan专用API配置,如果缺少配置则退回默认"""
plan_api_type = os.environ.get("PLAN_API_TYPE","openai")
plan_api_key = os.environ.get("PLAN_API_KEY")
plan_base_url = os.environ.get("PLAN_API_BASE_URL")
plan_api_version = os.environ.get("PLAN_API_VERSION",None)
model_name = os.environ.get("PLAN_MODEL_NAME")

if plan_api_type == "azure" and plan_api_version is None:
raise ValueError("Azure API requires API_VERSION to be set.")
# 检查三个字段是否都存在且非空
if not (plan_api_key and plan_base_url and model_name):
return get_model_config()
Expand All @@ -52,8 +58,10 @@ def get_plan_model_config() -> dict[str, Optional[str | int | float]]:
temperature = os.environ.get("PLAN_TEMPERATURE")

return {
"api_type": plan_api_type,
"api_key": plan_api_key,
"base_url": plan_base_url,
"api_version": plan_api_version,
"model": model_name,
"max_tokens": int(max_tokens) if max_tokens and max_tokens.strip() else None,
"temperature": float(temperature) if temperature and temperature.strip() else None,
Expand All @@ -64,10 +72,13 @@ def get_plan_model_config() -> dict[str, Optional[str | int | float]]:
# ========== 执行大模型配置 ==========
def get_act_model_config() -> dict[str, Optional[str | int | float]]:
"""获取Act专用API配置,如果缺少配置则退回默认"""
act_api_type = os.environ.get("ACT_API_TYPE","openai")
act_api_key = os.environ.get("ACT_API_KEY")
act_base_url = os.environ.get("ACT_API_BASE_URL")
act_api_version = os.environ.get("ACT_API_VERSION",None)
model_name = os.environ.get("ACT_MODEL_NAME")

if act_api_type == "azure" and act_api_version is None:
raise ValueError("Azure API requires API_VERSION to be set.")
# 检查三个字段是否都存在且非空
if not (act_api_key and act_base_url and model_name):
return get_model_config()
Expand All @@ -76,8 +87,10 @@ def get_act_model_config() -> dict[str, Optional[str | int | float]]:
temperature = os.environ.get("ACT_TEMPERATURE")

return {
"api_type": act_api_type,
"api_key": act_api_key,
"base_url": act_base_url,
"api_version": act_api_version,
"model": model_name,
"max_tokens": int(max_tokens) if max_tokens and max_tokens.strip() else None,
"temperature": float(temperature) if temperature and temperature.strip() else None,
Expand All @@ -88,20 +101,25 @@ def get_act_model_config() -> dict[str, Optional[str | int | float]]:
# ========== 工具大模型配置 ==========
def get_tool_model_config() -> dict[str, Optional[str | int | float]]:
"""获取Tool专用API配置,如果缺少配置则退回默认"""
tool_api_type = os.environ.get("TOOL_API_TYPE","openai")
tool_api_key = os.environ.get("TOOL_API_KEY")
tool_base_url = os.environ.get("TOOL_API_BASE_URL")
tool_api_version = os.environ.get("TOOL_API_VERSION",None)
model_name = os.environ.get("TOOL_MODEL_NAME")

if tool_api_type == "azure" and tool_api_version is None:
raise ValueError("Azure API requires API_VERSION to be set.")
# 检查三个字段是否都存在且非空
if not (tool_api_key and tool_base_url and model_name):
if not (tool_api_type and tool_api_key and tool_base_url and model_name):
return get_model_config()

max_tokens = os.environ.get("TOOL_MAX_TOKENS")
temperature = os.environ.get("TOOL_TEMPERATURE")

return {
"api_type": tool_api_type,
"api_key": tool_api_key,
"base_url": tool_base_url,
"api_version": tool_api_version,
"model": model_name,
"max_tokens": int(max_tokens) if max_tokens and max_tokens.strip() else None,
"temperature": float(temperature) if temperature and temperature.strip() else None,
Expand All @@ -112,20 +130,25 @@ def get_tool_model_config() -> dict[str, Optional[str | int | float]]:
# ========== 多模态大模型配置 ==========
def get_vision_model_config() -> dict[str, Optional[str | int | float]]:
"""获取Vision专用API配置,如果缺少配置则退回默认"""
vision_api_type = os.environ.get("VISION_API_TYPE","openai")
vision_api_key = os.environ.get("VISION_API_KEY")
vision_base_url = os.environ.get("VISION_API_BASE_URL")
vision_api_version = os.environ.get("VISION_API_VERSION",None)
model_name = os.environ.get("VISION_MODEL_NAME")

if vision_api_type == "azure" and vision_api_version is None:
raise ValueError("Azure API requires API_VERSION to be set.")
# 检查三个字段是否都存在且非空
if not (vision_api_key and vision_base_url and model_name):
if not (vision_api_type and vision_api_key and vision_base_url and model_name):
return get_model_config()

max_tokens = os.environ.get("VISION_MAX_TOKENS")
temperature = os.environ.get("VISION_TEMPERATURE")

return {
"api_type": vision_api_type,
"api_key": vision_api_key,
"base_url": vision_base_url,
"api_version": vision_api_version,
"model": model_name,
"max_tokens": int(max_tokens) if max_tokens and max_tokens.strip() else None,
"temperature": float(temperature) if temperature and temperature.strip() else None,
Expand Down
24 changes: 17 additions & 7 deletions llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
# License for the specific language governing permissions and limitations
# under the License.
import httpx
from openai import OpenAI
from openai import OpenAI, AzureOpenAI
from openai import OpenAI, AzureOpenAI

from app.cosight.llm.chat_llm import ChatLLM
from config.config import *


def set_model(model_config: dict[str, Optional[str | int | float]]):
openai_llm = None
http_client_kwargs = {
"headers": {
'Content-Type': 'application/json',
Expand All @@ -31,12 +33,20 @@ def set_model(model_config: dict[str, Optional[str | int | float]]):

if model_config['proxy']:
http_client_kwargs["proxy"] = model_config['proxy']

openai_llm = OpenAI(
base_url=model_config['base_url'],
api_key=model_config['api_key'],
http_client=httpx.Client(**http_client_kwargs)
)
if model_config['api_type'] == "azure" and model_config['api_version'] is not None:
openai_llm = AzureOpenAI(
base_url=model_config['base_url'],
api_key=model_config['api_key'],
api_version=model_config['api_version'],
http_client=httpx.Client(**http_client_kwargs)
)
if model_config['api_type'] != "azure":
openai_llm = OpenAI(
base_url=model_config['base_url'],
api_key=model_config['api_key'],
http_client=httpx.Client(**http_client_kwargs)
)


chat_llm_kwargs = {
"model": model_config['model'],
Expand Down