diff --git a/view/bing.py b/view/bing.py index 4a35870..9077efd 100644 --- a/view/bing.py +++ b/view/bing.py @@ -1,45 +1,17 @@ # -*- coding: utf-8 -*- # Author: XiaoXinYo -from typing import AsyncGenerator +from typing import Union, AsyncGenerator from fastapi import APIRouter, Request, Response -from module import auxiliary, core -import asyncio +from fastapi.responses import StreamingResponse +from module import auxiliary, core, chat_bot +from EdgeGPT import ConversationStyle import config -import uuid -from EdgeGPT import Chatbot, ConversationStyle import re import BingImageCreator -from fastapi.responses import StreamingResponse -BingAPP = APIRouter() +BING_APP = APIRouter() STYLES = ['balanced', 'creative', 'precise'] -CHATBOT = {} - -async def checkToken(): - global CHATBOT - while True: - for token in CHATBOT.copy(): - chatBot = CHATBOT[token] - if auxiliary.getTimeStamp() - chatBot['useTimeStamp'] > 5 * 60: - await chatBot['chatBot'].close() - del chatBot - await asyncio.sleep(60) - -def getChatBot(token: str) -> tuple: - global CHATBOT - if token: - if token in CHATBOT: - chatBot = CHATBOT[token]['chatBot'] - else: - return token, None - else: - chatBot = Chatbot(proxy=config.PROXY, cookie_path='./cookie/bing.json') - token = str(uuid.uuid4()) - CHATBOT[token] = {} - CHATBOT[token]['chatBot'] = chatBot - CHATBOT[token]['useTimeStamp'] = auxiliary.getTimeStamp() - return token, chatBot def getStyleEnum(style: str) -> ConversationStyle: enum = ConversationStyle @@ -51,6 +23,11 @@ def getStyleEnum(style: str) -> ConversationStyle: enum = enum.precise return enum +def getPrompt(prompt: Union[str, None]) -> Union[str, None]: + if prompt: + return f'[system](#additional_instructions)\n{config.BING_PROMPT.get(prompt, prompt)}' + return None + def filterAnswer(answer: str) -> str: answer = re.sub(r'\[\^.*?\^]', '', answer) return answer @@ -96,22 +73,29 @@ def needReset(data: dict, answer: str) -> bool: return True return False -@BingAPP.route('/ask', methods=['GET', 'POST']) +@BING_APP.route('/ask', methods=['GET', 'POST']) async def ask(request: Request) -> Response: parameter = await core.getrequestParameter(request) - style = parameter.get('style') question = parameter.get('question') + style = parameter.get('style') or 'balanced' + prompt = parameter.get('prompt') or '' token = parameter.get('token') - if not style or not question: + if not question: return core.GenerateResponse().error(110, '参数不能为空') elif style not in STYLES: return core.GenerateResponse().error(110, 'style不存在') - - token, chatBot = getChatBot(token) - if not chatBot: - return core.GenerateResponse().error(120, 'token不存在') - data = await chatBot.ask(question, conversation_style=getStyleEnum(style)) + elif prompt and not auxiliary.isEnglish(prompt): + return core.GenerateResponse().error(110, 'prompt仅支持英文') + if token: + chatBot = chat_bot.getChatBot(token) + if not chatBot: + return core.GenerateResponse().error(120, 'token不存在') + chatBot = chatBot['chatBot'] + else: + token, chatBot = chat_bot.generateChatBot('Bing') + + data = await chatBot.ask(question, conversation_style=getStyleEnum(style), webpage_context=getPrompt(prompt)) if data['item']['result']['value'] == 'Throttled': return core.GenerateResponse().error(120, '已上限,24小时后尝试') @@ -134,20 +118,27 @@ async def ask(request: Request) -> Response: return core.GenerateResponse().success(info) -@BingAPP.route('/ask_stream', methods=['GET', 'POST']) +@BING_APP.route('/ask_stream', methods=['GET', 'POST']) async def askStream(request: Request) -> Response: parameter = await core.getrequestParameter(request) - style = parameter.get('style') question = parameter.get('question') + style = parameter.get('style') or 'balanced' + prompt = parameter.get('prompt') or '' token = parameter.get('token') - if not style or not question: + if not question: return core.GenerateResponse().error(110, '参数不能为空') elif style not in STYLES: return core.GenerateResponse().error(110, 'style不存在') - - token, chatBot = getChatBot(token) - if not chatBot: - return core.GenerateResponse().error(120, 'token不存在') + elif prompt and not auxiliary.isEnglish(prompt): + return core.GenerateResponse().error(110, 'prompt仅支持英文') + + if token: + chatBot = chat_bot.getChatBot(token) + if not chatBot: + return core.GenerateResponse().error(120, 'token不存在') + chatBot = chatBot['chatBot'] + else: + token, chatBot = chat_bot.generateChatBot('Bing') async def generator() -> AsyncGenerator: index = 0 @@ -159,7 +150,7 @@ async def generator() -> AsyncGenerator: 'reset': False, 'token': token } - async for final, data in chatBot.ask_stream(question, conversation_style=getStyleEnum(style)): + async for final, data in chatBot.ask_stream(question, conversation_style=getStyleEnum(style), webpage_context=getPrompt(prompt)): if not final: answer = data[index:] index = len(data) @@ -171,8 +162,9 @@ async def generator() -> AsyncGenerator: if data['item']['result']['value'] == 'Throttled': yield core.GenerateResponse().error(120, '已上限,24小时后尝试', True) break - - info['answer'] = getAnswer(data) + + answer = getAnswer(data) + info['answer'] = answer info['suggests'] = getSuggest(data) info['urls'] = getUrl(data) info['done'] = True @@ -185,13 +177,13 @@ async def generator() -> AsyncGenerator: return StreamingResponse(generator(), media_type='text/event-stream') -@BingAPP.route('/image', methods=['GET', 'POST']) +@BING_APP.route('/image', methods=['GET', 'POST']) async def image(request: Request) -> Response: keyword = (await core.getrequestParameter(request)).get('keyword') if not keyword: return core.GenerateResponse().error(110, '参数不能为空') elif not auxiliary.isEnglish(keyword): - return core.GenerateResponse().error(110, '仅支持英文') + return core.GenerateResponse().error(110, 'keyword仅支持英文') cookie = auxiliary.getCookie('./cookie/bing.json', ['_U'])['_U'] return core.GenerateResponse().success(BingImageCreator.ImageGen(cookie).get_images(keyword)) \ No newline at end of file