From f35d30c17e46bd9eebced0db3687a5eeaded02fe Mon Sep 17 00:00:00 2001 From: XiaoXinYo <1104361313@qq.com> Date: Mon, 29 May 2023 17:57:08 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E5=AF=B9cookie=E5=92=8Ckey?= =?UTF-8?q?=E7=9A=84=E6=A3=80=E6=B5=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 43 ++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 40 insertions(+), 3 deletions(-) diff --git a/main.py b/main.py index 237805c..304cd5c 100644 --- a/main.py +++ b/main.py @@ -17,15 +17,45 @@ allow_methods=['*'], allow_headers=['*'], ) -APP.include_router(chatgpt.CHATGPT_APP, prefix='/chatgpt') -APP.include_router(bing.BING_APP, prefix='/bing') APP.include_router(bard.Bard_APP, prefix='/bard') +APP.include_router(bing.BING_APP, prefix='/bing') +APP.include_router(chatgpt.CHATGPT_APP, prefix='/chatgpt') APP.include_router(ernie.ERNIE_APP, prefix='/ernie') @APP.on_event('startup') async def startup() -> None: asyncio.get_event_loop().create_task(chat_bot.checkChatBot()) +@APP.on_event('shutdown') +async def shutdown() -> None: + asyncio.get_event_loop().create_task(chat_bot.checkChatBot(False)) + +@APP.middleware('http') +async def middleware(request: Request, call_next) -> None: + urls = request.url.path.split('/') + if len(urls) == 3: + model = urls[1] + mode = urls[2] + if mode == 'ask': + generate = lambda model_: core.GenerateResponse().error(110, f'{model_}未配置') + else: + generate = lambda model_: core.GenerateResponse().error(110, f'{model_}未配置', streamResponse=True) + if model == 'bard': + if not chat_bot.BARD_COOKIE: + return generate('Bard') + elif model == 'bing': + if not chat_bot.BING_COOKIE: + return generate('Bing') + elif model == 'chatgpt': + if not config.CHATGPT_KEY: + return generate('ChatGPT') + elif model == 'ernie': + if not chat_bot.ERNIE_COOKIE: + return generate('文心一言') + + response = await call_next(request) + return response + @APP.exception_handler(404) def error404(request: Request, exc: Exception) -> Response: return core.GenerateResponse().error(404, '未找到文件') @@ -35,4 +65,11 @@ def error500(request: Request, exc: Exception) -> Response: return core.GenerateResponse().error(500, '未知错误') if __name__ == '__main__': - uvicorn.run(APP, host=config.HOST, port=config.PORT) \ No newline at end of file + appConfig = { + 'host': config.HOST, + 'port': config.PORT, + } + if config.SSL['enable']: + uvicorn.run(APP, **appConfig, ssl_keyfile=config.SSL['keyPath'], ssl_certfile=config.SSL['certPath']) + else: + uvicorn.run(APP, **appConfig) \ No newline at end of file