diff --git a/.env_template b/.env_template index e36bfa1..e7b6dcd 100644 --- a/.env_template +++ b/.env_template @@ -2,8 +2,11 @@ ENVIRONMENT=development # ===== MODEL ===== +API_TYPE=openai or azure API_KEY={如:sk-e3fba700e89140408078debbc9fa9c0c} API_BASE_URL=https://api.deepseek.com/v1 +API_VERSION=your-azure-api-version +# it won't work if API_TYPE is openai MODEL_NAME=deepseek-chat MAX_TOKENS=4096 TEMPERATURE=0.0 @@ -22,32 +25,40 @@ TAVILY_API_KEY={如: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} # Co-Sight可分层配置模型:规划,执行,工具以及多模态 # 在对应的模型配置项下面,配置模型参数(API_KEY,API_BASE_URL,MODEL_NAME都配置方可生效) # # ===== PLAN MODEL ===== +# PLAN_API_TYPE= # PLAN_API_KEY= # PLAN_API_BASE_URL= +# PLAN_API_VERSION= # PLAN_MODEL_NAME= # PLAN_MAX_TOKENS= # PLAN_TEMPERATURE= # PLAN_PROXY= # # # ===== ACT MODEL ===== +# ACT_API_TYPE= # ACT_API_KEY= # ACT_API_BASE_URL= +# ACT_API_VERSION= # ACT_MODEL_NAME= # ACT_MAX_TOKENS= # ACT_TEMPERATURE= # ACT_PROXY= # # # ===== TOOL MODEL ===== +# TOOL_API_TYPE= # TOOL_API_KEY= # TOOL_API_BASE_URL= +# TOOL_API_VERSION= # TOOL_MODEL_NAME= # TOOL_MAX_TOKENS= # TOOL_TEMPERATURE= # TOOL_PROXY= # # # ===== VISION MODEL ===== +# VISION_API_TYPE= # VISION_API_KEY= # VISION_API_BASE_URL= +# VISION_API_VERSION= # VISION_MODEL_NAME= # VISION_MAX_TOKENS= # VISION_TEMPERATURE= diff --git a/app/cosight/llm/chat_llm.py b/app/cosight/llm/chat_llm.py index 36bbc1b..119f190 100644 --- a/app/cosight/llm/chat_llm.py +++ b/app/cosight/llm/chat_llm.py @@ -13,19 +13,20 @@ # License for the specific language governing permissions and limitations # under the License. -from typing import List, Dict, Any -from openai import OpenAI +from typing import List, Dict, Any, Optional, Union +from openai import OpenAI,AzureOpenAI from app.cosight.task.time_record_util import time_record class ChatLLM: - def __init__(self, base_url: str, api_key: str, model: str, client: OpenAI, max_tokens: int = 4096, + def __init__(self, base_url: str, api_key: str, model: str, client: Union[OpenAI,AzureOpenAI],api_version:str = None, max_tokens: int = 4096, temperature: float = 0.0, stream: bool = False, tools: List[Any] = None): self.tools = tools or [] self.client = client self.base_url = base_url self.api_key = api_key + self.api_version = api_version self.model = model self.stream = stream self.temperature = temperature diff --git a/config/config.py b/config/config.py index 0636a66..32d1a2f 100644 --- a/config/config.py +++ b/config/config.py @@ -28,8 +28,10 @@ def get_model_config() -> dict[str, Optional[str | int | float]]: temperature = os.environ.get("TEMPERATURE") return { + "api_type": os.environ.get("API_TYPE","openai"), "api_key": os.environ.get("API_KEY"), "base_url": os.environ.get("API_BASE_URL"), + "api_version": os.environ.get("API_VERSION",None), "model": os.environ.get("MODEL_NAME"), "max_tokens": int(max_tokens) if max_tokens and max_tokens.strip() else None, "temperature": float(temperature) if temperature and temperature.strip() else None, @@ -40,10 +42,14 @@ def get_model_config() -> dict[str, Optional[str | int | float]]: # ========== 规划大模型配置 ========== def get_plan_model_config() -> dict[str, Optional[str | int | float]]: """获取Plan专用API配置,如果缺少配置则退回默认""" + plan_api_type = os.environ.get("PLAN_API_TYPE","openai") plan_api_key = os.environ.get("PLAN_API_KEY") plan_base_url = os.environ.get("PLAN_API_BASE_URL") + plan_api_version = os.environ.get("PLAN_API_VERSION",None) model_name = os.environ.get("PLAN_MODEL_NAME") + if plan_api_type == "azure" and plan_api_version is None: + raise ValueError("Azure API requires API_VERSION to be set.") # 检查三个字段是否都存在且非空 if not (plan_api_key and plan_base_url and model_name): return get_model_config() @@ -52,8 +58,10 @@ def get_plan_model_config() -> dict[str, Optional[str | int | float]]: temperature = os.environ.get("PLAN_TEMPERATURE") return { + "api_type": plan_api_type, "api_key": plan_api_key, "base_url": plan_base_url, + "api_version": plan_api_version, "model": model_name, "max_tokens": int(max_tokens) if max_tokens and max_tokens.strip() else None, "temperature": float(temperature) if temperature and temperature.strip() else None, @@ -64,10 +72,13 @@ def get_plan_model_config() -> dict[str, Optional[str | int | float]]: # ========== 执行大模型配置 ========== def get_act_model_config() -> dict[str, Optional[str | int | float]]: """获取Act专用API配置,如果缺少配置则退回默认""" + act_api_type = os.environ.get("ACT_API_TYPE","openai") act_api_key = os.environ.get("ACT_API_KEY") act_base_url = os.environ.get("ACT_API_BASE_URL") + act_api_version = os.environ.get("ACT_API_VERSION",None) model_name = os.environ.get("ACT_MODEL_NAME") - + if act_api_type == "azure" and act_api_version is None: + raise ValueError("Azure API requires API_VERSION to be set.") # 检查三个字段是否都存在且非空 if not (act_api_key and act_base_url and model_name): return get_model_config() @@ -76,8 +87,10 @@ def get_act_model_config() -> dict[str, Optional[str | int | float]]: temperature = os.environ.get("ACT_TEMPERATURE") return { + "api_type": act_api_type, "api_key": act_api_key, "base_url": act_base_url, + "api_version": act_api_version, "model": model_name, "max_tokens": int(max_tokens) if max_tokens and max_tokens.strip() else None, "temperature": float(temperature) if temperature and temperature.strip() else None, @@ -88,20 +101,25 @@ def get_act_model_config() -> dict[str, Optional[str | int | float]]: # ========== 工具大模型配置 ========== def get_tool_model_config() -> dict[str, Optional[str | int | float]]: """获取Tool专用API配置,如果缺少配置则退回默认""" + tool_api_type = os.environ.get("TOOL_API_TYPE","openai") tool_api_key = os.environ.get("TOOL_API_KEY") tool_base_url = os.environ.get("TOOL_API_BASE_URL") + tool_api_version = os.environ.get("TOOL_API_VERSION",None) model_name = os.environ.get("TOOL_MODEL_NAME") - + if tool_api_type == "azure" and tool_api_version is None: + raise ValueError("Azure API requires API_VERSION to be set.") # 检查三个字段是否都存在且非空 - if not (tool_api_key and tool_base_url and model_name): + if not (tool_api_type and tool_api_key and tool_base_url and model_name): return get_model_config() max_tokens = os.environ.get("TOOL_MAX_TOKENS") temperature = os.environ.get("TOOL_TEMPERATURE") return { + "api_type": tool_api_type, "api_key": tool_api_key, "base_url": tool_base_url, + "api_version": tool_api_version, "model": model_name, "max_tokens": int(max_tokens) if max_tokens and max_tokens.strip() else None, "temperature": float(temperature) if temperature and temperature.strip() else None, @@ -112,20 +130,25 @@ def get_tool_model_config() -> dict[str, Optional[str | int | float]]: # ========== 多模态大模型配置 ========== def get_vision_model_config() -> dict[str, Optional[str | int | float]]: """获取Vision专用API配置,如果缺少配置则退回默认""" + vision_api_type = os.environ.get("VISION_API_TYPE","openai") vision_api_key = os.environ.get("VISION_API_KEY") vision_base_url = os.environ.get("VISION_API_BASE_URL") + vision_api_version = os.environ.get("VISION_API_VERSION",None) model_name = os.environ.get("VISION_MODEL_NAME") - + if vision_api_type == "azure" and vision_api_version is None: + raise ValueError("Azure API requires API_VERSION to be set.") # 检查三个字段是否都存在且非空 - if not (vision_api_key and vision_base_url and model_name): + if not (vision_api_type and vision_api_key and vision_base_url and model_name): return get_model_config() max_tokens = os.environ.get("VISION_MAX_TOKENS") temperature = os.environ.get("VISION_TEMPERATURE") return { + "api_type": vision_api_type, "api_key": vision_api_key, "base_url": vision_base_url, + "api_version": vision_api_version, "model": model_name, "max_tokens": int(max_tokens) if max_tokens and max_tokens.strip() else None, "temperature": float(temperature) if temperature and temperature.strip() else None, diff --git a/llm.py b/llm.py index 98fd40c..05c3962 100644 --- a/llm.py +++ b/llm.py @@ -13,13 +13,15 @@ # License for the specific language governing permissions and limitations # under the License. import httpx -from openai import OpenAI +from openai import OpenAI, AzureOpenAI +from openai import OpenAI, AzureOpenAI from app.cosight.llm.chat_llm import ChatLLM from config.config import * def set_model(model_config: dict[str, Optional[str | int | float]]): + openai_llm = None http_client_kwargs = { "headers": { 'Content-Type': 'application/json', @@ -31,12 +33,20 @@ def set_model(model_config: dict[str, Optional[str | int | float]]): if model_config['proxy']: http_client_kwargs["proxy"] = model_config['proxy'] - - openai_llm = OpenAI( - base_url=model_config['base_url'], - api_key=model_config['api_key'], - http_client=httpx.Client(**http_client_kwargs) - ) + if model_config['api_type'] == "azure" and model_config['api_version'] is not None: + openai_llm = AzureOpenAI( + base_url=model_config['base_url'], + api_key=model_config['api_key'], + api_version=model_config['api_version'], + http_client=httpx.Client(**http_client_kwargs) + ) + if model_config['api_type'] != "azure": + openai_llm = OpenAI( + base_url=model_config['base_url'], + api_key=model_config['api_key'], + http_client=httpx.Client(**http_client_kwargs) + ) + chat_llm_kwargs = { "model": model_config['model'],