diff --git a/src/config.py b/src/config.py index 57dc7ab..0f79e10 100644 --- a/src/config.py +++ b/src/config.py @@ -19,6 +19,7 @@ api_key_list = [value for key, value in _config.items('openai') if key.startswith('api') and value] temperature = _config.get('openai', 'temperature') +base_url = _config.get('openai', 'base_url') room_id = _config.getint('room', 'id') mysql = dict(_config.items('mysql')) sqlite = dict(_config.items('sqlite')) diff --git a/src/core/vup.py b/src/core/vup.py index 11a3cf0..be14f58 100644 --- a/src/core/vup.py +++ b/src/core/vup.py @@ -50,6 +50,7 @@ async def generate_chat(self, embedding): messages = self.event.get_prompt_messages(**extra_kwargs) logger.info(f"prompt:{messages[1]} 开始请求gpt") chat = ChatOpenAI(temperature=config.temperature, max_retries=2, max_tokens=150, + openai_api_base=config.base_url, openai_api_key=get_openai_key()) llm_res = chat.generate([messages]) assistant_content = llm_res.generations[0][0].text diff --git a/src/manager.py b/src/manager.py index 94cd57d..262d4b6 100644 --- a/src/manager.py +++ b/src/manager.py @@ -45,13 +45,16 @@ def test_net(self): from langchain import OpenAI import requests # 测试外网环境(可能异常) - r = requests.get(url='https://www.youtube.com/', verify=False, proxies={ - 'http': f'http://{config.proxy}/', - 'https': f'http://{config.proxy}/' - }) + proxies = None + if config.proxy: + proxies = { + 'http': f'http://{config.proxy}/', + 'https': f'http://{config.proxy}/' + } + r = requests.get(url='https://www.youtube.com/', verify=False, proxies=proxies) assert r.status_code == 200 # 测试openai库 - llm = OpenAI(temperature=config.temperature, openai_api_key=get_openai_key(), verbose=config.debug) + llm = OpenAI(temperature=config.temperature, openai_api_base=config.base_url, openai_api_key=get_openai_key(), verbose=config.debug) text = "python是世界上最好的语言 " print(llm(text)) print('测试成功!') diff --git a/src/utils/utils.py b/src/utils/utils.py index 3c8a29b..9fdd701 100644 --- a/src/utils/utils.py +++ b/src/utils/utils.py @@ -154,6 +154,7 @@ def top_n_indices_from_embeddings( def sync_get_embedding(texts: List[str], model="text-embedding-ada-002"): + openai.api_base=config.base_url res = openai.Embedding.create(input=texts, model=model, api_key=get_openai_key()) if isinstance(texts, list) and len(texts) == 1: return res['data'][0]['embedding']