diff --git a/config.json b/config.json index 7b680c23..bd484ad2 100644 --- a/config.json +++ b/config.json @@ -716,6 +716,20 @@ "beam_size": 5 } }, + "image_recognition": { + "enable": true, + "screenshot_window_title": "任务管理器", + "img_save_path": "./out/图像识别", + "prompt": "请讲解一下图片里的内容", + "screenshot_delay": 3.0, + "gemini": { + "enable": true, + "model": "gemini-pro-vision", + "api_key": "", + "http_proxy": "http://127.0.0.1:10809", + "https_proxy": "http://127.0.0.1:10809" + } + }, "captions": { "enable": true, "file_path": "log/字幕.txt", diff --git a/config.json.bak b/config.json.bak index 7b680c23..bd484ad2 100644 --- a/config.json.bak +++ b/config.json.bak @@ -716,6 +716,20 @@ "beam_size": 5 } }, + "image_recognition": { + "enable": true, + "screenshot_window_title": "任务管理器", + "img_save_path": "./out/图像识别", + "prompt": "请讲解一下图片里的内容", + "screenshot_delay": 3.0, + "gemini": { + "enable": true, + "model": "gemini-pro-vision", + "api_key": "", + "http_proxy": "http://127.0.0.1:10809", + "https_proxy": "http://127.0.0.1:10809" + } + }, "captions": { "enable": true, "file_path": "log/字幕.txt", diff --git a/tests/test_common/get_screenshot.py b/tests/test_common/get_screenshot.py new file mode 100644 index 00000000..08e40e53 --- /dev/null +++ b/tests/test_common/get_screenshot.py @@ -0,0 +1,45 @@ +import pygetwindow as gw +import pyautogui + +def capture_window_by_title(window_title): + try: + # 使用窗口标题查找窗口 + win = gw.getWindowsWithTitle(window_title)[0] # 获取第一个匹配的窗口 + if win: + # 获取窗口的位置和大小 + left, top = win.left, win.top + width, height = win.width, win.height + + # 使用pyautogui捕获指定区域的截图 + screenshot = pyautogui.screenshot(region=(left, top, width, height)) + screenshot.save(f'{window_title}.png') + print(f"截图已保存为 {window_title}.png") + else: + print("未找到指定的窗口") + except IndexError: + print("未找到指定的窗口") + + +# 获取所有有标题的窗口对象 +def list_visible_windows(): + """获取所有有标题的窗口对象 + + Returns: + list: 获取所有有标题的窗口名列表 + """ + windows = gw.getWindowsWithTitle('') + + window_titles = [] + + # 打印每个窗口的标题 + for win in windows: + if win.title: # 确保窗口有标题 + window_titles.append(win.title) + + return window_titles + +# 调用函数,列出所有可见窗口的标题 +list_visible_windows() + +# 调用函数,替换"Your Window Title Here"为你想要捕获的窗口的标题 +capture_window_by_title("伊卡酱 fans群等3个会话") diff --git a/tests/test_gemini/1.png b/tests/test_gemini/1.png new file mode 100644 index 00000000..b06ef07c Binary files /dev/null and b/tests/test_gemini/1.png differ diff --git a/tests/test_gemini/api.py b/tests/test_gemini/api.py index 189948e0..1a0c7cd4 100644 --- a/tests/test_gemini/api.py +++ b/tests/test_gemini/api.py @@ -23,6 +23,39 @@ def list_models(self): if 'generateContent' in m.supported_generation_methods: logging.info(m.name) + def get_resp_with_img(self, prompt, img_data): + try: + import PIL.Image + + # 检查 img_data 的类型 + if isinstance(img_data, str): # 如果是字符串,假定为文件路径 + # 使用 PIL.Image.open() 打开图片文件 + img = PIL.Image.open(img_data) + elif isinstance(img_data, PIL.Image.Image): # 如果已经是 PIL.Image.Image 对象 + # 直接返回这个图像对象 + img = img_data + else: + img = img_data + + model = genai.GenerativeModel('gemini-pro-vision') + + response = model.generate_content( + [ + prompt, + img + ], + stream=False + ) + + resp_content = response.text.strip() + + logging.debug(f"resp_content={resp_content}") + + return resp_content + except Exception as e: + logging.error(traceback.format_exc()) + return None + def get_resp(self, prompt): """请求对应接口,获取返回值 @@ -106,7 +139,9 @@ def get_resp(self, prompt): gemini = Gemini(data) - logging.info(gemini.get_resp("你可以扮演猫娘吗,每句话后面加个喵")) - logging.info(gemini.get_resp("早上好")) - logging.info(gemini.get_resp("我的眼睛好酸")) + # logging.info(gemini.get_resp("你可以扮演猫娘吗,每句话后面加个喵")) + # logging.info(gemini.get_resp("早上好")) + # logging.info(gemini.get_resp("我的眼睛好酸")) + + logging.info(gemini.get_resp_with_img("根据图片内容,猜猜我吃的什么", "1.png")) \ No newline at end of file diff --git a/utils/common.py b/utils/common.py index 62be6f2e..3876bd77 100644 --- a/utils/common.py +++ b/utils/common.py @@ -727,6 +727,12 @@ def get_live2d_model_name(self, path): ,@/[ .[\@` =@@@ @@@@ \@@@` ,` @@@^ .[ =@@@. @@@^ """ + def ensure_directory_exists(self, path): + # 检查路径是否存在 + if not os.path.exists(path): + # 如果路径不存在,创建它 + os.makedirs(path) + logging.info(f"路径已创建:{path}") # 写入内容到指定文件中 返回T/F def write_content_to_file(self, file_path, content, write_log=True): @@ -1061,3 +1067,78 @@ def check_useful(data_json): return False return check_useful(data_json) + + + """ + 图像操作 + """ + # 获取所有有标题的窗口对象 + def list_visible_windows(self): + """获取所有有标题的窗口对象 + + Returns: + list: 获取所有有标题的窗口名列表 + """ + import pygetwindow as gw + + windows = gw.getWindowsWithTitle('') + + window_titles = [] + + # 打印每个窗口的标题 + for win in windows: + if win.title: # 确保窗口有标题 + window_titles.append(win.title) + + return window_titles + + + + def capture_window_by_title(self, img_save_path: str, window_title: str): + """根据窗口名截图(截图窗口不能被遮挡,必须前置窗口) + + Args: + img_save_path (str): 图片保存路径 + window_title (str): 窗口标题 + + Returns: + str: 图片保存路径含文件名 + """ + try: + import pygetwindow as gw + import pyautogui + + # 使用窗口标题查找窗口 + win = gw.getWindowsWithTitle(window_title)[0] # 获取第一个匹配的窗口 + if win: + # 获取窗口的位置和大小 + left, top = win.left, win.top + width, height = win.width, win.height + + # 使用pyautogui捕获指定区域的截图 + screenshot = pyautogui.screenshot(region=(left, top, width, height)) + + # 判断路径存在,不存在就创建 + self.ensure_directory_exists(img_save_path) + + # logging.debug(f"img_save_path={img_save_path}") + destination_directory = os.path.abspath(img_save_path) + logging.debug(f"destination_directory={destination_directory}") + + # 获取图片路径含文件名 + destination_path = os.path.join(destination_directory, f"{window_title}.png") + logging.debug(f"destination_path={destination_path}") + + screenshot.save(destination_path) + + logging.info(f"截图已保存到:{destination_path}") + + return destination_path + else: + logging.error(f"未找到指定的窗口:{window_title}") + except IndexError: + logging.error(f"未找到指定的窗口:{window_title}") + except Exception as e: + logging.error(traceback.format_exc()) + + return None \ No newline at end of file diff --git a/utils/gpt_model/gemini.py b/utils/gpt_model/gemini.py index 189948e0..35ab8218 100644 --- a/utils/gpt_model/gemini.py +++ b/utils/gpt_model/gemini.py @@ -82,6 +82,39 @@ def get_resp(self, prompt): logging.error(traceback.format_exc()) return None + def get_resp_with_img(self, prompt, img_data): + try: + import PIL.Image + + # 检查 img_data 的类型 + if isinstance(img_data, str): # 如果是字符串,假定为文件路径 + # 使用 PIL.Image.open() 打开图片文件 + img = PIL.Image.open(img_data) + elif isinstance(img_data, PIL.Image.Image): # 如果已经是 PIL.Image.Image 对象 + # 直接返回这个图像对象 + img = img_data + else: + img = img_data + + model = genai.GenerativeModel('gemini-pro-vision') + + response = model.generate_content( + [ + prompt, + img + ], + stream=False + ) + + resp_content = response.text.strip() + + logging.debug(f"resp_content={resp_content}") + + return resp_content + except Exception as e: + logging.error(traceback.format_exc()) + return None + if __name__ == '__main__': # 配置日志输出格式 @@ -109,4 +142,7 @@ def get_resp(self, prompt): logging.info(gemini.get_resp("你可以扮演猫娘吗,每句话后面加个喵")) logging.info(gemini.get_resp("早上好")) logging.info(gemini.get_resp("我的眼睛好酸")) + + logging.info(gemini.get_resp_with_img("根据图片内容,猜猜我吃的什么", "1.png")) + \ No newline at end of file diff --git a/webui.py b/webui.py index 970941c3..6516baa1 100644 --- a/webui.py +++ b/webui.py @@ -1590,6 +1590,22 @@ def common_textarea_handle(content): config_data["talk"]["faster_whisper"]["download_root"] = input_faster_whisper_download_root.value config_data["talk"]["faster_whisper"]["beam_size"] = int(input_faster_whisper_beam_size.value) + """ + 图像识别 + """ + if True: + config_data["image_recognition"]["enable"] = button_image_recognition_enable.value + config_data["image_recognition"]["screenshot_window_title"] = select_image_recognition_screenshot_window_title.value + config_data["image_recognition"]["img_save_path"] = input_image_recognition_img_save_path.value + config_data["image_recognition"]["prompt"] = input_image_recognition_prompt.value + config_data["image_recognition"]["screenshot_delay"] = float(input_image_recognition_screenshot_delay.value) + + config_data["image_recognition"]["gemini"]["enable"] = switch_image_recognition_gemini_enable.value + config_data["image_recognition"]["gemini"]["model"] = select_image_recognition_gemini_model.value + config_data["image_recognition"]["gemini"]["api_key"] = input_image_recognition_gemini_api_key.value + config_data["image_recognition"]["gemini"]["http_proxy"] = input_image_recognition_gemini_http_proxy.value + config_data["image_recognition"]["gemini"]["https_proxy"] = input_image_recognition_gemini_https_proxy.value + """ 助播 """ @@ -1832,8 +1848,9 @@ def common_textarea_handle(content): svc_page = ui.tab('变声') visual_body_page = ui.tab('虚拟身体') copywriting_page = ui.tab('文案') - integral_page = ui.tab('积分') talk_page = ui.tab('聊天') + image_recognition_page = ui.tab('图像识别') + integral_page = ui.tab('积分') assistant_anchor_page = ui.tab('助播') translate_page = ui.tab('翻译') data_analysis_page = ui.tab('数据分析') @@ -3658,6 +3675,69 @@ def talk_chat_box_tuning(): button_talk_chat_box_tuning = ui.button('调教', on_click=lambda: talk_chat_box_tuning(), color=button_internal_color).style(button_internal_css) button_talk_chat_box_reread_first = ui.button('直接复读-插队首', on_click=lambda: talk_chat_box_reread(0), color=button_internal_color).style(button_internal_css) + with ui.tab_panel(image_recognition_page).style(tab_panel_css): + with ui.card().style(card_css): + ui.label("通用") + with ui.row(): + button_image_recognition_enable = ui.switch('启用', value=config.get("image_recognition", "enable")).style(switch_internal_css) + window_titles = common.list_visible_windows() + data_json = {} + for line in window_titles: + data_json[line] = line + select_image_recognition_screenshot_window_title = ui.select( + label='截图窗口标题', + options=data_json, + value=config.get("image_recognition", "screenshot_window_title") + ).style("width:300px") + input_image_recognition_img_save_path = ui.input(label='截图保存路径', value=config.get("image_recognition", "img_save_path"), placeholder='截图保存路径,支持绝对或相对路径') + input_image_recognition_prompt = ui.input(label='携带的提示词', value=config.get("image_recognition", "prompt"), placeholder='图片识别时附带的提示词,协同图片获取回答') + input_image_recognition_screenshot_delay = ui.input(label='N秒后进行截图', value=config.get("image_recognition", "screenshot_delay"), placeholder='截图延迟,方便用户打开对应窗口').style("width:100px") + + async def image_recognition_screenshot_and_send(): + global running_flag + + if running_flag != 1: + ui.notify(position="top", type="warning", message="请先点击“一键运行”,然后再进行截图识别") + return + + logging.info(f"{input_image_recognition_screenshot_delay.value}后触发截图识别") + ui.notify(position="top", type="positive", message=f"{input_image_recognition_screenshot_delay.value}后触发截图识别") + + await asyncio.sleep(float(input_image_recognition_screenshot_delay.value)) + + # 根据窗口名截图 + screenshot_path = common.capture_window_by_title(input_image_recognition_img_save_path.value, select_image_recognition_screenshot_window_title.value) + + from utils.gpt_model.gemini import Gemini + + gemini = Gemini(config.get("image_recognition", "gemini")) + + resp_content = gemini.get_resp_with_img(config.get("image_recognition", "prompt"), screenshot_path) + + data = { + "type": "reread", + "username": config.get("talk", "username"), + "content": resp_content, + "insert_index": -1 + } + + common.send_request(f'http://{config.get("api_ip")}:{config.get("api_port")}/send', "POST", data) + + button_image_recognition_screenshot_and_send = ui.button('截图并发送', on_click=lambda: image_recognition_screenshot_and_send(), color=button_internal_color).style(button_internal_css) + + with ui.card().style(card_css): + ui.label("Gemini") + with ui.row(): + switch_image_recognition_gemini_enable = ui.switch('启用', value=config.get("image_recognition", "gemini", "enable")).style(switch_internal_css) + select_image_recognition_gemini_model = ui.select( + label='模型', + options={'gemini-pro-vision': 'gemini-pro-vision'}, + value=config.get("image_recognition", "gemini", "model") + ).style("width:150px") + input_image_recognition_gemini_api_key = ui.input(label='API Key', value=config.get("image_recognition", "gemini", "api_key"), placeholder='Gemini API KEY') + input_image_recognition_gemini_http_proxy = ui.input(label='HTTP代理地址', value=config.get("image_recognition", "gemini", "http_proxy"), placeholder='http代理地址,需要魔法才能使用,所以需要配置此项。').style("width:200px;") + input_image_recognition_gemini_https_proxy = ui.input(label='HTTPS代理地址', value=config.get("image_recognition", "gemini", "https_proxy"), placeholder='https代理地址,需要魔法才能使用,所以需要配置此项。').style("width:200px;") + with ui.tab_panel(assistant_anchor_page).style(tab_panel_css): with ui.row(): switch_assistant_anchor_enable = ui.switch('启用', value=config.get("assistant_anchor", "enable")).style(switch_internal_css)