diff --git a/PROJECT_STRUCTURE.md b/PROJECT_STRUCTURE.md index 8456dc3..9fb5b7f 100644 --- a/PROJECT_STRUCTURE.md +++ b/PROJECT_STRUCTURE.md @@ -42,11 +42,13 @@ ultralytics-main/ ## 🚀 使用方法 ### 1. 运行GUI应用(推荐) + ```bash python scripts/run_gui.py ``` ### 2. 测试模型 + ```bash # 测试图片 python scripts/test_model.py --source image.jpg @@ -59,11 +61,13 @@ python scripts/test_model.py --source 0 ``` ### 3. 训练模型 + ```bash python scripts/train_model.py ``` ### 4. 验证模型 + ```bash python scripts/validate_model.py -``` \ No newline at end of file +``` diff --git a/UI/README.md b/UI/README.md index bbb8b01..075b2e6 100644 --- a/UI/README.md +++ b/UI/README.md @@ -23,6 +23,7 @@ ### 方式一:使用运行脚本(推荐) 在项目根目录运行: + ```bash python run_gui.py ``` @@ -36,52 +37,64 @@ python -m hys.main ### 方式三:在代码中导入 ```python -from hys import YOLODetectionGUI, main import tkinter as tk +from hys import YOLODetectionGUI + root = tk.Tk() app = YOLODetectionGUI(root) root.mainloop() ``` - ## 模块说明 ### config.py + 包含所有配置常量: + - `MODEL_PATH` - YOLO模型路径 - `WINDOW_TITLE` - 窗口标题 - `WINDOW_SIZE` - 窗口大小 - `DND_AVAILABLE` - 拖拽功能是否可用 ### gui_main.py + 主窗口类 `YOLODetectionGUI`,负责: + - 模式选择界面 - 检测界面管理 - 文件选择和处理 - 结果保存 ### detection_processor.py + 检测处理器 `DetectionProcessor`,负责: + - 屏幕检测 - 摄像头检测 - 文件检测(图片/视频) ### file_handler.py + 文件处理类 `FileHandler`,提供: + - 文件选择 - 文件类型判断 - 图片预览加载 - 检测结果保存 ### gui_utils.py + 线程安全的GUI更新器 `ThreadSafeGUIUpdater`,确保: + - 所有GUI更新在主线程执行 - 避免线程安全问题 - 安全的窗口关闭处理 ### detection_ui.py + 检测界面UI组件创建类 `DetectionUI`,提供: + - 统一的检测界面创建 - 文件检测特殊UI组件 @@ -100,4 +113,3 @@ root.mainloop() ```bash pip install tkinterdnd2 ``` - diff --git a/UI/__init__.py b/UI/__init__.py index a29346b..ef05f0d 100644 --- a/UI/__init__.py +++ b/UI/__init__.py @@ -1,8 +1,6 @@ -""" -YOLO目标检测GUI应用程序包 -""" +"""YOLO目标检测GUI应用程序包.""" + from .gui_main import YOLODetectionGUI from .main import main -__all__ = ['YOLODetectionGUI', 'main'] - +__all__ = ["YOLODetectionGUI", "main"] diff --git a/UI/config.py b/UI/config.py index fa13f93..d4ddffc 100644 --- a/UI/config.py +++ b/UI/config.py @@ -1,18 +1,20 @@ -""" -配置和常量定义 -""" +"""配置和常量定义.""" + import tkinter as tk # 尝试导入拖拽支持库(可选) try: from tkinterdnd2 import DND_FILES, TkinterDnD + DND_AVAILABLE = True except ImportError: DND_AVAILABLE = False + # 创建一个兼容类 class TkinterDnD: class Tk(tk.Tk): pass + # 定义 DND_FILES 占位符(即使不使用) DND_FILES = None @@ -35,8 +37,7 @@ class Tk(tk.Tk): # 默认检测参数 DEFAULT_CONF = 0.25 # 默认置信度阈值 (0-1) -DEFAULT_IOU = 0.45 # 默认IoU阈值 (0-1) +DEFAULT_IOU = 0.45 # 默认IoU阈值 (0-1) # 默认保存文件夹 DEFAULT_SAVE_DIR = "detection_saves" # 默认检测结果保存文件夹 - diff --git a/UI/detection_processor.py b/UI/detection_processor.py index 414dad8..5657fe3 100644 --- a/UI/detection_processor.py +++ b/UI/detection_processor.py @@ -1,11 +1,9 @@ -""" -检测处理逻辑 - 屏幕、摄像头、文件检测 -""" +"""检测处理逻辑 - 屏幕、摄像头、文件检测.""" + +import os +import sys import threading import tkinter as tk -import sys -import os -from ultralytics import YOLO # 处理相对导入和绝对导入 try: @@ -16,8 +14,8 @@ class DetectionProcessor: - """检测处理器""" - + """检测处理器.""" + def __init__(self, yolo_model, gui_updater, buttons, status_label, info_text, video_label): self.yolo = yolo_model self.gui_updater = gui_updater @@ -28,35 +26,37 @@ def __init__(self, yolo_model, gui_updater, buttons, status_label, info_text, vi self.is_running = False self.frame_count = 0 self.conf = 0.25 # 默认置信度阈值 - self.iou = 0.45 # 默认IoU阈值 + self.iou = 0.45 # 默认IoU阈值 self.save_dir = None # 保存文件夹 self.save_frame_count = 0 # 保存的帧计数 - + def set_params(self, conf, iou): - """设置检测参数""" + """设置检测参数.""" self.conf = conf self.iou = iou - + def set_save_dir(self, save_dir): - """设置保存文件夹""" + """设置保存文件夹.""" self.save_dir = save_dir # 确保文件夹存在 import os + os.makedirs(save_dir, exist_ok=True) - + def save_detected_frame(self, result): - """保存检测到目标的帧""" + """保存检测到目标的帧.""" if not self.save_dir or len(result.boxes) == 0: return - + try: - import cv2 import os - import numpy as np from datetime import datetime - + + import cv2 + import numpy as np + # 获取原始图像 - if hasattr(result, 'orig_img'): + if hasattr(result, "orig_img"): orig_img = result.orig_img # YOLO返回的orig_img通常是RGB格式,需要转为BGR if isinstance(orig_img, np.ndarray) and len(orig_img.shape) == 3: @@ -65,113 +65,115 @@ def save_detected_frame(self, result): orig_img_bgr = cv2.cvtColor(orig_img, cv2.COLOR_RGB2BGR) else: orig_img_bgr = orig_img - + # 生成文件名(时间戳 + 帧编号) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") filename = f"detected_{timestamp}_{self.save_frame_count:06d}.jpg" filepath = os.path.join(self.save_dir, filename) - + # 保存帧 cv2.imwrite(filepath, orig_img_bgr) self.save_frame_count += 1 - + # 每保存10帧提示一次(避免信息过多) if self.save_frame_count % 10 == 0: - self.gui_updater.add_info(self.info_text, - f"已保存 {self.save_frame_count} 帧到: {os.path.basename(self.save_dir)}") - except Exception as e: + self.gui_updater.add_info( + self.info_text, f"已保存 {self.save_frame_count} 帧到: {os.path.basename(self.save_dir)}" + ) + except Exception: # 保存失败时不中断检测,只记录错误 pass - + def start_screen_detection(self): - """启动屏幕检测""" + """启动屏幕检测.""" if self.is_running: return - + try: self.is_running = True - self.buttons['start'].config(state=tk.DISABLED) - self.buttons['stop'].config(state=tk.NORMAL) + self.buttons["start"].config(state=tk.DISABLED) + self.buttons["stop"].config(state=tk.NORMAL) self.gui_updater.update_status(self.status_label, "屏幕检测运行中...") self.gui_updater.add_info(self.info_text, "开始屏幕检测...") - + # 在新线程中运行屏幕检测 detection_thread = threading.Thread(target=self.process_screen, daemon=True) detection_thread.start() - + except Exception as e: - self.gui_updater.update_status(self.status_label, f"错误: {str(e)}") - self.gui_updater.add_info(self.info_text, f"启动屏幕检测时发生错误: {str(e)}") + self.gui_updater.update_status(self.status_label, f"错误: {e!s}") + self.gui_updater.add_info(self.info_text, f"启动屏幕检测时发生错误: {e!s}") self.is_running = False - + def start_camera_detection(self): - """启动摄像头检测""" + """启动摄像头检测.""" if self.is_running: return - + try: self.is_running = True - self.buttons['start'].config(state=tk.DISABLED) - self.buttons['stop'].config(state=tk.NORMAL) + self.buttons["start"].config(state=tk.DISABLED) + self.buttons["stop"].config(state=tk.NORMAL) self.gui_updater.update_status(self.status_label, "摄像头检测运行中...") self.gui_updater.add_info(self.info_text, "摄像头已启动,开始实时检测...") - + # 在新线程中运行视频处理(YOLO会自动处理摄像头) detection_thread = threading.Thread(target=self.process_camera, daemon=True) detection_thread.start() - + except Exception as e: - self.gui_updater.update_status(self.status_label, f"错误: {str(e)}") - self.gui_updater.add_info(self.info_text, f"启动摄像头时发生错误: {str(e)}") + self.gui_updater.update_status(self.status_label, f"错误: {e!s}") + self.gui_updater.add_info(self.info_text, f"启动摄像头时发生错误: {e!s}") self.is_running = False - - def start_file_detection(self, file_path, detection_file_type_callback, - detection_results_callback, detection_has_results_callback): - """开始文件检测""" + + def start_file_detection( + self, file_path, detection_file_type_callback, detection_results_callback, detection_has_results_callback + ): + """开始文件检测.""" if self.is_running: return - + try: self.gui_updater.update_status(self.status_label, "处理中...") self.gui_updater.add_info(self.info_text, f"正在处理文件: {file_path}") - + # 检查文件类型 is_video = FileHandler.is_video_file(file_path) - + if is_video: # 视频文件:使用stream模式处理 - detection_file_type_callback('video') + detection_file_type_callback("video") self.is_running = True - self.buttons['start'].config(state=tk.DISABLED) - self.buttons['stop'].config(state=tk.NORMAL) - if 'save' in self.buttons: - self.buttons['save'].config(state=tk.DISABLED) - + self.buttons["start"].config(state=tk.DISABLED) + self.buttons["stop"].config(state=tk.NORMAL) + if "save" in self.buttons: + self.buttons["save"].config(state=tk.DISABLED) + # 在新线程中处理视频 detection_thread = threading.Thread( - target=self.process_video_file, - args=(file_path, detection_results_callback, detection_has_results_callback), - daemon=True + target=self.process_video_file, + args=(file_path, detection_results_callback, detection_has_results_callback), + daemon=True, ) detection_thread.start() else: # 图片文件:直接处理 - detection_file_type_callback('image') + detection_file_type_callback("image") results = self.yolo(file_path, verbose=False, conf=self.conf, iou=self.iou) - + # 绘制检测结果 annotated_frame = results[0].plot() - + # 保存检测结果 detection_results_callback(annotated_frame, results) - + # 更新显示 self.gui_updater.update_frame(self.video_label, annotated_frame) - + # 检查是否有检测结果 has_detections = len(results[0].boxes) > 0 detection_has_results_callback(has_detections) - + # 显示检测信息 if has_detections: self.update_detection_info(results, show_all=True) @@ -179,147 +181,145 @@ def start_file_detection(self, file_path, detection_file_type_callback, else: self.gui_updater.add_info(self.info_text, "未检测到目标") self.gui_updater.add_info(self.info_text, "检测完成,可以点击'保存检测结果'保存") - + # 启用保存按钮 - if 'save' in self.buttons: - self.gui_updater.update_button_state(self.buttons, 'save', tk.NORMAL) + if "save" in self.buttons: + self.gui_updater.update_button_state(self.buttons, "save", tk.NORMAL) self.gui_updater.update_status(self.status_label, "检测完成") self.is_running = False - self.gui_updater.update_button_state(self.buttons, 'start', tk.NORMAL) - self.gui_updater.update_button_state(self.buttons, 'stop', tk.DISABLED) - + self.gui_updater.update_button_state(self.buttons, "start", tk.NORMAL) + self.gui_updater.update_button_state(self.buttons, "stop", tk.DISABLED) + except Exception as e: - self.gui_updater.update_status(self.status_label, f"错误: {str(e)}") - self.gui_updater.add_info(self.info_text, f"处理文件时发生错误: {str(e)}") + self.gui_updater.update_status(self.status_label, f"错误: {e!s}") + self.gui_updater.add_info(self.info_text, f"处理文件时发生错误: {e!s}") from tkinter import messagebox - messagebox.showerror("错误", f"处理文件时发生错误: {str(e)}") + + messagebox.showerror("错误", f"处理文件时发生错误: {e!s}") self.is_running = False - self.gui_updater.update_button_state(self.buttons, 'start', tk.NORMAL) - self.gui_updater.update_button_state(self.buttons, 'stop', tk.DISABLED) - if 'save' in self.buttons: - self.gui_updater.update_button_state(self.buttons, 'save', tk.DISABLED) - + self.gui_updater.update_button_state(self.buttons, "start", tk.NORMAL) + self.gui_updater.update_button_state(self.buttons, "stop", tk.DISABLED) + if "save" in self.buttons: + self.gui_updater.update_button_state(self.buttons, "save", tk.DISABLED) + def process_screen(self): - """处理屏幕检测""" + """处理屏幕检测.""" try: # source="screen" 会自动处理屏幕捕获 - for result in self.yolo(source="screen", stream=True, verbose=False, - conf=self.conf, iou=self.iou): + for result in self.yolo(source="screen", stream=True, verbose=False, conf=self.conf, iou=self.iou): if not self.is_running: break - + # 获取带标注的帧 annotated_frame = result.plot() - + # 如果检测到目标,保存原始帧(不带标注) if len(result.boxes) > 0: self.save_detected_frame(result) - + # 更新显示 self.gui_updater.update_frame(self.video_label, annotated_frame) - + # 显示检测信息 self.update_detection_info([result]) - + except Exception as e: - self.gui_updater.add_info(self.info_text, f"屏幕检测时发生错误: {str(e)}") + self.gui_updater.add_info(self.info_text, f"屏幕检测时发生错误: {e!s}") self.is_running = False - self.gui_updater.update_button_state(self.buttons, 'start', tk.NORMAL) - self.gui_updater.update_button_state(self.buttons, 'stop', tk.DISABLED) - + self.gui_updater.update_button_state(self.buttons, "start", tk.NORMAL) + self.gui_updater.update_button_state(self.buttons, "stop", tk.DISABLED) + def process_camera(self): - """处理摄像头检测""" + """处理摄像头检测.""" try: # source=0 表示摄像头 - for result in self.yolo(source=0, stream=True, verbose=False, - conf=self.conf, iou=self.iou): + for result in self.yolo(source=0, stream=True, verbose=False, conf=self.conf, iou=self.iou): if not self.is_running: break - + # 获取带标注的帧 annotated_frame = result.plot() - + # 如果检测到目标,保存原始帧(不带标注) if len(result.boxes) > 0: self.save_detected_frame(result) - + # 更新显示 self.gui_updater.update_frame(self.video_label, annotated_frame) - + # 显示检测信息 self.update_detection_info([result]) except Exception as e: - self.gui_updater.add_info(self.info_text, f"摄像头检测时发生错误: {str(e)}") + self.gui_updater.add_info(self.info_text, f"摄像头检测时发生错误: {e!s}") finally: # 清理资源:关闭 dataset 以释放摄像头 - if hasattr(self.yolo, 'predictor') and self.yolo.predictor is not None: + if hasattr(self.yolo, "predictor") and self.yolo.predictor is not None: predictor = self.yolo.predictor - if hasattr(predictor, 'dataset') and predictor.dataset is not None: + if hasattr(predictor, "dataset") and predictor.dataset is not None: dataset = predictor.dataset - if hasattr(dataset, 'close'): + if hasattr(dataset, "close"): dataset.close() self.is_running = False - self.gui_updater.update_button_state(self.buttons, 'start', tk.NORMAL) - self.gui_updater.update_button_state(self.buttons, 'stop', tk.DISABLED) - + self.gui_updater.update_button_state(self.buttons, "start", tk.NORMAL) + self.gui_updater.update_button_state(self.buttons, "stop", tk.DISABLED) + def process_video_file(self, file_path, detection_results_callback, detection_has_results_callback): - """处理视频文件""" + """处理视频文件.""" try: # 使用YOLO的stream模式处理视频 has_detections = False video_frames = [] # 保存所有处理后的帧 - - for result in self.yolo(source=file_path, stream=True, verbose=False, - conf=self.conf, iou=self.iou): + + for result in self.yolo(source=file_path, stream=True, verbose=False, conf=self.conf, iou=self.iou): if not self.is_running: break - + # 检查是否有检测结果 if len(result.boxes) > 0: has_detections = True - + # 获取带标注的帧 annotated_frame = result.plot() - + # 保存帧到列表 video_frames.append(annotated_frame.copy()) - + # 更新显示 self.gui_updater.update_frame(self.video_label, annotated_frame) - + # 显示检测信息 self.update_detection_info([result]) - + if self.is_running: # 保存检测结果 detection_results_callback(video_frames, None) detection_has_results_callback(has_detections) - + if has_detections: self.gui_updater.add_info(self.info_text, "检测完成,可以点击'保存检测结果'保存") else: self.gui_updater.add_info(self.info_text, "未检测到目标") self.gui_updater.add_info(self.info_text, "检测完成,可以点击'保存检测结果'保存") - + # 启用保存按钮 - if 'save' in self.buttons: - self.gui_updater.update_button_state(self.buttons, 'save', tk.NORMAL) + if "save" in self.buttons: + self.gui_updater.update_button_state(self.buttons, "save", tk.NORMAL) self.gui_updater.add_info(self.info_text, "视频处理完成") self.gui_updater.update_status(self.status_label, "检测完成") self.is_running = False - self.gui_updater.update_button_state(self.buttons, 'start', tk.NORMAL) - self.gui_updater.update_button_state(self.buttons, 'stop', tk.DISABLED) - + self.gui_updater.update_button_state(self.buttons, "start", tk.NORMAL) + self.gui_updater.update_button_state(self.buttons, "stop", tk.DISABLED) + except Exception as e: - self.gui_updater.add_info(self.info_text, f"处理视频时发生错误: {str(e)}") + self.gui_updater.add_info(self.info_text, f"处理视频时发生错误: {e!s}") self.is_running = False - self.gui_updater.update_button_state(self.buttons, 'start', tk.NORMAL) - self.gui_updater.update_button_state(self.buttons, 'stop', tk.DISABLED) - if 'save' in self.buttons: - self.gui_updater.update_button_state(self.buttons, 'save', tk.DISABLED) - + self.gui_updater.update_button_state(self.buttons, "start", tk.NORMAL) + self.gui_updater.update_button_state(self.buttons, "stop", tk.DISABLED) + if "save" in self.buttons: + self.gui_updater.update_button_state(self.buttons, "save", tk.DISABLED) + def update_detection_info(self, results, show_all=False): - """更新检测信息""" + """更新检测信息.""" if len(results[0].boxes) > 0: detections = [] for box in results[0].boxes: @@ -327,7 +327,7 @@ def update_detection_info(self, results, show_all=False): conf = float(box.conf[0]) cls_name = self.yolo.names[cls_id] detections.append(f"{cls_name}: {conf:.2f}") - + # 对于图片检测,显示所有信息;对于实时检测,每10帧更新一次 if show_all or self.frame_count % 10 == 0: info = f"检测到 {len(results[0].boxes)} 个目标: {', '.join(detections[:10])}" @@ -337,10 +337,9 @@ def update_detection_info(self, results, show_all=False): elif show_all: # 文件检测模式下,如果没有检测到目标,显示提示 self.gui_updater.add_info(self.info_text, "未检测到目标") - + self.frame_count += 1 - + def stop(self): - """停止检测""" + """停止检测.""" self.is_running = False - diff --git a/UI/detection_ui.py b/UI/detection_ui.py index 1f7601a..6ee06ab 100644 --- a/UI/detection_ui.py +++ b/UI/detection_ui.py @@ -1,21 +1,20 @@ -""" -检测界面UI组件创建 -""" +"""检测界面UI组件创建.""" + +import os +import sys import tkinter as tk from tkinter import ttk -import sys -import os # 处理相对导入和绝对导入 try: # 尝试相对导入(当作为包的一部分运行时) - from .config import DND_AVAILABLE, DND_FILES, DEFAULT_CONF, DEFAULT_IOU + from .config import DEFAULT_CONF, DEFAULT_IOU, DND_AVAILABLE, DND_FILES except ImportError: # 如果相对导入失败,使用绝对导入(直接运行时) # 添加父目录到路径 sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) try: - from hys.config import DND_AVAILABLE, DND_FILES, DEFAULT_CONF, DEFAULT_IOU + from hys.config import DEFAULT_CONF, DEFAULT_IOU, DND_AVAILABLE, DND_FILES except ImportError: # 如果都失败,使用默认值 DND_AVAILABLE = False @@ -25,60 +24,58 @@ class DetectionUI: - """检测界面UI组件""" - + """检测界面UI组件.""" + @staticmethod def create_detection_ui(root, mode_name, control_frame_callback=None): - """创建检测界面""" + """创建检测界面.""" # 清除现有组件 for widget in root.winfo_children(): widget.destroy() - + # 主框架 main_frame = ttk.Frame(root, padding="10") main_frame.grid(row=0, column=0, sticky=(tk.W, tk.E, tk.N, tk.S)) - + # 配置网格权重 root.columnconfigure(0, weight=1) root.rowconfigure(0, weight=1) main_frame.columnconfigure(0, weight=1) main_frame.rowconfigure(1, weight=1) - + # 标题栏 header_frame = ttk.Frame(main_frame) header_frame.grid(row=0, column=0, sticky=(tk.W, tk.E), pady=10) - - title_label = ttk.Label(header_frame, text=f"YOLO {mode_name}", - font=("Arial", 16, "bold")) + + title_label = ttk.Label(header_frame, text=f"YOLO {mode_name}", font=("Arial", 16, "bold")) title_label.pack(side=tk.LEFT) - + # 返回按钮 - back_button = ttk.Button(header_frame, text="← 返回模式选择", - command=control_frame_callback) + back_button = ttk.Button(header_frame, text="← 返回模式选择", command=control_frame_callback) back_button.pack(side=tk.RIGHT) - + # 视频显示区域 video_frame = ttk.Frame(main_frame) video_frame.grid(row=1, column=0, sticky=(tk.W, tk.E, tk.N, tk.S), pady=10) video_frame.columnconfigure(0, weight=1) video_frame.rowconfigure(0, weight=1) - + # 视频标签(居中显示,使用grid布局防止窗口收缩) - video_label = ttk.Label(video_frame, text="准备中...", - background="black", foreground="white", - font=("Arial", 12), anchor="center") + video_label = ttk.Label( + video_frame, text="准备中...", background="black", foreground="white", font=("Arial", 12), anchor="center" + ) video_label.grid(row=0, column=0, sticky=(tk.W, tk.E, tk.N, tk.S)) - + # 控制按钮框架 control_frame = ttk.Frame(main_frame) control_frame.grid(row=2, column=0, pady=10) - + # 参数设置框架 params_frame = ttk.LabelFrame(control_frame, text="检测参数", padding="5") params_frame.pack(side=tk.LEFT, padx=10) - + # 使用已导入的默认值(在文件顶部已导入) - + # Conf参数输入 conf_frame = ttk.Frame(params_frame) conf_frame.pack(side=tk.LEFT, padx=5) @@ -87,7 +84,7 @@ def create_detection_ui(root, mode_name, control_frame_callback=None): conf_entry = ttk.Entry(conf_frame, textvariable=conf_var, width=6, font=("Arial", 9)) conf_entry.pack(side=tk.LEFT, padx=2) ttk.Label(conf_frame, text="(0-1)", font=("Arial", 8), foreground="gray").pack(side=tk.LEFT) - + # IOU参数输入 iou_frame = ttk.Frame(params_frame) iou_frame.pack(side=tk.LEFT, padx=5) @@ -96,85 +93,75 @@ def create_detection_ui(root, mode_name, control_frame_callback=None): iou_entry = ttk.Entry(iou_frame, textvariable=iou_var, width=6, font=("Arial", 9)) iou_entry.pack(side=tk.LEFT, padx=2) ttk.Label(iou_frame, text="(0-1)", font=("Arial", 8), foreground="gray").pack(side=tk.LEFT) - + # 参数说明提示(换行显示更清晰) help_text = "Conf:置信度阈值(0-1,越高越严格) | IOU:重叠度阈值(0-1,越高越宽松)" - help_label = ttk.Label(params_frame, - text=help_text, - font=("Arial", 8), foreground="gray") + help_label = ttk.Label(params_frame, text=help_text, font=("Arial", 8), foreground="gray") help_label.pack(side=tk.LEFT, padx=10) - + # 启动/停止按钮 button_frame = ttk.Frame(control_frame) button_frame.pack(side=tk.LEFT, padx=10) - - start_button = ttk.Button(button_frame, text="开始检测", - command=lambda: None) # 将在外部设置 + + start_button = ttk.Button(button_frame, text="开始检测", command=lambda: None) # 将在外部设置 start_button.pack(side=tk.LEFT, padx=5) - - stop_button = ttk.Button(button_frame, text="停止检测", - state=tk.DISABLED) + + stop_button = ttk.Button(button_frame, text="停止检测", state=tk.DISABLED) stop_button.pack(side=tk.LEFT, padx=5) - + # 状态标签 - status_label = ttk.Label(control_frame, text="状态: 未启动", - font=("Arial", 10)) + status_label = ttk.Label(control_frame, text="状态: 未启动", font=("Arial", 10)) status_label.pack(side=tk.LEFT, padx=20) - + # 信息显示区域 info_frame = ttk.LabelFrame(main_frame, text="检测信息", padding="10") info_frame.grid(row=3, column=0, sticky=(tk.W, tk.E), pady=10) info_frame.columnconfigure(0, weight=1) - + info_text = tk.Text(info_frame, height=5, width=80, wrap=tk.WORD) info_text.grid(row=0, column=0, sticky=(tk.W, tk.E)) - + scrollbar = ttk.Scrollbar(info_frame, orient=tk.VERTICAL, command=info_text.yview) scrollbar.grid(row=0, column=1, sticky=(tk.N, tk.S)) info_text.configure(yscrollcommand=scrollbar.set) - + return { - 'main_frame': main_frame, - 'video_label': video_label, - 'control_frame': control_frame, - 'start_button': start_button, - 'stop_button': stop_button, - 'status_label': status_label, - 'info_text': info_text, - 'conf_var': conf_var, - 'iou_var': iou_var, - 'conf_entry': conf_entry, - 'iou_entry': iou_entry + "main_frame": main_frame, + "video_label": video_label, + "control_frame": control_frame, + "start_button": start_button, + "stop_button": stop_button, + "status_label": status_label, + "info_text": info_text, + "conf_var": conf_var, + "iou_var": iou_var, + "conf_entry": conf_entry, + "iou_entry": iou_entry, } - + @staticmethod - def setup_file_detection_ui(control_frame, video_label, select_file_callback, - on_file_drop_callback, save_results_callback): - """设置文件检测界面的额外组件""" + def setup_file_detection_ui( + control_frame, video_label, select_file_callback, on_file_drop_callback, save_results_callback + ): + """设置文件检测界面的额外组件.""" # 添加文件选择区域 file_select_frame = ttk.Frame(control_frame) file_select_frame.pack(pady=10) - - select_button = ttk.Button(file_select_frame, text="选择文件", - command=select_file_callback) + + select_button = ttk.Button(file_select_frame, text="选择文件", command=select_file_callback) select_button.pack(side=tk.LEFT, padx=5) - + # 启用拖拽功能(如果可用) if DND_AVAILABLE: video_label.drop_target_register(DND_FILES) - video_label.dnd_bind('<>', on_file_drop_callback) - + video_label.dnd_bind("<>", on_file_drop_callback) + # 提示文字 - hint_label = ttk.Label(file_select_frame, - text="或拖拽文件到显示区域", - font=("Arial", 9), foreground="gray") + hint_label = ttk.Label(file_select_frame, text="或拖拽文件到显示区域", font=("Arial", 9), foreground="gray") hint_label.pack(side=tk.LEFT, padx=10) - + # 保存结果按钮 - save_button = ttk.Button(control_frame, text="保存检测结果", - command=save_results_callback, - state=tk.DISABLED) + save_button = ttk.Button(control_frame, text="保存检测结果", command=save_results_callback, state=tk.DISABLED) save_button.pack(side=tk.LEFT, padx=5) - - return save_button + return save_button diff --git a/UI/file_handler.py b/UI/file_handler.py index af46fef..971c73b 100644 --- a/UI/file_handler.py +++ b/UI/file_handler.py @@ -1,102 +1,98 @@ -""" -文件处理和保存功能 -""" -import cv2 +"""文件处理和保存功能.""" + import os -import tkinter as tk from tkinter import filedialog, messagebox +import cv2 + class FileHandler: - """文件处理类""" - + """文件处理类.""" + @staticmethod def select_file(): - """选择文件""" + """选择文件.""" file_path = filedialog.askopenfilename( title="选择图片或视频文件", filetypes=[ ("图片文件", "*.jpg *.jpeg *.png *.bmp *.gif *.tiff"), ("视频文件", "*.mp4 *.avi *.mov *.mkv *.flv *.wmv"), - ("所有文件", "*.*") - ] + ("所有文件", "*.*"), + ], ) return file_path - + @staticmethod def is_video_file(file_path): - """判断是否为视频文件""" + """判断是否为视频文件.""" ext = os.path.splitext(file_path)[1].lower() - return ext in ['.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv'] - + return ext in [".mp4", ".avi", ".mov", ".mkv", ".flv", ".wmv"] + @staticmethod def load_image_preview(file_path): - """加载图片预览""" + """加载图片预览.""" try: img = cv2.imread(file_path) return img except: return None - + @staticmethod - def save_detection_results(detection_results, detection_file_type, selected_file_path, - add_info_callback, show_message_callback): - """保存检测结果""" + def save_detection_results( + detection_results, detection_file_type, selected_file_path, add_info_callback, show_message_callback + ): + """保存检测结果.""" # 检查是否有检测结果 if detection_results is None: messagebox.showinfo("提示", "没有检测结果可保存") return False - + # 对于视频,检查列表是否为空 - if detection_file_type == 'video' and len(detection_results) == 0: + if detection_file_type == "video" and len(detection_results) == 0: messagebox.showinfo("提示", "没有检测结果可保存") return False - + try: - if detection_file_type == 'image': + if detection_file_type == "image": # 图片文件:让用户选择保存位置 default_filename = f"detected_{os.path.basename(selected_file_path)}" - + save_path = filedialog.asksaveasfilename( title="保存检测结果", defaultextension=".jpg", - filetypes=[ - ("JPEG文件", "*.jpg"), - ("PNG文件", "*.png"), - ("所有文件", "*.*") - ], - initialfile=default_filename + filetypes=[("JPEG文件", "*.jpg"), ("PNG文件", "*.png"), ("所有文件", "*.*")], + initialfile=default_filename, ) - + if save_path: cv2.imwrite(save_path, detection_results) add_info_callback(f"已将检测结果保存至: {save_path}") show_message_callback("成功", f"检测结果已保存至:\n{save_path}") return True - - elif detection_file_type == 'video': + + elif detection_file_type == "video": # 视频文件:让用户选择保存文件夹 save_dir = filedialog.askdirectory(title="选择保存文件夹") - + if save_dir: # 获取原始文件名 original_filename = os.path.basename(selected_file_path) name_without_ext = os.path.splitext(original_filename)[0] output_filename = f"detected_{name_without_ext}.mp4" output_path = os.path.join(save_dir, output_filename) - + # 获取视频帧的尺寸 if len(detection_results) > 0: height, width = detection_results[0].shape[:2] - + # 创建视频写入器 - fourcc = cv2.VideoWriter_fourcc(*'mp4v') + fourcc = cv2.VideoWriter_fourcc(*"mp4v") out = cv2.VideoWriter(output_path, fourcc, 30.0, (width, height)) - + # 写入所有帧 for frame in detection_results: out.write(frame) - + out.release() add_info_callback(f"已将检测结果保存至: {output_path}") show_message_callback("成功", f"检测结果已保存至:\n{output_path}") @@ -104,11 +100,10 @@ def save_detection_results(detection_results, detection_file_type, selected_file else: messagebox.showerror("错误", "没有视频帧可保存") return False - + except Exception as e: - add_info_callback(f"保存检测结果时发生错误: {str(e)}") - messagebox.showerror("错误", f"保存检测结果时发生错误: {str(e)}") + add_info_callback(f"保存检测结果时发生错误: {e!s}") + messagebox.showerror("错误", f"保存检测结果时发生错误: {e!s}") return False - - return False + return False diff --git a/UI/gui_main.py b/UI/gui_main.py index 46736e4..c9fd9e2 100644 --- a/UI/gui_main.py +++ b/UI/gui_main.py @@ -1,65 +1,63 @@ -""" -主窗口和模式选择界面 -""" -import tkinter as tk -from tkinter import ttk, messagebox -from queue import Queue +"""主窗口和模式选择界面.""" + import os -import cv2 import sys +import tkinter as tk +from queue import Queue +from tkinter import messagebox, ttk # 处理相对导入和绝对导入 try: # 尝试相对导入(当作为包的一部分运行时) - from .config import MODEL_PATH, WINDOW_TITLE, WINDOW_SIZE, DND_AVAILABLE, DND_FILES - from .detection_ui import DetectionUI - from .gui_utils import ThreadSafeGUIUpdater + from .config import DND_AVAILABLE, DND_FILES, MODEL_PATH, WINDOW_SIZE, WINDOW_TITLE from .detection_processor import DetectionProcessor + from .detection_ui import DetectionUI from .file_handler import FileHandler + from .gui_utils import ThreadSafeGUIUpdater except ImportError: # 如果相对导入失败,使用绝对导入(直接运行时) # 添加父目录到路径 sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - from hys.config import MODEL_PATH, WINDOW_TITLE, WINDOW_SIZE, DND_AVAILABLE, DND_FILES - from hys.detection_ui import DetectionUI - from hys.gui_utils import ThreadSafeGUIUpdater + from hys.config import MODEL_PATH, WINDOW_SIZE, WINDOW_TITLE from hys.detection_processor import DetectionProcessor + from hys.detection_ui import DetectionUI from hys.file_handler import FileHandler + from hys.gui_utils import ThreadSafeGUIUpdater from ultralytics import YOLO class YOLODetectionGUI: - """YOLO目标检测GUI主类""" - + """YOLO目标检测GUI主类.""" + def __init__(self, root): self.root = root self.root.title(WINDOW_TITLE) self.root.geometry(WINDOW_SIZE) - + # 模型路径(初始使用默认路径) self.model_path = MODEL_PATH self.yolo = None self.model_loaded = False - + # 加载YOLO模型 self.load_model(self.model_path) - + # 检测模式:'screen', 'camera', 'file', None self.detection_mode = None self.is_running = False - + # 文件检测设置 self.selected_file_path = None self.detection_results = None self.detection_results_info = None self.detection_file_type = None self.detection_has_results = False - + # 参数输入控件引用(初始化为None) self.conf_var = None self.iou_var = None - + # 保存文件夹设置 try: from .config import DEFAULT_SAVE_DIR @@ -68,28 +66,28 @@ def __init__(self, root): self.save_dir = DEFAULT_SAVE_DIR # 确保默认文件夹存在 os.makedirs(self.save_dir, exist_ok=True) - + # 当前显示的图像 self.current_frame = None - + # 线程安全的GUI更新队列 self.gui_queue = Queue() self.gui_updater = ThreadSafeGUIUpdater(self.root, self.gui_queue) self.root.after(100, self.gui_updater.process_gui_queue) - + # UI组件引用 self.ui_components = {} self.buttons = {} self.detection_processor = None - + # 显示模式选择界面 self.show_mode_selection() - + # 绑定窗口关闭事件 self.root.protocol("WM_DELETE_WINDOW", self.on_closing) - + def load_model(self, model_path): - """加载YOLO模型""" + """加载YOLO模型.""" try: self.yolo = YOLO(model_path, task="detect") self.model_path = model_path @@ -97,28 +95,24 @@ def load_model(self, model_path): return True except Exception as e: self.model_loaded = False - messagebox.showerror("错误", f"无法加载YOLO模型: {str(e)}") + messagebox.showerror("错误", f"无法加载YOLO模型: {e!s}") return False - + def select_model(self): - """选择模型文件""" + """选择模型文件.""" # 如果正在检测,先停止 if self.is_running: result = messagebox.askyesno("提示", "当前正在检测中,更换模型将停止检测。\n\n是否继续?") if not result: return self.stop_detection() - + from tkinter import filedialog + model_path = filedialog.askopenfilename( - title="选择YOLO模型文件", - filetypes=[ - ("PyTorch模型", "*.pt"), - ("所有文件", "*.*") - ], - initialdir="." + title="选择YOLO模型文件", filetypes=[("PyTorch模型", "*.pt"), ("所有文件", "*.*")], initialdir="." ) - + if model_path: # 尝试加载模型 if self.load_model(model_path): @@ -130,38 +124,36 @@ def select_model(self): if os.path.exists(MODEL_PATH): self.load_model(MODEL_PATH) self.show_mode_selection() - + def show_mode_selection(self): - """显示模式选择界面""" + """显示模式选择界面.""" # 保存当前窗口大小 # 清除现有组件 for widget in self.root.winfo_children(): widget.destroy() - + # 主框架 main_frame = ttk.Frame(self.root, padding="20") main_frame.pack(expand=True, fill=tk.BOTH) - + # 标题 - title_label = ttk.Label(main_frame, text="YOLO 目标检测系统", - font=("Arial", 20, "bold")) + title_label = ttk.Label(main_frame, text="YOLO 目标检测系统", font=("Arial", 20, "bold")) title_label.pack(pady=20) - + # 副标题 - subtitle_label = ttk.Label(main_frame, text="请选择检测模式", - font=("Arial", 12)) + subtitle_label = ttk.Label(main_frame, text="请选择检测模式", font=("Arial", 12)) subtitle_label.pack(pady=10) - + # 模型选择区域 model_frame = ttk.LabelFrame(main_frame, text="模型设置", padding="10") model_frame.pack(pady=15, padx=20, fill=tk.X) - + # 当前模型显示 model_info_frame = ttk.Frame(model_frame) model_info_frame.pack(fill=tk.X, pady=5) - + ttk.Label(model_info_frame, text="当前模型:", font=("Arial", 10)).pack(side=tk.LEFT, padx=5) - + # 显示模型文件名(如果路径太长,显示文件名;否则显示完整路径) if self.model_path: model_name = os.path.basename(self.model_path) @@ -170,24 +162,25 @@ def show_mode_selection(self): model_name = model_name[:27] + "..." else: model_name = "未加载" - - model_status_label = ttk.Label(model_info_frame, text=model_name, - font=("Arial", 10, "bold"), - foreground="blue" if self.model_loaded else "red") + + model_status_label = ttk.Label( + model_info_frame, + text=model_name, + font=("Arial", 10, "bold"), + foreground="blue" if self.model_loaded else "red", + ) model_status_label.pack(side=tk.LEFT, padx=5) - + # 添加完整路径提示(鼠标悬停时显示) if self.model_path and os.path.exists(self.model_path): full_path = os.path.abspath(self.model_path) model_status_label.bind("", lambda e: self._show_tooltip(e, full_path)) model_status_label.bind("", lambda e: self._hide_tooltip()) - + # 选择模型按钮 - select_model_button = ttk.Button(model_frame, text="📦 选择模型", - command=self.select_model, - width=20) + select_model_button = ttk.Button(model_frame, text="📦 选择模型", command=self.select_model, width=20) select_model_button.pack(pady=5) - + # 模型状态提示 if not self.model_loaded: status_text = "⚠️ 模型未加载,请先选择模型" @@ -195,172 +188,161 @@ def show_mode_selection(self): else: status_text = "✓ 模型已加载" status_color = "green" - - status_hint = ttk.Label(model_frame, text=status_text, - font=("Arial", 9), - foreground=status_color) + + status_hint = ttk.Label(model_frame, text=status_text, font=("Arial", 9), foreground=status_color) status_hint.pack(pady=2) - + # 保存文件夹设置区域 save_frame = ttk.LabelFrame(main_frame, text="保存设置", padding="10") save_frame.pack(pady=15, padx=20, fill=tk.X) - + # 当前保存文件夹显示 save_info_frame = ttk.Frame(save_frame) save_info_frame.pack(fill=tk.X, pady=5) - + ttk.Label(save_info_frame, text="保存文件夹:", font=("Arial", 10)).pack(side=tk.LEFT, padx=5) - + # 显示保存文件夹路径(如果路径太长,截断) save_dir_display = self.save_dir if len(save_dir_display) > 50: save_dir_display = "..." + save_dir_display[-47:] - - save_dir_label = ttk.Label(save_info_frame, text=save_dir_display, - font=("Arial", 9), - foreground="blue") + + save_dir_label = ttk.Label(save_info_frame, text=save_dir_display, font=("Arial", 9), foreground="blue") save_dir_label.pack(side=tk.LEFT, padx=5) - + # 添加完整路径提示 full_save_path = os.path.abspath(self.save_dir) save_dir_label.bind("", lambda e: self._show_tooltip(e, full_save_path)) save_dir_label.bind("", lambda e: self._hide_tooltip()) - + # 选择保存文件夹按钮 - select_save_button = ttk.Button(save_frame, text="📁 选择保存文件夹", - command=self.select_save_folder, - width=20) + select_save_button = ttk.Button(save_frame, text="📁 选择保存文件夹", command=self.select_save_folder, width=20) select_save_button.pack(pady=5) - + # 保存文件夹说明 - save_hint = ttk.Label(save_frame, - text="检测到目标时,会自动保存帧到此文件夹", - font=("Arial", 8), - foreground="gray") + save_hint = ttk.Label( + save_frame, text="检测到目标时,会自动保存帧到此文件夹", font=("Arial", 8), foreground="gray" + ) save_hint.pack(pady=2) - + # 按钮框架 button_frame = ttk.Frame(main_frame) button_frame.pack(pady=30) - + # 屏幕检测按钮 - screen_button = ttk.Button(button_frame, text="🖥️ 屏幕检测", - command=lambda: self.select_mode('screen'), - width=25) + screen_button = ttk.Button( + button_frame, text="🖥️ 屏幕检测", command=lambda: self.select_mode("screen"), width=25 + ) screen_button.pack(pady=15, padx=10) - + # 摄像头检测按钮 - camera_button = ttk.Button(button_frame, text="📷 摄像头检测", - command=lambda: self.select_mode('camera'), - width=25) + camera_button = ttk.Button( + button_frame, text="📷 摄像头检测", command=lambda: self.select_mode("camera"), width=25 + ) camera_button.pack(pady=15, padx=10) - + # 文件检测按钮 - file_button = ttk.Button(button_frame, text="📁 文件检测(图片/视频)", - command=lambda: self.select_mode('file'), - width=25) + file_button = ttk.Button( + button_frame, text="📁 文件检测(图片/视频)", command=lambda: self.select_mode("file"), width=25 + ) file_button.pack(pady=15, padx=10) - + # 说明文字 - info_label = ttk.Label(main_frame, - text="提示:文件检测模式下,您可以拖拽文件到窗口或点击按钮选择文件", - font=("Arial", 9), foreground="gray") + info_label = ttk.Label( + main_frame, + text="提示:文件检测模式下,您可以拖拽文件到窗口或点击按钮选择文件", + font=("Arial", 9), + foreground="gray", + ) info_label.pack(pady=20) - + def select_mode(self, mode): - """选择检测模式""" + """选择检测模式.""" if not self.model_loaded: messagebox.showerror("错误", "YOLO模型未加载,无法进行检测\n\n请先点击'选择模型'按钮加载模型") return - + self.detection_mode = mode - - if mode == 'screen': + + if mode == "screen": self.setup_screen_detection() - elif mode == 'camera': + elif mode == "camera": self.setup_camera_detection() - elif mode == 'file': + elif mode == "file": self.setup_file_detection() - + def setup_screen_detection(self): - """设置屏幕检测界面""" + """设置屏幕检测界面.""" self.create_detection_ui("屏幕检测") - + def setup_camera_detection(self): - """设置摄像头检测界面""" + """设置摄像头检测界面.""" self.create_detection_ui("摄像头检测") - + def setup_file_detection(self): - """设置文件检测界面""" + """设置文件检测界面.""" self.create_detection_ui("文件检测") - + # 清空之前选择的文件和结果 self.selected_file_path = None self.detection_results = None self.detection_results_info = None - + # 设置文件检测的额外UI save_button = DetectionUI.setup_file_detection_ui( - self.ui_components['control_frame'], - self.ui_components['video_label'], + self.ui_components["control_frame"], + self.ui_components["video_label"], self.select_file, self.on_file_drop, - self.save_detection_results + self.save_detection_results, ) - self.buttons['save'] = save_button - + self.buttons["save"] = save_button + # 更新状态提示 - self.gui_updater.update_status(self.ui_components['status_label'], "请选择文件") - + self.gui_updater.update_status(self.ui_components["status_label"], "请选择文件") + def create_detection_ui(self, mode_name): - """创建检测界面""" - ui_dict = DetectionUI.create_detection_ui( - self.root, - mode_name, - self.show_mode_selection - ) - + """创建检测界面.""" + ui_dict = DetectionUI.create_detection_ui(self.root, mode_name, self.show_mode_selection) + self.ui_components = ui_dict - + # 设置按钮命令 - ui_dict['start_button'].config(command=self.toggle_detection) - ui_dict['stop_button'].config(command=self.stop_detection) - + ui_dict["start_button"].config(command=self.toggle_detection) + ui_dict["stop_button"].config(command=self.stop_detection) + # 保存按钮引用 - self.buttons = { - 'start': ui_dict['start_button'], - 'stop': ui_dict['stop_button'] - } - + self.buttons = {"start": ui_dict["start_button"], "stop": ui_dict["stop_button"]} + # 创建检测处理器 self.detection_processor = DetectionProcessor( self.yolo, self.gui_updater, self.buttons, - ui_dict['status_label'], - ui_dict['info_text'], - ui_dict['video_label'] + ui_dict["status_label"], + ui_dict["info_text"], + ui_dict["video_label"], ) - + # 设置保存文件夹(仅在屏幕和摄像头检测模式下) - if self.detection_mode in ['screen', 'camera']: + if self.detection_mode in ["screen", "camera"]: self.detection_processor.set_save_dir(self.save_dir) # 重置保存帧计数 self.detection_processor.save_frame_count = 0 - + # 保存参数输入控件的引用 - if 'conf_var' in ui_dict and 'iou_var' in ui_dict: - self.conf_var = ui_dict['conf_var'] - self.iou_var = ui_dict['iou_var'] + if "conf_var" in ui_dict and "iou_var" in ui_dict: + self.conf_var = ui_dict["conf_var"] + self.iou_var = ui_dict["iou_var"] else: # 如果没有参数控件,创建默认值 self.conf_var = None self.iou_var = None - + def _validate_and_get_params(self): - """验证并获取检测参数""" + """验证并获取检测参数.""" from .config import DEFAULT_CONF, DEFAULT_IOU - + try: # 获取conf参数 if self.conf_var: @@ -368,211 +350,208 @@ def _validate_and_get_params(self): conf = float(conf_str) if conf_str else DEFAULT_CONF else: conf = DEFAULT_CONF - + # 验证conf范围 if conf < 0 or conf > 1: messagebox.showerror("参数错误", f"Conf参数必须在0-1之间,当前值: {conf}") return None, None conf = max(0.0, min(1.0, conf)) # 确保在范围内 - + # 获取iou参数 if self.iou_var: iou_str = self.iou_var.get().strip() iou = float(iou_str) if iou_str else DEFAULT_IOU else: iou = DEFAULT_IOU - + # 验证iou范围 if iou < 0 or iou > 1: messagebox.showerror("参数错误", f"IOU参数必须在0-1之间,当前值: {iou}") return None, None iou = max(0.0, min(1.0, iou)) # 确保在范围内 - + return conf, iou - + except ValueError: messagebox.showerror("参数错误", "请输入有效的数字(0-1之间)") return None, None except Exception as e: - messagebox.showerror("参数错误", f"读取参数时发生错误: {str(e)}") + messagebox.showerror("参数错误", f"读取参数时发生错误: {e!s}") return None, None - + def toggle_detection(self): - """切换检测状态""" - if self.detection_mode == 'screen': + """切换检测状态.""" + if self.detection_mode == "screen": if not self.is_running: # 获取并验证参数 conf, iou = self._validate_and_get_params() if conf is None or iou is None: return - + # 设置参数 self.detection_processor.set_params(conf, iou) self.detection_processor.start_screen_detection() self.is_running = True else: self.stop_detection() - elif self.detection_mode == 'camera': + elif self.detection_mode == "camera": if not self.is_running: # 获取并验证参数 conf, iou = self._validate_and_get_params() if conf is None or iou is None: return - + # 设置参数 self.detection_processor.set_params(conf, iou) self.detection_processor.start_camera_detection() self.is_running = True else: self.stop_detection() - elif self.detection_mode == 'file': + elif self.detection_mode == "file": if not self.is_running: if self.selected_file_path: # 获取并验证参数 conf, iou = self._validate_and_get_params() if conf is None or iou is None: return - + # 设置参数 self.detection_processor.set_params(conf, iou) self.detection_processor.start_file_detection( self.selected_file_path, self._set_detection_file_type, self._set_detection_results, - self._set_detection_has_results + self._set_detection_has_results, ) self.is_running = True else: messagebox.showinfo("提示", "请先选择要检测的文件") else: self.stop_detection() - + def _set_detection_file_type(self, file_type): - """设置检测文件类型""" + """设置检测文件类型.""" self.detection_file_type = file_type - + def _set_detection_results(self, results, info=None): - """设置检测结果""" + """设置检测结果.""" self.detection_results = results if info is not None: self.detection_results_info = info - + def _set_detection_has_results(self, has_results): - """设置是否有检测结果""" + """设置是否有检测结果.""" self.detection_has_results = has_results - + def stop_detection(self): - """停止检测""" + """停止检测.""" self.is_running = False - + if self.detection_processor: self.detection_processor.stop() - + # 更新控件状态 try: - if 'start' in self.buttons and self.buttons['start'].winfo_exists(): - self.buttons['start'].config(state=tk.NORMAL) - if 'stop' in self.buttons and self.buttons['stop'].winfo_exists(): - self.buttons['stop'].config(state=tk.DISABLED) - - status_label = self.ui_components.get('status_label') + if "start" in self.buttons and self.buttons["start"].winfo_exists(): + self.buttons["start"].config(state=tk.NORMAL) + if "stop" in self.buttons and self.buttons["stop"].winfo_exists(): + self.buttons["stop"].config(state=tk.DISABLED) + + status_label = self.ui_components.get("status_label") if status_label and status_label.winfo_exists(): - if self.detection_mode == 'file' and self.selected_file_path: + if self.detection_mode == "file" and self.selected_file_path: self.gui_updater.update_status(status_label, "已停止,可重新开始检测") else: self.gui_updater.update_status(status_label, "已停止") - - info_text = self.ui_components.get('info_text') + + info_text = self.ui_components.get("info_text") if info_text and info_text.winfo_exists(): self.gui_updater.add_info(info_text, "检测已停止。") - + # 清空显示(文件检测模式不清空,保持文件预览) - video_label = self.ui_components.get('video_label') - if video_label and video_label.winfo_exists() and self.detection_mode != 'file': - video_label.config(image='', text="检测已停止") - + video_label = self.ui_components.get("video_label") + if video_label and video_label.winfo_exists() and self.detection_mode != "file": + video_label.config(image="", text="检测已停止") + # 禁用保存按钮(如果存在) - if 'save' in self.buttons and self.buttons['save'].winfo_exists(): - self.buttons['save'].config(state=tk.DISABLED) + if "save" in self.buttons and self.buttons["save"].winfo_exists(): + self.buttons["save"].config(state=tk.DISABLED) except: pass # 窗口可能已关闭,忽略错误 - + def select_file(self): - """选择文件""" + """选择文件.""" file_path = FileHandler.select_file() - + if file_path: self.selected_file_path = file_path - info_text = self.ui_components.get('info_text') + info_text = self.ui_components.get("info_text") if info_text: self.gui_updater.add_info(info_text, f"已选择文件: {os.path.basename(file_path)}") - - status_label = self.ui_components.get('status_label') + + status_label = self.ui_components.get("status_label") if status_label: self.gui_updater.update_status(status_label, "已选择文件,点击'开始检测'开始检测") - + # 显示文件预览(如果是图片) if not FileHandler.is_video_file(file_path): img = FileHandler.load_image_preview(file_path) if img is not None: - video_label = self.ui_components.get('video_label') + video_label = self.ui_components.get("video_label") if video_label: self.gui_updater.update_frame(video_label, img) - + def on_file_drop(self, event): - """处理文件拖拽""" + """处理文件拖拽.""" file_path = event.data.strip() # 移除可能的花括号 - if file_path.startswith('{') and file_path.endswith('}'): + if file_path.startswith("{") and file_path.endswith("}"): file_path = file_path[1:-1] - + # 检查文件是否存在 if os.path.exists(file_path): self.selected_file_path = file_path - info_text = self.ui_components.get('info_text') + info_text = self.ui_components.get("info_text") if info_text: self.gui_updater.add_info(info_text, f"已选择文件: {os.path.basename(file_path)}") - - status_label = self.ui_components.get('status_label') + + status_label = self.ui_components.get("status_label") if status_label: self.gui_updater.update_status(status_label, "已选择文件,点击'开始检测'开始检测") - + # 显示文件预览(如果是图片) if not FileHandler.is_video_file(file_path): img = FileHandler.load_image_preview(file_path) if img is not None: - video_label = self.ui_components.get('video_label') + video_label = self.ui_components.get("video_label") if video_label: self.gui_updater.update_frame(video_label, img) else: messagebox.showerror("错误", "文件不存在") - + def save_detection_results(self): - """保存检测结果""" + """保存检测结果.""" + def add_info(msg): - info_text = self.ui_components.get('info_text') + info_text = self.ui_components.get("info_text") if info_text: self.gui_updater.add_info(info_text, msg) - + def show_message(title, msg): messagebox.showinfo(title, msg) - + FileHandler.save_detection_results( - self.detection_results, - self.detection_file_type, - self.selected_file_path, - add_info, - show_message + self.detection_results, self.detection_file_type, self.selected_file_path, add_info, show_message ) - + def select_save_folder(self): - """选择保存文件夹""" + """选择保存文件夹.""" from tkinter import filedialog + folder = filedialog.askdirectory( - title="选择保存文件夹", - initialdir=self.save_dir if os.path.exists(self.save_dir) else "." + title="选择保存文件夹", initialdir=self.save_dir if os.path.exists(self.save_dir) else "." ) - + if folder: self.save_dir = folder # 确保文件夹存在 @@ -580,31 +559,29 @@ def select_save_folder(self): # 刷新界面显示 self.show_mode_selection() messagebox.showinfo("成功", f"保存文件夹已设置为:\n{os.path.abspath(folder)}") - + def _show_tooltip(self, event, text): - """显示工具提示""" + """显示工具提示.""" tooltip = tk.Toplevel() tooltip.wm_overrideredirect(True) - tooltip.wm_geometry(f"+{event.x_root+10}+{event.y_root+10}") - label = tk.Label(tooltip, text=text, background="yellow", - relief="solid", borderwidth=1, font=("Arial", 9)) + tooltip.wm_geometry(f"+{event.x_root + 10}+{event.y_root + 10}") + label = tk.Label(tooltip, text=text, background="yellow", relief="solid", borderwidth=1, font=("Arial", 9)) label.pack() self._tooltip_window = tooltip - + def _hide_tooltip(self): - """隐藏工具提示""" - if hasattr(self, '_tooltip_window'): + """隐藏工具提示.""" + if hasattr(self, "_tooltip_window"): try: self._tooltip_window.destroy() except: pass - if hasattr(self, '_tooltip_window'): - delattr(self, '_tooltip_window') - + if hasattr(self, "_tooltip_window"): + delattr(self, "_tooltip_window") + def on_closing(self): - """窗口关闭时的处理""" + """窗口关闭时的处理.""" # 只有在检测界面时才需要停止检测 if self.is_running: self.stop_detection() self.root.destroy() - diff --git a/UI/gui_utils.py b/UI/gui_utils.py index bf41df6..32a0bce 100644 --- a/UI/gui_utils.py +++ b/UI/gui_utils.py @@ -1,34 +1,37 @@ -""" -GUI工具函数 - 线程安全的GUI更新方法 -""" -import cv2 -import tkinter as tk -import sys +"""GUI工具函数 - 线程安全的GUI更新方法.""" + import os +import sys +import tkinter as tk + +import cv2 from PIL import Image, ImageTk + # 处理相对导入和绝对导入 def _import_config(): - """导入配置模块""" + """导入配置模块.""" try: - from .config import GUI_QUEUE_INTERVAL, MAX_INFO_LINES, MAX_DISPLAY_WIDTH, MAX_DISPLAY_HEIGHT + from .config import GUI_QUEUE_INTERVAL, MAX_DISPLAY_HEIGHT, MAX_DISPLAY_WIDTH, MAX_INFO_LINES + return GUI_QUEUE_INTERVAL, MAX_INFO_LINES, MAX_DISPLAY_WIDTH, MAX_DISPLAY_HEIGHT except ImportError: sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - from hys.config import GUI_QUEUE_INTERVAL, MAX_INFO_LINES, MAX_DISPLAY_WIDTH, MAX_DISPLAY_HEIGHT + from hys.config import GUI_QUEUE_INTERVAL, MAX_DISPLAY_HEIGHT, MAX_DISPLAY_WIDTH, MAX_INFO_LINES + return GUI_QUEUE_INTERVAL, MAX_INFO_LINES, MAX_DISPLAY_WIDTH, MAX_DISPLAY_HEIGHT class ThreadSafeGUIUpdater: - """线程安全的GUI更新器""" - + """线程安全的GUI更新器.""" + def __init__(self, root, gui_queue): self.root = root self.gui_queue = gui_queue self.frame_count = 0 - + def process_gui_queue(self): - """处理GUI更新队列(在主线程中调用)""" + """处理GUI更新队列(在主线程中调用).""" try: while True: try: @@ -36,89 +39,89 @@ def process_gui_queue(self): task = self.gui_queue.get_nowait() if task is None: # 停止信号 break - + task_type, args = task - - if task_type == 'update_status': + + if task_type == "update_status": self._update_status_direct(args) - elif task_type == 'add_info': + elif task_type == "add_info": self._add_info_direct(args) - elif task_type == 'update_frame': + elif task_type == "update_frame": self._update_frame_direct(args) - elif task_type == 'update_button_state': + elif task_type == "update_button_state": self._update_button_state_direct(args) - + except: break # 队列为空,退出循环 except: pass # 忽略错误,继续运行 - + # 继续处理队列 try: GUI_QUEUE_INTERVAL, _, _, _ = _import_config() self.root.after(GUI_QUEUE_INTERVAL, self.process_gui_queue) except: pass # 窗口可能已关闭 - + def _update_status_direct(self, args): - """直接更新状态标签(仅在主线程调用)""" + """直接更新状态标签(仅在主线程调用).""" status_label, status = args try: if status_label and status_label.winfo_exists(): status_label.config(text=f"状态: {status}") except: pass - + def _add_info_direct(self, args): - """直接添加信息(仅在主线程调用)""" + """直接添加信息(仅在主线程调用).""" info_text, message = args try: if info_text and info_text.winfo_exists(): info_text.insert(tk.END, f"{message}\n") info_text.see(tk.END) - + # 限制信息条数 _, MAX_INFO_LINES, _, _ = _import_config() - lines = info_text.get("1.0", tk.END).split('\n') + lines = info_text.get("1.0", tk.END).split("\n") if len(lines) > MAX_INFO_LINES: info_text.delete("1.0", f"{len(lines) - MAX_INFO_LINES}.0") except: pass - + def _update_frame_direct(self, args): - """直接更新帧(仅在主线程调用)""" + """直接更新帧(仅在主线程调用).""" video_label, frame, max_width, max_height = args try: if not video_label or not video_label.winfo_exists(): return - + # 转换颜色空间 BGR -> RGB frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - + # 调整大小以适应显示区域 height, width = frame_rgb.shape[:2] - + if width > max_width or height > max_height: scale = min(max_width / width, max_height / height) new_width = int(width * scale) new_height = int(height * scale) frame_rgb = cv2.resize(frame_rgb, (new_width, new_height)) - + # 转换为PIL Image image = Image.fromarray(frame_rgb) photo = ImageTk.PhotoImage(image=image) - + # 更新标签 video_label.config(image=photo, text="") video_label.image = photo # 保持引用 - + # 更新帧计数 self.frame_count += 1 except: pass - + def _update_button_state_direct(self, args): - """直接更新按钮状态(仅在主线程调用)""" + """直接更新按钮状态(仅在主线程调用).""" buttons, button_name, state = args try: button = buttons.get(button_name) @@ -126,33 +129,32 @@ def _update_button_state_direct(self, args): button.config(state=state) except: pass - + def update_status(self, status_label, status): - """更新状态标签(线程安全)""" + """更新状态标签(线程安全).""" try: - self.gui_queue.put(('update_status', [status_label, status])) + self.gui_queue.put(("update_status", [status_label, status])) except: pass - + def add_info(self, info_text, message): - """添加信息到信息显示区域(线程安全)""" + """添加信息到信息显示区域(线程安全).""" try: - self.gui_queue.put(('add_info', [info_text, message])) + self.gui_queue.put(("add_info", [info_text, message])) except: pass - + def update_frame(self, video_label, frame): - """更新显示的帧(线程安全)""" + """更新显示的帧(线程安全).""" try: _, _, MAX_DISPLAY_WIDTH, MAX_DISPLAY_HEIGHT = _import_config() - self.gui_queue.put(('update_frame', [video_label, frame, MAX_DISPLAY_WIDTH, MAX_DISPLAY_HEIGHT])) + self.gui_queue.put(("update_frame", [video_label, frame, MAX_DISPLAY_WIDTH, MAX_DISPLAY_HEIGHT])) except: pass - + def update_button_state(self, buttons, button_name, state): - """更新按钮状态(线程安全)""" + """更新按钮状态(线程安全).""" try: - self.gui_queue.put(('update_button_state', [buttons, button_name, state])) + self.gui_queue.put(("update_button_state", [buttons, button_name, state])) except: pass - diff --git a/UI/main.py b/UI/main.py index 39078cb..73b36df 100644 --- a/UI/main.py +++ b/UI/main.py @@ -1,9 +1,8 @@ -""" -程序入口 -""" -import tkinter as tk -import sys +"""程序入口.""" + import os +import sys +import tkinter as tk # 添加父目录到路径,以便导入hys包 sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -19,18 +18,17 @@ def main(): - """主函数""" + """主函数.""" # 使用TkinterDnD来支持拖拽功能(如果可用) if DND_AVAILABLE: root = TkinterDnD.Tk() else: root = tk.Tk() print("提示: 安装 tkinterdnd2 库可启用文件拖拽功能: pip install tkinterdnd2") - - app = YOLODetectionGUI(root) + + YOLODetectionGUI(root) root.mainloop() if __name__ == "__main__": main() - diff --git a/configs/yolo_fire.yaml b/configs/yolo_fire.yaml index c454327..4c87f64 100644 --- a/configs/yolo_fire.yaml +++ b/configs/yolo_fire.yaml @@ -3,4 +3,4 @@ train: images/train val: images/val names: 0: smoke - 1: fire \ No newline at end of file + 1: fire diff --git a/pyproject.toml b/pyproject.toml index eaf5956..b5c6248 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,7 +88,7 @@ dev = [ "minijinja>=2.0.0", # render docs macros without mkdocs-macros-plugin ] export = [ - "numpy<2.0.0", # TF 2.20 compatibility + "numpy<3.0.0", # TF 2.20 compatibility "onnx>=1.12.0; platform_system != 'Darwin'", # ONNX export "onnx>=1.12.0,<1.18.0; platform_system == 'Darwin'", # TF inference hanging on MacOS "coremltools>=8.0; platform_system != 'Windows' and python_version <= '3.13'", # CoreML supported on macOS and Linux diff --git a/scripts/README.md b/scripts/README.md index 15da9ec..f5da989 100644 --- a/scripts/README.md +++ b/scripts/README.md @@ -3,6 +3,7 @@ ## 可用脚本 ### 1. run_gui.py - 运行GUI界面 + 启动图形化界面进行火灾检测 ```bash @@ -10,12 +11,14 @@ python scripts/run_gui.py ``` **功能:** + - 图片/视频检测 - 实时摄像头检测 - 可视化结果展示 - 参数调节 ### 2. train_model.py - 训练模型 + 训练YOLO火灾检测模型 ```bash @@ -23,12 +26,14 @@ python scripts/train_model.py ``` **配置参数:** + - 模型: `models/yolov8n.pt` - 数据配置: `configs/yolo_fire.yaml` - Epochs: 50 - Batch size: 48 ### 3. validate_model.py - 验证模型 + 在验证集上评估模型性能 ```bash @@ -36,12 +41,14 @@ python scripts/validate_model.py ``` **输出指标:** + - mAP50 - mAP50-95 - Precision - Recall ### 4. test_model.py - 测试模型 + 测试模型在不同输入源上的表现 ```bash @@ -62,6 +69,7 @@ python scripts/test_model.py --source 0 --model runs/detect/train2/weights/best. ``` **参数说明:** + - `--source`: 输入源(图片/视频路径、摄像头编号、文件夹) - `--model`: 模型路径(默认: runs/detect/train2/weights/best.pt) - `--conf`: 置信度阈值(默认: 0.25) @@ -70,19 +78,25 @@ python scripts/test_model.py --source 0 --model runs/detect/train2/weights/best. ## 快速开始 ### 新手推荐 + 使用GUI界面,最简单直观: + ```bash python scripts/run_gui.py ``` ### 命令行用户 + 快速测试摄像头: + ```bash python scripts/test_model.py --source 0 ``` ### 开发者 + 训练自己的模型: + ```bash python scripts/train_model.py python scripts/validate_model.py diff --git a/scripts/run_gui.py b/scripts/run_gui.py index 9f4fbc3..b214f90 100644 --- a/scripts/run_gui.py +++ b/scripts/run_gui.py @@ -1,18 +1,18 @@ -""" -运行YOLO火灾检测GUI应用程序 -""" +"""运行YOLO火灾检测GUI应用程序.""" + if __name__ == "__main__": - import sys import os - + import sys + # 设置使用项目本地的Ultralytics配置 project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) - os.environ['ULTRALYTICS_CONFIG_DIR'] = os.path.join(project_root, 'configs') - + os.environ["ULTRALYTICS_CONFIG_DIR"] = os.path.join(project_root, "configs") + # 确保项目根目录在Python路径中 if project_root not in sys.path: sys.path.insert(0, project_root) - + # 导入并运行 from UI.main import main + main() diff --git a/scripts/test_model.py b/scripts/test_model.py index acea87c..4f2b041 100644 --- a/scripts/test_model.py +++ b/scripts/test_model.py @@ -1,20 +1,21 @@ """ 测试YOLO火灾检测模型 -支持图片、视频和摄像头 +支持图片、视频和摄像头. """ + import os -import sys # 设置使用项目本地的Ultralytics配置 -os.environ['ULTRALYTICS_CONFIG_DIR'] = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'configs') +os.environ["ULTRALYTICS_CONFIG_DIR"] = os.path.join(os.path.dirname(os.path.dirname(__file__)), "configs") -from ultralytics import YOLO import argparse +from ultralytics import YOLO + + def test_model(source, model_path="models/yolov8n.pt", save=True, conf=0.25): - """ - 测试模型 - + """测试模型. + Args: source: 输入源 - 图片路径: 'image.jpg' @@ -29,36 +30,29 @@ def test_model(source, model_path="models/yolov8n.pt", save=True, conf=0.25): results = model(source=source, save=save, conf=conf) return results + def main(): - parser = argparse.ArgumentParser(description='测试YOLO火灾检测模型') - parser.add_argument('--source', type=str, default=0, - help='输入源: 图片/视频路径, 摄像头编号(0), 或文件夹路径') - parser.add_argument('--model', type=str, default='runs/detect/train2/weights/best.pt', - help='模型路径') - parser.add_argument('--conf', type=float, default=0.25, - help='置信度阈值 (0-1)') - parser.add_argument('--no-save', action='store_true', - help='不保存检测结果') - + parser = argparse.ArgumentParser(description="测试YOLO火灾检测模型") + parser.add_argument("--source", type=str, default=0, help="输入源: 图片/视频路径, 摄像头编号(0), 或文件夹路径") + parser.add_argument("--model", type=str, default="runs/detect/train2/weights/best.pt", help="模型路径") + parser.add_argument("--conf", type=float, default=0.25, help="置信度阈值 (0-1)") + parser.add_argument("--no-save", action="store_true", help="不保存检测结果") + args = parser.parse_args() - + # 如果source是数字字符串,转换为整数(摄像头编号) source = args.source try: source = int(source) except ValueError: pass - + print(f"使用模型: {args.model}") print(f"输入源: {source}") print(f"置信度阈值: {args.conf}") - - test_model( - source=source, - model_path=args.model, - save=not args.no_save, - conf=args.conf - ) + + test_model(source=source, model_path=args.model, save=not args.no_save, conf=args.conf) + if __name__ == "__main__": main() diff --git a/scripts/train_model.py b/scripts/train_model.py index a347bbe..f8edb87 100644 --- a/scripts/train_model.py +++ b/scripts/train_model.py @@ -1,22 +1,18 @@ -""" -训练YOLO火灾检测模型 -""" +"""训练YOLO火灾检测模型.""" + import os # 设置使用项目本地的Ultralytics配置 -os.environ['ULTRALYTICS_CONFIG_DIR'] = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'configs') +os.environ["ULTRALYTICS_CONFIG_DIR"] = os.path.join(os.path.dirname(os.path.dirname(__file__)), "configs") from ultralytics import YOLO + def train(): - """训练模型""" + """训练模型.""" model = YOLO("models/yolov8n.pt") - model.train( - data="configs/yolo_fire.yaml", - workers=0, - epochs=50, - batch=48 - ) + model.train(data="configs/yolo_fire.yaml", workers=0, epochs=50, batch=48) + if __name__ == "__main__": train() diff --git a/scripts/validate_model.py b/scripts/validate_model.py index 72e275e..d031a20 100644 --- a/scripts/validate_model.py +++ b/scripts/validate_model.py @@ -1,27 +1,28 @@ -""" -验证YOLO火灾检测模型性能 -""" +"""验证YOLO火灾检测模型性能.""" + import os # 设置使用项目本地的Ultralytics配置 -os.environ['ULTRALYTICS_CONFIG_DIR'] = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'configs') +os.environ["ULTRALYTICS_CONFIG_DIR"] = os.path.join(os.path.dirname(os.path.dirname(__file__)), "configs") from ultralytics import YOLO + def validate(): - """在验证集上验证模型""" + """在验证集上验证模型.""" # 加载训练好的模型 model = YOLO("runs/detect/train2/weights/best.pt") - + # 在验证集上进行验证 results = model.val(data="configs/yolo_fire.yaml") - + # 打印验证结果 - print(f"\n验证结果:") + print("\n验证结果:") print(f"mAP50: {results.box.map50:.4f}") print(f"mAP50-95: {results.box.map:.4f}") print(f"Precision: {results.box.mp:.4f}") print(f"Recall: {results.box.mr:.4f}") + if __name__ == "__main__": validate() diff --git a/ultralytics/trackers/bot_sort.py b/ultralytics/trackers/bot_sort.py index 9a75c08..9a06a12 100644 --- a/ultralytics/trackers/bot_sort.py +++ b/ultralytics/trackers/bot_sort.py @@ -71,7 +71,7 @@ def __init__( self.curr_feat = None if feat is not None: self.update_features(feat) - self.features = deque([], maxlen=feat_history) + self.features = deque(maxlen=feat_history) self.alpha = 0.9 def update_features(self, feat: np.ndarray) -> None: