Skip to content

Commit

Permalink
Merge pull request #701 from Ikaros-521/owner
Browse files Browse the repository at this point in the history
新增 图像识别 板块,使用多模态模型(Gemini)的能力实现,暂时仅支持n秒后截图识别内容,未来会继续拓展
  • Loading branch information
Ikaros-521 authored Mar 15, 2024
2 parents 55fbd42 + a6d6e46 commit bd2a1f1
Show file tree
Hide file tree
Showing 8 changed files with 309 additions and 4 deletions.
14 changes: 14 additions & 0 deletions config.json
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,20 @@
"beam_size": 5
}
},
"image_recognition": {
"enable": true,
"screenshot_window_title": "任务管理器",
"img_save_path": "./out/图像识别",
"prompt": "请讲解一下图片里的内容",
"screenshot_delay": 3.0,
"gemini": {
"enable": true,
"model": "gemini-pro-vision",
"api_key": "",
"http_proxy": "http://127.0.0.1:10809",
"https_proxy": "http://127.0.0.1:10809"
}
},
"captions": {
"enable": true,
"file_path": "log/字幕.txt",
Expand Down
14 changes: 14 additions & 0 deletions config.json.bak
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,20 @@
"beam_size": 5
}
},
"image_recognition": {
"enable": true,
"screenshot_window_title": "任务管理器",
"img_save_path": "./out/图像识别",
"prompt": "请讲解一下图片里的内容",
"screenshot_delay": 3.0,
"gemini": {
"enable": true,
"model": "gemini-pro-vision",
"api_key": "",
"http_proxy": "http://127.0.0.1:10809",
"https_proxy": "http://127.0.0.1:10809"
}
},
"captions": {
"enable": true,
"file_path": "log/字幕.txt",
Expand Down
45 changes: 45 additions & 0 deletions tests/test_common/get_screenshot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import pygetwindow as gw
import pyautogui

def capture_window_by_title(window_title):
try:
# 使用窗口标题查找窗口
win = gw.getWindowsWithTitle(window_title)[0] # 获取第一个匹配的窗口
if win:
# 获取窗口的位置和大小
left, top = win.left, win.top
width, height = win.width, win.height

# 使用pyautogui捕获指定区域的截图
screenshot = pyautogui.screenshot(region=(left, top, width, height))
screenshot.save(f'{window_title}.png')
print(f"截图已保存为 {window_title}.png")
else:
print("未找到指定的窗口")
except IndexError:
print("未找到指定的窗口")


# 获取所有有标题的窗口对象
def list_visible_windows():
"""获取所有有标题的窗口对象
Returns:
list: 获取所有有标题的窗口名列表
"""
windows = gw.getWindowsWithTitle('')

window_titles = []

# 打印每个窗口的标题
for win in windows:
if win.title: # 确保窗口有标题
window_titles.append(win.title)

return window_titles

# 调用函数,列出所有可见窗口的标题
list_visible_windows()

# 调用函数,替换"Your Window Title Here"为你想要捕获的窗口的标题
capture_window_by_title("伊卡酱 fans群等3个会话")
Binary file added tests/test_gemini/1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
41 changes: 38 additions & 3 deletions tests/test_gemini/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,39 @@ def list_models(self):
if 'generateContent' in m.supported_generation_methods:
logging.info(m.name)

def get_resp_with_img(self, prompt, img_data):
try:
import PIL.Image

# 检查 img_data 的类型
if isinstance(img_data, str): # 如果是字符串,假定为文件路径
# 使用 PIL.Image.open() 打开图片文件
img = PIL.Image.open(img_data)
elif isinstance(img_data, PIL.Image.Image): # 如果已经是 PIL.Image.Image 对象
# 直接返回这个图像对象
img = img_data
else:
img = img_data

model = genai.GenerativeModel('gemini-pro-vision')

response = model.generate_content(
[
prompt,
img
],
stream=False
)

resp_content = response.text.strip()

logging.debug(f"resp_content={resp_content}")

return resp_content
except Exception as e:
logging.error(traceback.format_exc())
return None

def get_resp(self, prompt):
"""请求对应接口,获取返回值
Expand Down Expand Up @@ -106,7 +139,9 @@ def get_resp(self, prompt):

gemini = Gemini(data)

logging.info(gemini.get_resp("你可以扮演猫娘吗,每句话后面加个喵"))
logging.info(gemini.get_resp("早上好"))
logging.info(gemini.get_resp("我的眼睛好酸"))
# logging.info(gemini.get_resp("你可以扮演猫娘吗,每句话后面加个喵"))
# logging.info(gemini.get_resp("早上好"))
# logging.info(gemini.get_resp("我的眼睛好酸"))

logging.info(gemini.get_resp_with_img("根据图片内容,猜猜我吃的什么", "1.png"))

81 changes: 81 additions & 0 deletions utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,12 @@ def get_live2d_model_name(self, path):
,@/[ .[\@` =@@@ @@@@ \@@@` ,` @@@^ .[ =@@@. @@@^
"""
def ensure_directory_exists(self, path):
# 检查路径是否存在
if not os.path.exists(path):
# 如果路径不存在,创建它
os.makedirs(path)
logging.info(f"路径已创建:{path}")

# 写入内容到指定文件中 返回T/F
def write_content_to_file(self, file_path, content, write_log=True):
Expand Down Expand Up @@ -1061,3 +1067,78 @@ def check_useful(data_json):
return False

return check_useful(data_json)


"""
图像操作
"""
# 获取所有有标题的窗口对象
def list_visible_windows(self):
"""获取所有有标题的窗口对象
Returns:
list: 获取所有有标题的窗口名列表
"""
import pygetwindow as gw

windows = gw.getWindowsWithTitle('')

window_titles = []

# 打印每个窗口的标题
for win in windows:
if win.title: # 确保窗口有标题
window_titles.append(win.title)

return window_titles



def capture_window_by_title(self, img_save_path: str, window_title: str):
"""根据窗口名截图(截图窗口不能被遮挡,必须前置窗口)
Args:
img_save_path (str): 图片保存路径
window_title (str): 窗口标题
Returns:
str: 图片保存路径含文件名
"""
try:
import pygetwindow as gw
import pyautogui

# 使用窗口标题查找窗口
win = gw.getWindowsWithTitle(window_title)[0] # 获取第一个匹配的窗口
if win:
# 获取窗口的位置和大小
left, top = win.left, win.top
width, height = win.width, win.height

# 使用pyautogui捕获指定区域的截图
screenshot = pyautogui.screenshot(region=(left, top, width, height))

# 判断路径存在,不存在就创建
self.ensure_directory_exists(img_save_path)

# logging.debug(f"img_save_path={img_save_path}")
destination_directory = os.path.abspath(img_save_path)
logging.debug(f"destination_directory={destination_directory}")

# 获取图片路径含文件名
destination_path = os.path.join(destination_directory, f"{window_title}.png")
logging.debug(f"destination_path={destination_path}")

screenshot.save(destination_path)

logging.info(f"截图已保存到:{destination_path}")

return destination_path
else:
logging.error(f"未找到指定的窗口:{window_title}")
except IndexError:
logging.error(f"未找到指定的窗口:{window_title}")
except Exception as e:
logging.error(traceback.format_exc())

return None
36 changes: 36 additions & 0 deletions utils/gpt_model/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,39 @@ def get_resp(self, prompt):
logging.error(traceback.format_exc())
return None

def get_resp_with_img(self, prompt, img_data):
try:
import PIL.Image

# 检查 img_data 的类型
if isinstance(img_data, str): # 如果是字符串,假定为文件路径
# 使用 PIL.Image.open() 打开图片文件
img = PIL.Image.open(img_data)
elif isinstance(img_data, PIL.Image.Image): # 如果已经是 PIL.Image.Image 对象
# 直接返回这个图像对象
img = img_data
else:
img = img_data

model = genai.GenerativeModel('gemini-pro-vision')

response = model.generate_content(
[
prompt,
img
],
stream=False
)

resp_content = response.text.strip()

logging.debug(f"resp_content={resp_content}")

return resp_content
except Exception as e:
logging.error(traceback.format_exc())
return None


if __name__ == '__main__':
# 配置日志输出格式
Expand Down Expand Up @@ -109,4 +142,7 @@ def get_resp(self, prompt):
logging.info(gemini.get_resp("你可以扮演猫娘吗,每句话后面加个喵"))
logging.info(gemini.get_resp("早上好"))
logging.info(gemini.get_resp("我的眼睛好酸"))

logging.info(gemini.get_resp_with_img("根据图片内容,猜猜我吃的什么", "1.png"))


82 changes: 81 additions & 1 deletion webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -1590,6 +1590,22 @@ def common_textarea_handle(content):
config_data["talk"]["faster_whisper"]["download_root"] = input_faster_whisper_download_root.value
config_data["talk"]["faster_whisper"]["beam_size"] = int(input_faster_whisper_beam_size.value)

"""
图像识别
"""
if True:
config_data["image_recognition"]["enable"] = button_image_recognition_enable.value
config_data["image_recognition"]["screenshot_window_title"] = select_image_recognition_screenshot_window_title.value
config_data["image_recognition"]["img_save_path"] = input_image_recognition_img_save_path.value
config_data["image_recognition"]["prompt"] = input_image_recognition_prompt.value
config_data["image_recognition"]["screenshot_delay"] = float(input_image_recognition_screenshot_delay.value)

config_data["image_recognition"]["gemini"]["enable"] = switch_image_recognition_gemini_enable.value
config_data["image_recognition"]["gemini"]["model"] = select_image_recognition_gemini_model.value
config_data["image_recognition"]["gemini"]["api_key"] = input_image_recognition_gemini_api_key.value
config_data["image_recognition"]["gemini"]["http_proxy"] = input_image_recognition_gemini_http_proxy.value
config_data["image_recognition"]["gemini"]["https_proxy"] = input_image_recognition_gemini_https_proxy.value

"""
助播
"""
Expand Down Expand Up @@ -1832,8 +1848,9 @@ def common_textarea_handle(content):
svc_page = ui.tab('变声')
visual_body_page = ui.tab('虚拟身体')
copywriting_page = ui.tab('文案')
integral_page = ui.tab('积分')
talk_page = ui.tab('聊天')
image_recognition_page = ui.tab('图像识别')
integral_page = ui.tab('积分')
assistant_anchor_page = ui.tab('助播')
translate_page = ui.tab('翻译')
data_analysis_page = ui.tab('数据分析')
Expand Down Expand Up @@ -3658,6 +3675,69 @@ def talk_chat_box_tuning():
button_talk_chat_box_tuning = ui.button('调教', on_click=lambda: talk_chat_box_tuning(), color=button_internal_color).style(button_internal_css)
button_talk_chat_box_reread_first = ui.button('直接复读-插队首', on_click=lambda: talk_chat_box_reread(0), color=button_internal_color).style(button_internal_css)

with ui.tab_panel(image_recognition_page).style(tab_panel_css):
with ui.card().style(card_css):
ui.label("通用")
with ui.row():
button_image_recognition_enable = ui.switch('启用', value=config.get("image_recognition", "enable")).style(switch_internal_css)
window_titles = common.list_visible_windows()
data_json = {}
for line in window_titles:
data_json[line] = line
select_image_recognition_screenshot_window_title = ui.select(
label='截图窗口标题',
options=data_json,
value=config.get("image_recognition", "screenshot_window_title")
).style("width:300px")
input_image_recognition_img_save_path = ui.input(label='截图保存路径', value=config.get("image_recognition", "img_save_path"), placeholder='截图保存路径,支持绝对或相对路径')
input_image_recognition_prompt = ui.input(label='携带的提示词', value=config.get("image_recognition", "prompt"), placeholder='图片识别时附带的提示词,协同图片获取回答')
input_image_recognition_screenshot_delay = ui.input(label='N秒后进行截图', value=config.get("image_recognition", "screenshot_delay"), placeholder='截图延迟,方便用户打开对应窗口').style("width:100px")

async def image_recognition_screenshot_and_send():
global running_flag

if running_flag != 1:
ui.notify(position="top", type="warning", message="请先点击“一键运行”,然后再进行截图识别")
return

logging.info(f"{input_image_recognition_screenshot_delay.value}后触发截图识别")
ui.notify(position="top", type="positive", message=f"{input_image_recognition_screenshot_delay.value}后触发截图识别")

await asyncio.sleep(float(input_image_recognition_screenshot_delay.value))

# 根据窗口名截图
screenshot_path = common.capture_window_by_title(input_image_recognition_img_save_path.value, select_image_recognition_screenshot_window_title.value)

from utils.gpt_model.gemini import Gemini

gemini = Gemini(config.get("image_recognition", "gemini"))

resp_content = gemini.get_resp_with_img(config.get("image_recognition", "prompt"), screenshot_path)

data = {
"type": "reread",
"username": config.get("talk", "username"),
"content": resp_content,
"insert_index": -1
}

common.send_request(f'http://{config.get("api_ip")}:{config.get("api_port")}/send', "POST", data)

button_image_recognition_screenshot_and_send = ui.button('截图并发送', on_click=lambda: image_recognition_screenshot_and_send(), color=button_internal_color).style(button_internal_css)

with ui.card().style(card_css):
ui.label("Gemini")
with ui.row():
switch_image_recognition_gemini_enable = ui.switch('启用', value=config.get("image_recognition", "gemini", "enable")).style(switch_internal_css)
select_image_recognition_gemini_model = ui.select(
label='模型',
options={'gemini-pro-vision': 'gemini-pro-vision'},
value=config.get("image_recognition", "gemini", "model")
).style("width:150px")
input_image_recognition_gemini_api_key = ui.input(label='API Key', value=config.get("image_recognition", "gemini", "api_key"), placeholder='Gemini API KEY')
input_image_recognition_gemini_http_proxy = ui.input(label='HTTP代理地址', value=config.get("image_recognition", "gemini", "http_proxy"), placeholder='http代理地址,需要魔法才能使用,所以需要配置此项。').style("width:200px;")
input_image_recognition_gemini_https_proxy = ui.input(label='HTTPS代理地址', value=config.get("image_recognition", "gemini", "https_proxy"), placeholder='https代理地址,需要魔法才能使用,所以需要配置此项。').style("width:200px;")

with ui.tab_panel(assistant_anchor_page).style(tab_panel_css):
with ui.row():
switch_assistant_anchor_enable = ui.switch('启用', value=config.get("assistant_anchor", "enable")).style(switch_internal_css)
Expand Down

0 comments on commit bd2a1f1

Please sign in to comment.