diff --git a/ml-backends/temperature_annotation/Dockerfile b/ml-backends/temperature_annotation/Dockerfile new file mode 100644 index 0000000..87020cc --- /dev/null +++ b/ml-backends/temperature_annotation/Dockerfile @@ -0,0 +1,34 @@ +# 温度标注ML Backend的Docker配置 +# 基于Python 3.9构建 + +FROM python:3.9-slim + +WORKDIR /app + +# 设置环境变量 +ENV PYTHONUNBUFFERED=1 \ + PYTHONDONTWRITEBYTECODE=1 \ + PORT=9090 \ + WORKERS=1 \ + THREADS=8 + +# 安装系统依赖 +RUN apt-get update && apt-get install -y \ + git \ + && rm -rf /var/lib/apt/lists/* + +# 复制依赖文件 +COPY requirements.txt . +COPY requirements-base.txt . + +# 安装Python依赖 +RUN pip install --no-cache-dir -r requirements.txt + +# 复制应用代码 +COPY . . + +# 暴露端口 +EXPOSE 9090 + +# 启动命令 +CMD ["python", "_wsgi.py", "--host", "0.0.0.0", "--port", "9090"] diff --git a/ml-backends/temperature_annotation/README.md b/ml-backends/temperature_annotation/README.md new file mode 100644 index 0000000..804e536 --- /dev/null +++ b/ml-backends/temperature_annotation/README.md @@ -0,0 +1,197 @@ +# Temperature Annotation ML Backend + +这是一个专门用于托卡马克等离子体温度数据标注的ML Backend,支持温度特征识别、异常检测和智能标注。 + +## 功能特性 + +- **温度特征识别**: 自动检测温度上升、下降、峰值等关键阶段 +- **异常事件检测**: 识别温度突变、骤降等异常情况 +- **多通道支持**: 支持多个温度测量通道的并行分析 +- **自适应阈值**: 基于人工标注自动优化检测阈值 +- **实时预测**: 基于Label Studio ML Backend架构的实时预测服务 + +## 核心算法 + +### 温度阶段检测 +- **上升阶段**: 检测温度超过阈值的持续上升区间 +- **峰值时刻**: 基于梯度变化检测温度峰值点 +- **下降阶段**: 识别温度低于阈值的下降区间 +- **平台期**: 检测温度相对稳定的平台阶段 + +### 异常检测 +- **温度突变**: 基于滑动窗口统计检测异常温度变化 +- **温度骤降**: 识别温度快速下降的异常事件 + +### 阈值优化 +- 支持基于人工标注的自适应阈值调整 +- 动态学习不同实验条件下的最佳参数 + +## 快速开始 + +### 使用Docker运行 (推荐) + +1. 启动ML Backend服务: + +```bash +docker-compose up +``` + +2. 验证服务运行状态: + +```bash +curl http://localhost:9090/health +# 返回: {"status":"healthy","model_version":"temperature_v1.0",...} +``` + +3. 在Label Studio中连接后端: + 进入项目设置 -> Machine Learning -> Add Model,指定URL为 `http://localhost:9090` + +### 从源码构建 + +```bash +docker-compose build +``` + +### 不使用Docker运行 + +```bash +python -m venv temperature-ml-backend +source temperature-ml-backend/bin/activate # Windows: temperature-ml-backend\Scripts\activate +pip install -r requirements.txt +python _wsgi.py +``` + +## 配置参数 + +在 `config.json` 中可以设置以下参数: + +### 温度阈值设置 +- `temp_rise_threshold`: 温度上升阈值 (默认1000.0 eV) +- `temp_fall_threshold`: 温度下降阈值 (默认500.0 eV) +- `gradient_threshold`: 梯度阈值 (默认100.0 eV/ms) +- `anomaly_threshold`: 异常检测阈值 (默认3.0) +- `min_peak_height`: 最小峰值高度 (默认800.0 eV) + +### 数据处理设置 +- `remove_outliers`: 是否移除异常值 +- `outlier_threshold`: 异常值检测阈值 +- `interpolate_missing`: 是否插值缺失值 +- `normalize_data`: 是否标准化数据 + +### 预测设置 +- `detect_rise_phases`: 是否检测上升阶段 +- `detect_peak_moments`: 是否检测峰值时刻 +- `detect_fall_phases`: 是否检测下降阶段 +- `detect_anomalies`: 是否检测异常事件 +- `detect_plateau_phases`: 是否检测平台期 + +## Label Studio配置 + +在Label Studio中使用以下XML配置来支持温度标注: + +```xml + + + + + + +``` + +## 数据格式要求 + +### 输入数据格式 +- CSV文件,包含时间列和温度通道列 +- 时间列名必须为 `time` +- 温度通道列名应包含 `Te`、`temp`、`Ti` 等关键词 +- 支持多通道温度数据 + +### 示例数据 +```csv +time,Te_1,Te_2,Te_3 +0.0,100,150,200 +0.001,120,170,220 +0.002,140,190,240 +... +``` + +## 部署到Kubernetes + +```bash +kubectl apply -f temperature-annotation.yaml +``` + +## 自定义模型 + +可以通过修改 `temperature_predictor.py` 中的算法来自定义温度处理逻辑: + +- 调整阈值检测算法 +- 添加新的温度特征识别方法 +- 自定义异常检测规则 +- 优化梯度计算方法 + +## 测试 + +运行测试确保功能正常: + +```bash +pytest test_temperature_backend.py -v +``` + +## API接口 + +### 健康检查 +``` +GET /health +``` + +### 模型信息 +``` +GET /model/info +``` + +### 预测接口 +``` +POST /predict +``` + +### 训练接口 +``` +POST /fit +``` + +## 故障排除 + +### 常见问题 + +1. **模型加载失败**: 检查依赖包安装和配置文件 +2. **预测结果为空**: 确认输入数据格式和温度通道识别 +3. **阈值设置不当**: 根据实际数据调整阈值参数 +4. **内存不足**: 调整Docker容器的内存限制 + +### 日志查看 + +```bash +# Docker日志 +docker logs temperature-annotation-ml-backend + +# 查看模型日志 +docker exec -it temperature-annotation-ml-backend tail -f /app/logs/model.log +``` + +## 性能优化 + +### 数据处理优化 +- 使用并发加载多个炮号数据 +- 支持数据预处理和缓存 +- 优化异常值检测算法 + +### 预测性能优化 +- 批量处理多个温度通道 +- 并行计算不同特征检测 +- 支持增量学习和模型更新 diff --git a/ml-backends/temperature_annotation/_wsgi.py b/ml-backends/temperature_annotation/_wsgi.py new file mode 100644 index 0000000..60c3782 --- /dev/null +++ b/ml-backends/temperature_annotation/_wsgi.py @@ -0,0 +1,132 @@ +""" +温度标注ML Backend的WSGI应用入口 + +基于Label Studio ML Backend的标准模式实现 +""" + +import os +import argparse +import json +import logging +import logging.config + +# 设置默认日志级别 +log_level = os.getenv("LOG_LEVEL", "INFO") + +logging.config.dictConfig({ + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "standard": { + "format": "[%(asctime)s] [%(levelname)s] [%(name)s::%(funcName)s::%(lineno)d] %(message)s" + } + }, + "handlers": { + "console": { + "class": "logging.StreamHandler", + "level": log_level, + "stream": "ext://sys.stdout", + "formatter": "standard" + } + }, + "root": { + "level": log_level, + "handlers": ["console"], + "propagate": True + } +}) + +from label_studio_ml.api import init_app +from model import TemperatureModel + +_DEFAULT_CONFIG_PATH = os.path.join(os.path.dirname(__file__), 'config.json') + + +def get_kwargs_from_config(config_path=_DEFAULT_CONFIG_PATH): + """从配置文件获取初始化参数""" + if not os.path.exists(config_path): + return dict() + with open(config_path) as f: + config = json.load(f) + assert isinstance(config, dict) + return config + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='温度标注ML Backend') + parser.add_argument( + '-p', '--port', dest='port', type=int, default=9090, + help='Server port') + parser.add_argument( + '--host', dest='host', type=str, default='0.0.0.0', + help='Server host') + parser.add_argument( + '--kwargs', '--with', dest='kwargs', metavar='KEY=VAL', nargs='+', + type=lambda kv: kv.split('='), + help='Additional TemperatureModel initialization kwargs') + parser.add_argument( + '-d', '--debug', dest='debug', action='store_true', + help='Switch debug mode') + parser.add_argument( + '--log-level', dest='log_level', + choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'], default=log_level, + help='Logging level') + parser.add_argument( + '--model-dir', dest='model_dir', default=os.path.dirname(__file__), + help='Directory where models are stored') + parser.add_argument( + '--check', dest='check', action='store_true', + help='Validate model instance before launching server') + parser.add_argument('--basic-auth-user', + default=os.environ.get('ML_SERVER_BASIC_AUTH_USER', None), + help='Basic auth user') + parser.add_argument('--basic-auth-pass', + default=os.environ.get('ML_SERVER_BASIC_AUTH_PASS', None), + help='Basic auth pass') + + args = parser.parse_args() + + if args.log_level: + logging.root.setLevel(args.log_level) + + def isfloat(value): + try: + float(value) + return True + except ValueError: + return False + + def parse_kwargs(): + param = dict() + if args.kwargs: + for k, v in args.kwargs: + if v.isdigit(): + param[k] = int(v) + elif v == 'True' or v == 'true': + param[k] = True + elif v == 'False' or v == 'false': + param[k] = False + elif isfloat(v): + param[k] = float(v) + else: + param[k] = v + return param + + kwargs = get_kwargs_from_config() + if args.kwargs: + kwargs.update(parse_kwargs()) + + if args.check: + print('Check "' + TemperatureModel.__name__ + '" instance creation..') + model = TemperatureModel(**kwargs) + + app = init_app( + model_class=TemperatureModel, + basic_auth_user=args.basic_auth_user, + basic_auth_pass=args.basic_auth_pass + ) + app.run(host=args.host, port=args.port, debug=args.debug) + +else: + # 用于uWSGI + app = init_app(model_class=TemperatureModel) diff --git a/ml-backends/temperature_annotation/config.json b/ml-backends/temperature_annotation/config.json new file mode 100644 index 0000000..21cf748 --- /dev/null +++ b/ml-backends/temperature_annotation/config.json @@ -0,0 +1,28 @@ +{ + "model_version": "temperature_v1.0", + "temperature_thresholds": { + "temp_rise_threshold": 1000.0, + "temp_fall_threshold": 500.0, + "gradient_threshold": 100.0, + "anomaly_threshold": 3.0, + "min_peak_height": 800.0 + }, + "data_processing": { + "remove_outliers": true, + "outlier_threshold": 3.0, + "interpolate_missing": true, + "normalize_data": false + }, + "prediction_settings": { + "detect_rise_phases": true, + "detect_peak_moments": true, + "detect_fall_phases": true, + "detect_anomalies": true, + "detect_plateau_phases": true + }, + "model_optimization": { + "enable_adaptive_thresholds": true, + "learning_rate": 0.05, + "min_training_samples": 10 + } +} diff --git a/ml-backends/temperature_annotation/docker-compose.yml b/ml-backends/temperature_annotation/docker-compose.yml new file mode 100644 index 0000000..0a88d99 --- /dev/null +++ b/ml-backends/temperature_annotation/docker-compose.yml @@ -0,0 +1,22 @@ +version: '3.8' + +services: + ml-backend-temperature: + build: . + container_name: temperature-annotation-ml-backend + ports: + - "9090:9090" + environment: + - MODEL_VERSION=temperature_v1.0 + - LOG_LEVEL=INFO + - WORKERS=1 + - THREADS=8 + volumes: + - ./models:/app/models + - ./data:/app/data + restart: unless-stopped + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:9090/health"] + interval: 30s + timeout: 10s + retries: 3 diff --git a/ml-backends/temperature_annotation/model.py b/ml-backends/temperature_annotation/model.py new file mode 100644 index 0000000..e99443c --- /dev/null +++ b/ml-backends/temperature_annotation/model.py @@ -0,0 +1,264 @@ +""" +温度标注ML Backend模型 + +基于Label Studio ML Backend架构,实现温度标注的AI预测服务 +""" + +from typing import List, Dict, Optional +from label_studio_ml.model import LabelStudioMLBase +from label_studio_ml.response import ModelResponse +import prediction +import utils +import numpy as np +import pandas as pd +import logging +import time +from temperature_predictor import TemperaturePredictor + +logger = logging.getLogger(__name__) + + +class TemperatureModel(LabelStudioMLBase): + """温度标注ML Backend模型""" + + def setup(self): + """配置模型参数""" + self.set("model_version", "temperature_v1.0") + + # 初始化温度预测器 + self.predictor = TemperaturePredictor() + + # 记录模型统计信息 + self.set('training_count', 0) + self.set('last_training_time', time.time()) + + logger.info("温度标注模型初始化完成") + + def get_data(self, tasks: List[Dict]) -> Dict: + """获取任务数据""" + urls = {} + for task in tasks: + data = task['data'] + urls[data['shot']] = data['csv'] + + logger.info(f"开始加载 {len(urls)} 个炮号的数据") + data_dict = utils.load_data(urls) + return data_dict + + def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) -> ModelResponse: + """执行温度预测""" + logger.info(f'运行温度预测,任务数量: {len(tasks)}') + + try: + data_dict = self.get_data(tasks) + model_preds = [] + + for shot, data in data_dict.items(): + try: + # 验证数据 + if not utils.validate_temperature_data(data): + logger.warning(f"炮号 {shot} 数据验证失败,跳过") + continue + + # 预处理数据 + data_processed = utils.preprocess_temperature_data(data) + + # 执行预测 + preds = prediction.convert_to_labelstudio_form( + self.predictor.user_predict(data_processed), + 'temperature_model_v1' + ) + + if preds: + model_preds.extend(preds) + logger.info(f'炮号 {shot} 预测成功,生成 {len(preds)} 个预测结果') + else: + logger.info(f'炮号 {shot} 未生成预测结果') + + except Exception as e: + logger.error(f'炮号 {shot} 预测失败: {e}') + continue + + logger.info(f"温度预测完成,总共生成 {len(model_preds)} 个预测结果") + + except Exception as e: + logger.error(f"温度预测过程中出错: {e}") + model_preds = [] + + return ModelResponse(predictions=model_preds) + + def fit(self, event, data, **kwargs): + """ + 温度标注模型的训练/更新逻辑 + + 每次创建或更新标注时调用此方法来优化模型参数 + + Args: + event: 事件类型 ('ANNOTATION_CREATED', 'ANNOTATION_UPDATED', 'START_TRAINING') + data: 来自webhook的数据载荷 + """ + logger.info(f'接收到温度标注训练事件: {event}') + + try: + # 获取之前的模型参数 + old_temp_rise_threshold = self.get('temp_rise_threshold', 1000.0) + old_temp_fall_threshold = self.get('temp_fall_threshold', 500.0) + old_gradient_threshold = self.get('gradient_threshold', 100.0) + old_model_version = self.get('model_version', 'temperature_v1.0') + + logger.info(f'当前模型参数: 上升阈值={old_temp_rise_threshold}, 下降阈值={old_temp_fall_threshold}, 梯度阈值={old_gradient_threshold}') + + # 处理标注数据以优化参数 + if event in ['ANNOTATION_CREATED', 'ANNOTATION_UPDATED'] and data: + try: + # 解析标注数据 + annotation_data = self._parse_annotation_data(data) + + if annotation_data: + # 基于标注数据调整阈值参数 + new_thresholds = self._optimize_thresholds(annotation_data) + + # 更新预测器参数 + if new_thresholds: + self.predictor.update_thresholds(new_thresholds) + + # 保存到缓存 + self.set('temp_rise_threshold', self.predictor.temp_rise_threshold) + self.set('temp_fall_threshold', self.predictor.temp_fall_threshold) + self.set('gradient_threshold', self.predictor.gradient_threshold) + + # 更新模型版本 + new_version = f"temperature_v1.0_{int(time.time())}" + self.set('model_version', new_version) + + logger.info(f'参数已更新: 上升阈值={self.predictor.temp_rise_threshold}, 下降阈值={self.predictor.temp_fall_threshold}, 梯度阈值={self.predictor.gradient_threshold}') + logger.info(f'模型版本更新为: {new_version}') + + except Exception as e: + logger.error(f'处理标注数据时出错: {e}') + + # 记录训练统计信息 + training_count = self.get('training_count', 0) + 1 + self.set('training_count', training_count) + self.set('last_training_time', time.time()) + + logger.info(f'温度标注模型训练完成,总训练次数: {training_count}') + + except Exception as e: + logger.error(f'温度标注模型训练过程中出错: {e}') + + def _parse_annotation_data(self, data): + """解析标注数据,提取温度相关信息""" + try: + if 'annotation' in data and 'result' in data['annotation']: + results = data['annotation']['result'] + temp_annotations = [] + + for result in results: + if (result.get('from_name') == 'temperature_events' and + result.get('type') == 'timeserieslabels'): + + value = result.get('value', {}) + temp_annotations.append({ + 'label': value.get('timeserieslabels', [''])[0], + 'start': value.get('start'), + 'end': value.get('end'), + 'shot': data.get('task', {}).get('data', {}).get('shot') + }) + + logger.info(f"解析到 {len(temp_annotations)} 个温度标注") + return temp_annotations if temp_annotations else None + + except Exception as e: + logger.error(f'解析标注数据失败: {e}') + return None + + def _optimize_thresholds(self, annotation_data): + """基于标注数据优化阈值参数""" + try: + # 统计不同类型标注的特征 + rise_annotations = [ann for ann in annotation_data if '上升' in ann['label']] + fall_annotations = [ann for ann in annotation_data if '下降' in ann['label']] + peak_annotations = [ann for ann in annotation_data if '峰值' in ann['label']] + + new_thresholds = {} + + # 基于标注数据调整阈值(简化的自适应逻辑) + if rise_annotations: + # 如果有上升阶段标注,可以适当降低上升阈值以提高敏感度 + current_rise = self.get('temp_rise_threshold', 1000.0) + new_thresholds['temp_rise_threshold'] = max(800.0, current_rise * 0.95) + + if fall_annotations: + # 如果有下降阶段标注,可以适当调整下降阈值 + current_fall = self.get('temp_fall_threshold', 500.0) + new_thresholds['temp_fall_threshold'] = max(400.0, current_fall * 0.95) + + if peak_annotations: + # 如果有峰值标注,可以调整梯度阈值 + current_gradient = self.get('gradient_threshold', 100.0) + new_thresholds['gradient_threshold'] = max(50.0, current_gradient * 0.9) + + if new_thresholds: + logger.info(f"基于标注数据优化阈值: {new_thresholds}") + + return new_thresholds if new_thresholds else None + + except Exception as e: + logger.error(f'优化阈值参数失败: {e}') + return None + + def get_model_info(self) -> Dict: + """获取模型信息""" + return { + 'model_version': self.get('model_version', 'temperature_v1.0'), + 'training_count': self.get('training_count', 0), + 'last_training_time': self.get('last_training_time', 0), + 'current_thresholds': self.predictor.get_current_thresholds() + } + + def reset_model(self): + """重置模型参数到默认值""" + try: + # 重置预测器参数 + self.predictor = TemperaturePredictor() + + # 重置模型版本 + self.set('model_version', 'temperature_v1.0') + self.set('training_count', 0) + self.set('last_training_time', time.time()) + + logger.info("模型参数已重置到默认值") + + except Exception as e: + logger.error(f"重置模型参数失败: {e}") + + def update_model_config(self, config: Dict): + """更新模型配置""" + try: + if 'thresholds' in config: + self.predictor.update_thresholds(config['thresholds']) + logger.info(f"模型阈值已更新: {config['thresholds']}") + + if 'model_version' in config: + self.set('model_version', config['model_version']) + logger.info(f"模型版本已更新: {config['model_version']}") + + except Exception as e: + logger.error(f"更新模型配置失败: {e}") + + def health_check(self) -> Dict: + """健康检查""" + try: + return { + 'status': 'healthy', + 'model_version': self.get('model_version', 'unknown'), + 'predictor_initialized': hasattr(self, 'predictor'), + 'last_training': self.get('last_training_time', 0), + 'training_count': self.get('training_count', 0) + } + except Exception as e: + return { + 'status': 'unhealthy', + 'error': str(e) + } diff --git a/ml-backends/temperature_annotation/prediction.py b/ml-backends/temperature_annotation/prediction.py new file mode 100644 index 0000000..bdef299 --- /dev/null +++ b/ml-backends/temperature_annotation/prediction.py @@ -0,0 +1,239 @@ +""" +温度预测工具模块 + +提供温度预测结果数据结构、预测器抽象基类和格式转换功能 +""" + +import abc +import pandas as pd +import numpy as np +from typing import List, Optional + + +class Prediction: + """温度预测结果数据结构""" + + def __init__(self, label_group: str, label: str, start: float, end: Optional[float] = None, score: Optional[float] = None): + """ + 初始化温度预测结果 + + Args: + label_group: 标签组名称 + label: 标签名称 + start: 开始时间 + end: 结束时间(可选,None表示时间点) + score: 置信度分数(可选) + """ + self.label_group = label_group + self.label = label + self.start = start + self.end = end + self.score = score + + def __repr__(self): + """字符串表示""" + return f"" + + +class BasePredictor(abc.ABC): + """温度预测器抽象基类""" + + @abc.abstractmethod + def user_predict(self, task_data: pd.DataFrame) -> List[Prediction]: + """ + 执行温度预测 + + Args: + task_data: 任务数据,包含时间序列温度数据 + + Returns: + 预测结果列表 + """ + pass + + +def start_end_time_1D(predict_result, threshold: float, postive: bool = True): + """ + 1D温度数据阈值检测 + + Args: + predict_result: 温度数据数组 + threshold: 阈值 + postive: True表示检测大于阈值的区间,False表示检测小于阈值的区间 + + Returns: + 区间列表,每个元素为(start_index, end_index)的元组 + """ + predict_result = np.array(predict_result) + + if postive: + mask = predict_result > threshold + else: + mask = predict_result < threshold + + diff = np.diff(mask.astype(int)) + starts = np.where(diff == 1)[0] + 1 + ends = np.where(diff == -1)[0] + 1 + + if mask[0]: + starts = np.insert(starts, 0, 0) + if mask[-1]: + ends = np.append(ends, len(predict_result) - 1) + + return list(zip(starts, ends)) + + +def start_end_time(predict_result, threshold, postive: bool = True, all_dim: bool = True, time_axis: int = -1): + """ + 多维温度数据阈值检测 + + Args: + predict_result: 温度数据数组 + threshold: 阈值 + postive: True表示检测大于阈值的区间,False表示检测小于阈值的区间 + all_dim: True表示所有维度必须同时满足阈值条件,False表示对每个通道逐一判断 + time_axis: 时间轴索引 + + Returns: + 区间列表,每个元素为(start_index, end_index)的元组 + """ + predict_result = np.array(predict_result) + + if postive: + mask = predict_result > threshold + else: + mask = predict_result < threshold + + if all_dim: + # 所有维度同时满足条件 + mask = np.all(mask, axis=tuple(i for i in range(mask.ndim) if i != time_axis)) + + diff = np.diff(mask.astype(int)) + starts = np.where(diff == 1)[0] + 1 + ends = np.where(diff == -1)[0] + 1 + + if mask[0]: + starts = np.insert(starts, 0, 0) + if mask[-1]: + ends = np.append(ends, len(mask) - 1) + + return list(zip(starts, ends)) + + +def convert_to_labelstudio_form(pred_results: List[Prediction], model_version: str = "temperature_v1") -> List[dict]: + """ + 转换为Label Studio格式 + + Args: + pred_results: 预测结果列表 + model_version: 模型版本 + + Returns: + Label Studio格式的预测结果 + """ + LS_result_list = [] + + for p in pred_results: + if p.end is None: + p.end = p.start + + LS_result_list.append({ + "from_name": p.label_group, + "to_name": "ts", + "type": "timeserieslabels", + "value": { + "start": p.start, + "end": p.end, + "timeserieslabels": [p.label] + } + }) + + return [{ + "model_version": model_version, + "result": LS_result_list + }] if LS_result_list else [] + + +def detect_temperature_anomalies(temp_data: np.ndarray, time_data: np.ndarray, + threshold: float = 3.0, window_size: int = 10): + """ + 检测温度异常事件 + + Args: + temp_data: 温度数据 + time_data: 时间数据 + threshold: 异常检测阈值(标准差倍数) + window_size: 滑动窗口大小 + + Returns: + 异常事件列表,每个元素为(start_time, end_time, anomaly_type)的元组 + """ + anomalies = [] + + # 计算滑动窗口的统计量 + for i in range(window_size, len(temp_data)): + window_data = temp_data[i-window_size:i] + mean_temp = np.mean(window_data) + std_temp = np.std(window_data) + + current_temp = temp_data[i] + + # 检测异常(超出阈值范围) + if abs(current_temp - mean_temp) > threshold * std_temp: + anomaly_type = "温度突变" if current_temp > mean_temp else "温度骤降" + anomalies.append((time_data[i], time_data[i], anomaly_type)) + + return anomalies + + +def calculate_temperature_gradient(temp_data: np.ndarray, time_data: np.ndarray, + window_size: int = 5): + """ + 计算温度梯度 + + Args: + temp_data: 温度数据 + time_data: 时间数据 + window_size: 梯度计算窗口大小 + + Returns: + 梯度数据数组 + """ + gradient = np.zeros_like(temp_data) + + for i in range(window_size, len(temp_data) - window_size): + # 使用中心差分计算梯度 + temp_diff = temp_data[i + window_size] - temp_data[i - window_size] + time_diff = time_data[i + window_size] - time_data[i - window_size] + + if time_diff != 0: + gradient[i] = temp_diff / time_diff + + return gradient + + +def find_temperature_peaks(temp_data: np.ndarray, time_data: np.ndarray, + gradient_threshold: float = 100.0, min_peak_height: float = 0.0): + """ + 查找温度峰值 + + Args: + temp_data: 温度数据 + time_data: 时间数据 + gradient_threshold: 梯度阈值 + min_peak_height: 最小峰值高度 + + Returns: + 峰值时间列表 + """ + peaks = [] + gradient = calculate_temperature_gradient(temp_data, time_data) + + for i in range(1, len(gradient) - 1): + # 梯度从正变负,且温度值足够高 + if (gradient[i-1] > gradient_threshold and + gradient[i+1] < -gradient_threshold and + temp_data[i] > min_peak_height): + peaks.append(time_data[i]) + + return peaks diff --git a/ml-backends/temperature_annotation/requirements-base.txt b/ml-backends/temperature_annotation/requirements-base.txt new file mode 100644 index 0000000..97f92d5 --- /dev/null +++ b/ml-backends/temperature_annotation/requirements-base.txt @@ -0,0 +1,2 @@ +label-studio-ml>=1.0.0 +gunicorn>=20.1.0 diff --git a/ml-backends/temperature_annotation/requirements-test.txt b/ml-backends/temperature_annotation/requirements-test.txt new file mode 100644 index 0000000..ba4bd9d --- /dev/null +++ b/ml-backends/temperature_annotation/requirements-test.txt @@ -0,0 +1,3 @@ +pytest>=7.0.0 +pytest-cov>=4.0.0 +pytest-mock>=3.10.0 diff --git a/ml-backends/temperature_annotation/requirements.txt b/ml-backends/temperature_annotation/requirements.txt new file mode 100644 index 0000000..4000c6d --- /dev/null +++ b/ml-backends/temperature_annotation/requirements.txt @@ -0,0 +1,6 @@ +label-studio-ml>=1.0.0 +pandas>=1.3.0 +numpy>=1.21.0 +requests>=2.25.0 +scikit-learn>=1.0.0 +scipy>=1.9.0 diff --git a/ml-backends/temperature_annotation/temperature_predictor.py b/ml-backends/temperature_annotation/temperature_predictor.py new file mode 100644 index 0000000..2d32974 --- /dev/null +++ b/ml-backends/temperature_annotation/temperature_predictor.py @@ -0,0 +1,374 @@ +""" +温度预测器模块 + +实现温度特征识别算法,包括温度上升、下降、峰值检测等 +""" + +import pandas as pd +import numpy as np +from typing import List +from prediction import BasePredictor, Prediction +from prediction import start_end_time_1D, detect_temperature_anomalies, find_temperature_peaks +import logging + +logger = logging.getLogger(__name__) + + +class TemperaturePredictor(BasePredictor): + """温度预测器类""" + + def __init__(self, + label_group: str = 'temperature_events', + temp_rise_threshold: float = 1000.0, # eV + temp_fall_threshold: float = 500.0, # eV + gradient_threshold: float = 100.0, # eV/ms + anomaly_threshold: float = 3.0, # 异常检测阈值 + min_peak_height: float = 800.0): # 最小峰值高度 + """ + 初始化温度预测器 + + Args: + label_group: 标签组名称 + temp_rise_threshold: 温度上升阈值 + temp_fall_threshold: 温度下降阈值 + gradient_threshold: 梯度阈值 + anomaly_threshold: 异常检测阈值 + min_peak_height: 最小峰值高度 + """ + super().__init__() + self.label_group = label_group + self.temp_rise_threshold = temp_rise_threshold + self.temp_fall_threshold = temp_fall_threshold + self.gradient_threshold = gradient_threshold + self.anomaly_threshold = anomaly_threshold + self.min_peak_height = min_peak_height + + logger.info(f"温度预测器初始化完成,标签组: {label_group}") + logger.info(f"阈值设置: 上升={temp_rise_threshold}, 下降={temp_fall_threshold}, 梯度={gradient_threshold}") + + def user_predict(self, task_data: pd.DataFrame) -> List[Prediction]: + """ + 执行温度预测 + + Args: + task_data: 任务数据,包含时间序列温度数据 + + Returns: + 预测结果列表 + """ + predictions = [] + + try: + # 验证数据 + if task_data is None or task_data.empty: + logger.warning("任务数据为空") + return predictions + + if 'time' not in task_data.columns: + logger.warning("缺少时间列") + return predictions + + time = task_data['time'].values + + # 检测多个温度通道 + temp_channels = self._extract_temperature_channels(task_data) + + if not temp_channels: + logger.warning("未找到温度通道") + return predictions + + logger.info(f"开始分析 {len(temp_channels)} 个温度通道: {temp_channels}") + + for channel in temp_channels: + if channel not in task_data.columns: + continue + + temp_data = np.array(task_data[channel]) + + # 跳过全为NaN的通道 + if np.all(np.isnan(temp_data)): + logger.warning(f"通道 {channel} 全为NaN,跳过") + continue + + # 处理NaN值 + temp_data = self._handle_missing_values(temp_data) + + # 1. 检测温度上升阶段 + rise_predictions = self._detect_rise_phases(time, temp_data, channel) + predictions.extend(rise_predictions) + + # 2. 检测温度峰值时刻 + peak_predictions = self._detect_peak_moments(time, temp_data, channel) + predictions.extend(peak_predictions) + + # 3. 检测温度下降阶段 + fall_predictions = self._detect_fall_phases(time, temp_data, channel) + predictions.extend(fall_predictions) + + # 4. 检测异常温度事件 + anomaly_predictions = self._detect_anomalies(time, temp_data, channel) + predictions.extend(anomaly_predictions) + + # 5. 检测温度平台期 + plateau_predictions = self._detect_plateau_phases(time, temp_data, channel) + predictions.extend(plateau_predictions) + + logger.info(f"温度预测完成,生成了 {len(predictions)} 个预测结果") + + except Exception as e: + logger.error(f"温度预测过程中出错: {str(e)}") + + return predictions + + def _extract_temperature_channels(self, task_data: pd.DataFrame) -> List[str]: + """提取温度通道列名""" + temp_channels = [] + for col in task_data.columns: + col_lower = col.lower() + if any(pattern in col_lower for pattern in ['te', 'temp', 'ti', 'temperature']): + temp_channels.append(col) + return temp_channels + + def _handle_missing_values(self, temp_data: np.ndarray) -> np.ndarray: + """处理缺失值""" + # 使用前向填充处理NaN值 + temp_data_clean = temp_data.copy() + for i in range(1, len(temp_data_clean)): + if np.isnan(temp_data_clean[i]): + temp_data_clean[i] = temp_data_clean[i-1] + + # 如果第一个值是NaN,使用下一个非NaN值 + if np.isnan(temp_data_clean[0]): + for i in range(1, len(temp_data_clean)): + if not np.isnan(temp_data_clean[i]): + temp_data_clean[0] = temp_data_clean[i] + break + + return temp_data_clean + + def _detect_rise_phases(self, time: np.ndarray, temp_data: np.ndarray, channel: str) -> List[Prediction]: + """检测温度上升阶段""" + predictions = [] + + try: + # 检测温度上升区间 + rise_intervals = start_end_time_1D( + temp_data, self.temp_rise_threshold, postive=True + ) + + for start_idx, end_idx in rise_intervals: + if start_idx < end_idx: # 确保区间有效 + start_time = time[start_idx] + end_time = time[end_idx] + + # 计算上升阶段的特征 + rise_temp_data = temp_data[start_idx:end_idx+1] + max_temp = np.max(rise_temp_data) + avg_temp = np.mean(rise_temp_data) + + # 根据温度特征调整标签 + if max_temp > self.temp_rise_threshold * 1.5: + label = f"{channel}_快速上升" + else: + label = f"{channel}_上升阶段" + + predictions.append(Prediction( + self.label_group, + label, + start=start_time, + end=end_time, + score=min(0.9, (max_temp - self.temp_rise_threshold) / self.temp_rise_threshold) + )) + + logger.info(f"通道 {channel} 检测到 {len(predictions)} 个上升阶段") + + except Exception as e: + logger.error(f"检测通道 {channel} 上升阶段时出错: {str(e)}") + + return predictions + + def _detect_peak_moments(self, time: np.ndarray, temp_data: np.ndarray, channel: str) -> List[Prediction]: + """检测温度峰值时刻""" + predictions = [] + + try: + # 查找温度峰值 + peak_times = find_temperature_peaks( + temp_data, time, + gradient_threshold=self.gradient_threshold, + min_peak_height=self.min_peak_height + ) + + for peak_time in peak_times: + # 找到峰值对应的时间索引 + peak_idx = np.argmin(np.abs(time - peak_time)) + peak_temp = temp_data[peak_idx] + + # 根据峰值高度调整标签 + if peak_temp > self.temp_rise_threshold * 2: + label = f"{channel}_极高峰值" + elif peak_temp > self.temp_rise_threshold * 1.5: + label = f"{channel}_高峰值" + else: + label = f"{channel}_峰值时刻" + + predictions.append(Prediction( + self.label_group, + label, + start=peak_time, + end=None, # 时间点标注 + score=min(0.95, peak_temp / (self.temp_rise_threshold * 2)) + )) + + logger.info(f"通道 {channel} 检测到 {len(predictions)} 个峰值时刻") + + except Exception as e: + logger.error(f"检测通道 {channel} 峰值时刻时出错: {str(e)}") + + return predictions + + def _detect_fall_phases(self, time: np.ndarray, temp_data: np.ndarray, channel: str) -> List[Prediction]: + """检测温度下降阶段""" + predictions = [] + + try: + # 检测温度下降区间 + fall_intervals = start_end_time_1D( + temp_data, self.temp_fall_threshold, postive=False + ) + + for start_idx, end_idx in fall_intervals: + if start_idx < end_idx: # 确保区间有效 + start_time = time[start_idx] + end_time = time[end_idx] + + # 计算下降阶段的特征 + fall_temp_data = temp_data[start_idx:end_idx+1] + min_temp = np.min(fall_temp_data) + avg_temp = np.mean(fall_temp_data) + + # 根据温度特征调整标签 + if min_temp < self.temp_fall_threshold * 0.5: + label = f"{channel}_快速下降" + else: + label = f"{channel}_下降阶段" + + predictions.append(Prediction( + self.label_group, + label, + start=start_time, + end=end_time, + score=min(0.9, (self.temp_fall_threshold - min_temp) / self.temp_fall_threshold) + )) + + logger.info(f"通道 {channel} 检测到 {len(predictions)} 个下降阶段") + + except Exception as e: + logger.error(f"检测通道 {channel} 下降阶段时出错: {str(e)}") + + return predictions + + def _detect_anomalies(self, time: np.ndarray, temp_data: np.ndarray, channel: str) -> List[Prediction]: + """检测异常温度事件""" + predictions = [] + + try: + # 检测温度异常 + anomalies = detect_temperature_anomalies( + temp_data, time, + threshold=self.anomaly_threshold + ) + + for start_time, end_time, anomaly_type in anomalies: + predictions.append(Prediction( + self.label_group, + f"{channel}_{anomaly_type}", + start=start_time, + end=end_time, + score=0.8 # 异常检测的置信度 + )) + + logger.info(f"通道 {channel} 检测到 {len(predictions)} 个异常事件") + + except Exception as e: + logger.error(f"检测通道 {channel} 异常事件时出错: {str(e)}") + + return predictions + + def _detect_plateau_phases(self, time: np.ndarray, temp_data: np.ndarray, channel: str) -> List[Prediction]: + """检测温度平台期""" + predictions = [] + + try: + # 计算温度梯度 + gradient = np.gradient(temp_data) + + # 检测梯度较小的区间(平台期) + plateau_threshold = self.gradient_threshold * 0.1 # 平台期的梯度阈值 + plateau_mask = np.abs(gradient) < plateau_threshold + + # 找到连续的平台期区间 + plateau_intervals = start_end_time_1D( + plateau_mask.astype(float), 0.5, postive=True + ) + + for start_idx, end_idx in plateau_intervals: + if end_idx - start_idx > 10: # 平台期至少持续10个采样点 + start_time = time[start_idx] + end_time = time[end_idx] + + # 计算平台期的特征 + plateau_temp_data = temp_data[start_idx:end_idx+1] + avg_temp = np.mean(plateau_temp_data) + temp_std = np.std(plateau_temp_data) + + # 根据温度水平调整标签 + if avg_temp > self.temp_rise_threshold: + label = f"{channel}_高温平台期" + elif avg_temp > self.temp_fall_threshold: + label = f"{channel}_中温平台期" + else: + label = f"{channel}_低温平台期" + + # 根据稳定性计算置信度 + stability_score = max(0.5, 1.0 - temp_std / avg_temp) if avg_temp > 0 else 0.5 + + predictions.append(Prediction( + self.label_group, + label, + start=start_time, + end=end_time, + score=stability_score + )) + + logger.info(f"通道 {channel} 检测到 {len(predictions)} 个平台期") + + except Exception as e: + logger.error(f"检测通道 {channel} 平台期时出错: {str(e)}") + + return predictions + + def update_thresholds(self, new_thresholds: dict): + """更新阈值参数""" + if 'temp_rise_threshold' in new_thresholds: + self.temp_rise_threshold = new_thresholds['temp_rise_threshold'] + if 'temp_fall_threshold' in new_thresholds: + self.temp_fall_threshold = new_thresholds['temp_fall_threshold'] + if 'gradient_threshold' in new_thresholds: + self.gradient_threshold = new_thresholds['gradient_threshold'] + if 'anomaly_threshold' in new_thresholds: + self.anomaly_threshold = new_thresholds['anomaly_threshold'] + if 'min_peak_height' in new_thresholds: + self.min_peak_height = new_thresholds['min_peak_height'] + + logger.info(f"阈值参数已更新: {new_thresholds}") + + def get_current_thresholds(self) -> dict: + """获取当前阈值参数""" + return { + 'temp_rise_threshold': self.temp_rise_threshold, + 'temp_fall_threshold': self.temp_fall_threshold, + 'gradient_threshold': self.gradient_threshold, + 'anomaly_threshold': self.anomaly_threshold, + 'min_peak_height': self.min_peak_height + } diff --git a/ml-backends/temperature_annotation/test_e2e.py b/ml-backends/temperature_annotation/test_e2e.py new file mode 100644 index 0000000..fb1eb39 --- /dev/null +++ b/ml-backends/temperature_annotation/test_e2e.py @@ -0,0 +1,635 @@ +""" +温度标注ML Backend的端到端测试 + +测试完整的ML Backend服务流程,包括: +- 服务启动和健康检查 +- 数据加载和预处理 +- 温度预测和结果转换 +- API接口响应 +- 模型训练和参数更新 +""" + +import pytest +import requests +import time +import json +import pandas as pd +import numpy as np +import tempfile +import os +import subprocess +import signal +from unittest.mock import patch, MagicMock +import logging + +# 设置日志 +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# 测试配置 +TEST_CONFIG = { + 'host': 'localhost', + 'port': 9090, + 'base_url': 'http://localhost:9090', + 'timeout': 30, + 'retry_attempts': 3 +} + + +class TestTemperatureBackendE2E: + """温度标注ML Backend端到端测试类""" + + @classmethod + def setup_class(cls): + """测试类初始化,启动ML Backend服务""" + cls.service_process = None + cls.service_url = f"{TEST_CONFIG['base_url']}" + + # 检查服务是否已经在运行 + if cls._is_service_running(): + logger.info("ML Backend服务已在运行") + return + + # 启动服务 + cls._start_service() + + # 等待服务启动 + cls._wait_for_service() + + @classmethod + def teardown_class(cls): + """测试类清理,停止ML Backend服务""" + if cls.service_process: + cls._stop_service() + + @classmethod + def _is_service_running(cls): + """检查服务是否在运行""" + try: + response = requests.get(f"{cls.service_url}/health", timeout=5) + return response.status_code == 200 + except: + return False + + @classmethod + def _start_service(cls): + """启动ML Backend服务""" + try: + logger.info("启动温度标注ML Backend服务...") + + # 使用Python直接启动服务 + cmd = [ + 'python', '_wsgi.py', + '--host', TEST_CONFIG['host'], + '--port', str(TEST_CONFIG['port']), + '--debug' + ] + + cls.service_process = subprocess.Popen( + cmd, + cwd=os.path.dirname(__file__), + stdout=subprocess.PIPE, + stderr=subprocess.PIPE + ) + + logger.info(f"服务启动命令: {' '.join(cmd)}") + logger.info(f"服务进程ID: {cls.service_process.pid}") + + except Exception as e: + logger.error(f"启动服务失败: {e}") + raise + + @classmethod + def _stop_service(cls): + """停止ML Backend服务""" + if cls.service_process: + logger.info("停止ML Backend服务...") + try: + cls.service_process.terminate() + cls.service_process.wait(timeout=10) + except subprocess.TimeoutExpired: + cls.service_process.kill() + cls.service_process.wait() + finally: + cls.service_process = None + + @classmethod + def _wait_for_service(cls): + """等待服务启动完成""" + logger.info("等待服务启动...") + max_wait = 60 # 最大等待60秒 + wait_interval = 2 + + for i in range(0, max_wait, wait_interval): + if cls._is_service_running(): + logger.info(f"服务启动成功,耗时 {i} 秒") + return + time.sleep(wait_interval) + + raise TimeoutError("服务启动超时") + + def _make_request(self, method, endpoint, data=None, headers=None): + """发送HTTP请求""" + url = f"{self.service_url}{endpoint}" + + for attempt in range(TEST_CONFIG['retry_attempts']): + try: + if method.upper() == 'GET': + response = requests.get(url, headers=headers, timeout=TEST_CONFIG['timeout']) + elif method.upper() == 'POST': + response = requests.post(url, json=data, headers=headers, timeout=TEST_CONFIG['timeout']) + else: + raise ValueError(f"不支持的HTTP方法: {method}") + + return response + + except requests.exceptions.RequestException as e: + if attempt == TEST_CONFIG['retry_attempts'] - 1: + raise + logger.warning(f"请求失败,重试 {attempt + 1}/{TEST_CONFIG['retry_attempts']}: {e}") + time.sleep(1) + + def test_01_service_health(self): + """测试服务健康状态""" + logger.info("测试服务健康状态...") + + response = self._make_request('GET', '/health') + + assert response.status_code == 200 + health_data = response.json() + + assert 'status' in health_data + assert health_data['status'] == 'healthy' + assert 'model_version' in health_data + assert 'predictor_initialized' in health_data + + logger.info(f"服务健康状态: {health_data}") + + def test_02_model_info(self): + """测试模型信息接口""" + logger.info("测试模型信息接口...") + + response = self._make_request('GET', '/model/info') + + assert response.status_code == 200 + model_info = response.json() + + assert 'model_version' in model_info + assert 'training_count' in model_info + assert 'last_training_time' in model_info + assert 'current_thresholds' in model_info + + # 验证阈值参数 + thresholds = model_info['current_thresholds'] + assert 'temp_rise_threshold' in thresholds + assert 'temp_fall_threshold' in thresholds + assert 'gradient_threshold' in thresholds + + logger.info(f"模型信息: {model_info}") + + def test_03_prediction_with_sample_data(self): + """测试使用样本数据的预测功能""" + logger.info("测试使用样本数据的预测功能...") + + # 创建样本温度数据 + sample_data = self._create_sample_temperature_data() + + # 准备预测请求 + prediction_request = { + "tasks": [ + { + "data": { + "shot": 12345, + "csv": "sample_temperature_data.csv" + } + } + ] + } + + # 模拟数据加载 + with patch('utils.load_data') as mock_load: + mock_load.return_value = {12345: sample_data} + + response = self._make_request('POST', '/predict', data=prediction_request) + + assert response.status_code == 200 + prediction_result = response.json() + + # 验证预测结果结构 + assert 'predictions' in prediction_result + predictions = prediction_result['predictions'] + + if len(predictions) > 0: + # 验证预测结果格式 + first_prediction = predictions[0] + assert 'model_version' in first_prediction + assert 'result' in first_prediction + + # 验证结果内容 + results = first_prediction['result'] + for result in results: + assert 'from_name' in result + assert 'to_name' in result + assert 'type' in result + assert 'value' in result + assert result['type'] == 'timeserieslabels' + + logger.info(f"预测成功,生成了 {len(predictions)} 个预测结果") + else: + logger.info("预测完成,但未生成预测结果(可能是阈值设置过高)") + + def test_04_prediction_with_realistic_data(self): + """测试使用真实场景数据的预测功能""" + logger.info("测试使用真实场景数据的预测功能...") + + # 创建更真实的温度数据(模拟托卡马克实验) + realistic_data = self._create_realistic_temperature_data() + + prediction_request = { + "tasks": [ + { + "data": { + "shot": 240830001, + "csv": "realistic_temperature_data.csv" + } + } + ] + } + + with patch('utils.load_data') as mock_load: + mock_load.return_value = {240830001: realistic_data} + + response = self._make_request('POST', '/predict', data=prediction_request) + + assert response.status_code == 200 + prediction_result = response.json() + + # 验证预测结果 + assert 'predictions' in prediction_result + predictions = prediction_result['predictions'] + + if len(predictions) > 0: + # 分析预测结果类型 + result_types = set() + for pred in predictions: + for result in pred['result']: + label = result['value']['timeserieslabels'][0] + result_types.add(label) + + logger.info(f"检测到的温度事件类型: {result_types}") + + # 验证是否检测到预期的温度事件 + expected_events = ['上升阶段', '峰值时刻', '下降阶段', '平台期'] + detected_events = [event for event in expected_events if any(event in label for label in result_types)] + + logger.info(f"检测到的事件: {detected_events}") + assert len(detected_events) > 0, "应该检测到至少一种温度事件" + + def test_05_model_training(self): + """测试模型训练功能""" + logger.info("测试模型训练功能...") + + # 准备训练数据(模拟人工标注) + training_data = { + "event": "ANNOTATION_CREATED", + "annotation": { + "id": 1, + "result": [ + { + "from_name": "temperature_events", + "type": "timeserieslabels", + "value": { + "start": 2.0, + "end": 4.0, + "timeserieslabels": ["Te_1_上升阶段"] + } + } + ] + }, + "task": { + "id": 1, + "data": { + "shot": 240830001 + } + } + } + + # 发送训练请求 + response = self._make_request('POST', '/fit', data=training_data) + + # 训练接口通常返回200状态码 + assert response.status_code in [200, 204] + + # 验证模型参数是否更新 + time.sleep(2) # 等待训练完成 + + model_info_response = self._make_request('GET', '/model/info') + assert model_info_response.status_code == 200 + + updated_model_info = model_info_response.json() + assert updated_model_info['training_count'] > 0 + + logger.info(f"模型训练完成,训练次数: {updated_model_info['training_count']}") + + def test_06_threshold_optimization(self): + """测试阈值优化功能""" + logger.info("测试阈值优化功能...") + + # 获取当前阈值 + initial_response = self._make_request('GET', '/model/info') + initial_thresholds = initial_response.json()['current_thresholds'] + + logger.info(f"初始阈值: {initial_thresholds}") + + # 发送多个标注数据来触发阈值优化 + training_events = [ + { + "event": "ANNOTATION_CREATED", + "annotation": { + "id": 2, + "result": [ + { + "from_name": "temperature_events", + "type": "timeserieslabels", + "value": { + "start": 1.0, + "end": 3.0, + "timeserieslabels": ["Te_1_峰值时刻"] + } + } + ] + }, + "task": {"id": 2, "data": {"shot": 240830002}} + }, + { + "event": "ANNOTATION_CREATED", + "annotation": { + "id": 3, + "result": [ + { + "from_name": "temperature_events", + "type": "timeserieslabels", + "value": { + "start": 5.0, + "end": 7.0, + "timeserieslabels": ["Te_1_下降阶段"] + } + } + ] + }, + "task": {"id": 3, "data": {"shot": 240830003}} + } + ] + + # 发送训练事件 + for event_data in training_events: + response = self._make_request('POST', '/fit', data=event_data) + assert response.status_code in [200, 204] + + # 等待训练完成 + time.sleep(3) + + # 检查阈值是否被优化 + final_response = self._make_request('GET', '/model/info') + final_thresholds = final_response.json()['current_thresholds'] + + logger.info(f"优化后阈值: {final_thresholds}") + + # 验证阈值是否发生变化(由于是模拟数据,可能变化很小) + threshold_changed = False + for key in initial_thresholds: + if abs(initial_thresholds[key] - final_thresholds[key]) > 0.01: + threshold_changed = True + break + + if threshold_changed: + logger.info("阈值优化成功") + else: + logger.info("阈值未发生变化(可能是优化逻辑或数据特征导致的)") + + def test_07_error_handling(self): + """测试错误处理""" + logger.info("测试错误处理...") + + # 测试无效的预测请求 + invalid_request = { + "tasks": [ + { + "data": { + "shot": "invalid_shot", + "csv": "nonexistent.csv" + } + } + ] + } + + response = self._make_request('POST', '/predict', data=invalid_request) + + # 应该返回200状态码,但预测结果可能为空 + assert response.status_code == 200 + + # 测试无效的训练数据 + invalid_training_data = { + "event": "INVALID_EVENT", + "data": "invalid_data" + } + + response = self._make_request('POST', '/fit', data=invalid_training_data) + assert response.status_code in [200, 204, 400] + + logger.info("错误处理测试完成") + + def test_08_performance_test(self): + """测试性能""" + logger.info("测试性能...") + + # 创建大量数据 + large_data = self._create_large_temperature_data() + + prediction_request = { + "tasks": [ + { + "data": { + "shot": 999999, + "csv": "large_temperature_data.csv" + } + } + ] + } + + with patch('utils.load_data') as mock_load: + mock_load.return_value = {999999: large_data} + + start_time = time.time() + response = self._make_request('POST', '/predict', data=prediction_request) + end_time = time.time() + + processing_time = end_time - start_time + + assert response.status_code == 200 + assert processing_time < 10.0, f"预测处理时间过长: {processing_time:.2f}秒" + + logger.info(f"性能测试完成,处理时间: {processing_time:.2f}秒") + + def test_09_model_reset(self): + """测试模型重置功能""" + logger.info("测试模型重置功能...") + + # 获取当前模型状态 + initial_response = self._make_request('GET', '/model/info') + initial_info = initial_response.json() + + # 重置模型 + reset_response = self._make_request('POST', '/model/reset') + assert reset_response.status_code in [200, 204] + + # 等待重置完成 + time.sleep(2) + + # 检查重置后的状态 + final_response = self._make_request('GET', '/model/info') + final_info = final_response.json() + + # 验证重置 + assert final_info['training_count'] == 0 + assert final_info['model_version'] == 'temperature_v1.0' + + logger.info("模型重置测试完成") + + def _create_sample_temperature_data(self): + """创建样本温度数据""" + time_data = np.linspace(0, 10, 1000) + temp_data = 500 + 1000 * np.sin(time_data) + 100 * np.random.randn(1000) + + df = pd.DataFrame({ + 'time': time_data, + 'Te_1': temp_data, + 'Te_2': temp_data * 0.8 + 50 * np.random.randn(1000) + }) + + return df + + def _create_realistic_temperature_data(self): + """创建真实的温度数据(模拟托卡马克实验)""" + time_data = np.linspace(0, 15, 1500) + + # 模拟托卡马克温度曲线:加热期 -> 平顶期 -> 衰减期 + temp_curve = np.zeros_like(time_data) + + # 加热期 (0-3秒) + heating_mask = (time_data >= 0) & (time_data < 3) + temp_curve[heating_mask] = 200 + 800 * (time_data[heating_mask] / 3) + + # 平顶期 (3-8秒) + plateau_mask = (time_data >= 3) & (time_data < 8) + temp_curve[plateau_mask] = 1000 + 50 * np.sin(2 * np.pi * time_data[plateau_mask]) + + # 衰减期 (8-15秒) + decay_mask = (time_data >= 8) & (time_data <= 15) + decay_time = time_data[decay_mask] - 8 + temp_curve[decay_mask] = 1000 * np.exp(-decay_time / 2) + + # 添加噪声 + temp_curve += 20 * np.random.randn(len(temp_curve)) + + # 创建多通道数据 + df = pd.DataFrame({ + 'time': time_data, + 'Te_1': temp_curve, + 'Te_2': temp_curve * 0.9 + 30 * np.random.randn(len(temp_curve)), + 'Te_3': temp_curve * 1.1 - 50 * np.random.randn(len(temp_curve)) + }) + + return df + + def _create_large_temperature_data(self): + """创建大量温度数据用于性能测试""" + time_data = np.linspace(0, 30, 3000) + temp_data = 500 + 1000 * np.sin(time_data / 5) + 100 * np.random.randn(3000) + + df = pd.DataFrame({ + 'time': time_data, + 'Te_1': temp_data, + 'Te_2': temp_data * 0.8 + 50 * np.random.randn(3000), + 'Te_3': temp_data * 1.2 - 80 * np.random.randn(3000), + 'Te_4': temp_data * 0.6 + 120 * np.random.randn(3000) + }) + + return df + + +class TestTemperatureBackendIntegration: + """温度标注ML Backend集成测试类""" + + def test_model_initialization(self): + """测试模型初始化""" + from model import TemperatureModel + + model = TemperatureModel() + model.setup() + + assert model.get("model_version") == "temperature_v1.0" + assert hasattr(model, 'predictor') + assert model.get('training_count') == 0 + + def test_predictor_functionality(self): + """测试预测器功能""" + from temperature_predictor import TemperaturePredictor + from prediction import Prediction + + predictor = TemperaturePredictor() + + # 创建测试数据 + test_data = pd.DataFrame({ + 'time': [0, 1, 2, 3, 4, 5], + 'Te_1': [100, 200, 300, 400, 500, 600], + 'Te_2': [150, 250, 350, 450, 550, 650] + }) + + # 执行预测 + predictions = predictor.user_predict(test_data) + + # 验证预测结果 + assert isinstance(predictions, list) + if len(predictions) > 0: + assert all(isinstance(p, Prediction) for p in predictions) + + def test_data_processing_pipeline(self): + """测试数据处理流水线""" + from utils import validate_temperature_data, preprocess_temperature_data + + # 创建测试数据 + test_data = pd.DataFrame({ + 'time': [0, 1, 2, 3, 4], + 'Te_1': [100, 200, 300, 400, 500], + 'Te_2': [150, 250, 350, 450, 550] + }) + + # 验证数据 + assert validate_temperature_data(test_data) == True + + # 预处理数据 + processed_data = preprocess_temperature_data(test_data) + assert len(processed_data) == len(test_data) + + def test_prediction_conversion(self): + """测试预测结果转换""" + from prediction import Prediction, convert_to_labelstudio_form + + # 创建预测结果 + predictions = [ + Prediction("temperature_events", "Te_1_上升阶段", 1.0, 3.0), + Prediction("temperature_events", "Te_1_峰值时刻", 2.0, None) + ] + + # 转换为Label Studio格式 + ls_format = convert_to_labelstudio_form(predictions, "test_model") + + assert len(ls_format) == 1 + assert ls_format[0]["model_version"] == "test_model" + assert len(ls_format[0]["result"]) == 2 + + +if __name__ == "__main__": + # 运行端到端测试 + pytest.main([__file__, "-v", "--tb=short"]) diff --git a/ml-backends/temperature_annotation/test_temperature_backend.py b/ml-backends/temperature_annotation/test_temperature_backend.py new file mode 100644 index 0000000..1ff7635 --- /dev/null +++ b/ml-backends/temperature_annotation/test_temperature_backend.py @@ -0,0 +1,315 @@ +""" +温度标注ML Backend的测试文件 + +测试各个组件的功能和集成 +""" + +import pytest +import numpy as np +import pandas as pd +from unittest.mock import Mock, patch, MagicMock +import tempfile +import os + +# 导入被测试的模块 +from model import TemperatureModel +from temperature_predictor import TemperaturePredictor +from prediction import Prediction, start_end_time_1D, convert_to_labelstudio_form +from utils import validate_temperature_data, preprocess_temperature_data, extract_temperature_channels + + +class TestPrediction: + """测试预测结果数据结构""" + + def test_prediction_creation(self): + """测试预测结果创建""" + pred = Prediction("test_group", "test_label", 1.0, 2.0, 0.8) + + assert pred.label_group == "test_group" + assert pred.label == "test_label" + assert pred.start == 1.0 + assert pred.end == 2.0 + assert pred.score == 0.8 + + def test_prediction_repr(self): + """测试预测结果的字符串表示""" + pred = Prediction("test_group", "test_label", 1.0, 2.0) + repr_str = repr(pred) + + assert "TemperaturePrediction" in repr_str + assert "test_label" in repr_str + assert "1.0" in repr_str + assert "2.0" in repr_str + + +class TestPredictionFunctions: + """测试预测函数""" + + def test_start_end_time_1D_positive(self): + """测试正向阈值检测""" + data = np.array([0, 1, 2, 3, 4, 5]) + threshold = 2.5 + + intervals = start_end_time_1D(data, threshold, postive=True) + + assert len(intervals) == 1 + assert intervals[0] == (3, 5) # 索引3-5的值大于2.5 + + def test_start_end_time_1D_negative(self): + """测试负向阈值检测""" + data = np.array([5, 4, 3, 2, 1, 0]) + threshold = 2.5 + + intervals = start_end_time_1D(data, threshold, postive=False) + + assert len(intervals) == 1 + assert intervals[0] == (3, 5) # 索引3-5的值小于2.5 + + def test_convert_to_labelstudio_form(self): + """测试转换为Label Studio格式""" + predictions = [ + Prediction("group1", "label1", 1.0, 2.0), + Prediction("group2", "label2", 3.0, None) + ] + + result = convert_to_labelstudio_form(predictions, "test_model") + + assert len(result) == 1 + assert result[0]["model_version"] == "test_model" + assert len(result[0]["result"]) == 2 + + # 检查第一个预测结果 + first_result = result[0]["result"][0] + assert first_result["from_name"] == "group1" + assert first_result["to_name"] == "ts" + assert first_result["type"] == "timeserieslabels" + assert first_result["value"]["start"] == 1.0 + assert first_result["value"]["end"] == 2.0 + assert first_result["value"]["timeserieslabels"] == ["label1"] + + # 检查第二个预测结果(时间点) + second_result = result[0]["result"][1] + assert second_result["value"]["start"] == 3.0 + assert second_result["value"]["end"] == 3.0 # 应该自动设置为start + + +class TestUtils: + """测试工具函数""" + + def test_validate_temperature_data_valid(self): + """测试有效温度数据验证""" + df = pd.DataFrame({ + 'time': [0, 1, 2, 3], + 'Te_1': [100, 200, 300, 400], + 'Te_2': [150, 250, 350, 450] + }) + + assert validate_temperature_data(df) == True + + def test_validate_temperature_data_missing_time(self): + """测试缺少时间列的数据验证""" + df = pd.DataFrame({ + 'Te_1': [100, 200, 300, 400], + 'Te_2': [150, 250, 350, 450] + }) + + assert validate_temperature_data(df) == False + + def test_validate_temperature_data_empty(self): + """测试空数据验证""" + df = pd.DataFrame() + + assert validate_temperature_data(df) == False + + def test_extract_temperature_channels(self): + """测试温度通道提取""" + df = pd.DataFrame({ + 'time': [0, 1, 2], + 'Te_1': [100, 200, 300], + 'Te_2': [150, 250, 350], + 'other': [10, 20, 30] + }) + + channels = extract_temperature_channels(df) + + assert 'Te_1' in channels + assert 'Te_2' in channels + assert 'other' not in channels + assert len(channels) == 2 + + +class TestTemperaturePredictor: + """测试温度预测器""" + + def setup_method(self): + """每个测试方法前的设置""" + self.predictor = TemperaturePredictor() + + def test_initialization(self): + """测试初始化""" + assert self.predictor.label_group == 'temperature_events' + assert self.predictor.temp_rise_threshold == 1000.0 + assert self.predictor.temp_fall_threshold == 500.0 + assert self.predictor.gradient_threshold == 100.0 + + def test_extract_temperature_channels(self): + """测试温度通道提取""" + df = pd.DataFrame({ + 'time': [0, 1, 2], + 'Te_1': [100, 200, 300], + 'Te_2': [150, 250, 350], + 'other': [10, 20, 30] + }) + + channels = self.predictor._extract_temperature_channels(df) + + assert 'Te_1' in channels + assert 'Te_2' in channels + assert 'other' not in channels + + def test_handle_missing_values(self): + """测试缺失值处理""" + # 创建包含NaN的数据 + temp_data = np.array([100, np.nan, 300, np.nan, 500]) + + cleaned_data = self.predictor._handle_missing_values(temp_data) + + # 检查NaN值是否被处理 + assert not np.any(np.isnan(cleaned_data)) + assert len(cleaned_data) == 5 + + def test_update_thresholds(self): + """测试阈值更新""" + new_thresholds = { + 'temp_rise_threshold': 1200.0, + 'temp_fall_threshold': 600.0 + } + + self.predictor.update_thresholds(new_thresholds) + + assert self.predictor.temp_rise_threshold == 1200.0 + assert self.predictor.temp_fall_threshold == 600.0 + + def test_get_current_thresholds(self): + """测试获取当前阈值""" + thresholds = self.predictor.get_current_thresholds() + + assert 'temp_rise_threshold' in thresholds + assert 'temp_fall_threshold' in thresholds + assert 'gradient_threshold' in thresholds + assert thresholds['temp_rise_threshold'] == 1000.0 + + +class TestTemperatureModel: + """测试温度标注模型""" + + def setup_method(self): + """每个测试方法前的设置""" + with patch('utils.load_data'): + with patch('utils.validate_temperature_data', return_value=True): + with patch('utils.preprocess_temperature_data'): + self.model = TemperatureModel() + self.model.setup() + + def test_model_initialization(self): + """测试模型初始化""" + assert self.model.get("model_version") == "temperature_v1.0" + assert hasattr(self.model, 'predictor') + assert isinstance(self.model.predictor, TemperaturePredictor) + + def test_get_model_info(self): + """测试获取模型信息""" + info = self.model.get_model_info() + + assert 'model_version' in info + assert 'training_count' in info + assert 'last_training_time' in info + assert 'current_thresholds' in info + assert info['model_version'] == 'temperature_v1.0' + + def test_health_check(self): + """测试健康检查""" + health = self.model.health_check() + + assert 'status' in health + assert health['status'] == 'healthy' + assert 'model_version' in health + assert 'predictor_initialized' in health + + def test_reset_model(self): + """测试模型重置""" + # 修改一些参数 + self.model.set('training_count', 100) + + # 重置模型 + self.model.reset_model() + + # 检查是否重置 + assert self.model.get('training_count') == 0 + assert self.model.get('model_version') == 'temperature_v1.0' + + +class TestIntegration: + """集成测试""" + + def test_end_to_end_prediction(self): + """测试端到端预测流程""" + with patch('utils.load_data') as mock_load: + with patch('utils.validate_temperature_data', return_value=True): + with patch('utils.preprocess_temperature_data') as mock_preprocess: + # 创建模型 + model = TemperatureModel() + model.setup() + + # 模拟数据 + test_data = pd.DataFrame({ + 'time': [0, 1, 2, 3, 4], + 'Te_1': [100, 200, 300, 400, 500], + 'Te_2': [150, 250, 350, 450, 550] + }) + + mock_load.return_value = {123: test_data} + mock_preprocess.return_value = test_data + + # 执行预测 + tasks = [{'data': {'shot': 123, 'csv': 'test.csv'}}] + response = model.predict(tasks) + + # 验证响应 + assert hasattr(response, 'predictions') + assert isinstance(response.predictions, list) + + def test_prediction_with_mock_data(self): + """测试使用模拟数据的预测""" + # 创建模拟的预测器 + mock_predictor = Mock() + mock_predictor.user_predict.return_value = [ + Prediction("test_group", "test_label", 1.0, 2.0) + ] + + # 创建模型并替换预测器 + model = TemperatureModel() + model.predictor = mock_predictor + + # 模拟数据加载 + with patch.object(model, 'get_data') as mock_get_data: + test_data = pd.DataFrame({ + 'time': [0, 1, 2], + 'Te_1': [100, 200, 300] + }) + mock_get_data.return_value = {123: test_data} + + # 执行预测 + tasks = [{'data': {'shot': 123, 'csv': 'test.csv'}}] + response = model.predict(tasks) + + # 验证预测器被调用 + mock_predictor.user_predict.assert_called_once() + + # 验证响应 + assert len(response.predictions) > 0 + + +if __name__ == "__main__": + # 运行测试 + pytest.main([__file__, "-v"]) diff --git a/ml-backends/temperature_annotation/utils.py b/ml-backends/temperature_annotation/utils.py new file mode 100644 index 0000000..a9c2a8e --- /dev/null +++ b/ml-backends/temperature_annotation/utils.py @@ -0,0 +1,296 @@ +""" +数据工具模块 + +提供温度数据的加载、处理和工具函数 +""" + +import pandas as pd +import requests +import concurrent.futures +from typing import Dict, Optional +import numpy as np +import logging + +logger = logging.getLogger(__name__) + + +def load_data(urls: Dict[int, str]) -> Dict[int, pd.DataFrame]: + """ + 并发加载温度数据 + + Args: + urls: 炮号到URL的映射字典 + + Returns: + 炮号到数据DataFrame的映射字典 + """ + def fetch_csv(shot_url_pair): + shot, url = shot_url_pair + try: + response = requests.get(url, timeout=30) + response.raise_for_status() + df = pd.read_csv(pd.StringIO(response.text)) + logger.info(f"成功加载炮号 {shot} 的数据,形状: {df.shape}") + return shot, df + except Exception as e: + logger.error(f"加载炮号 {shot} 数据失败: {e}") + return shot, None + + data_dict = {} + logger.info(f"开始加载 {len(urls)} 个炮号的数据") + + with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: + results = executor.map(fetch_csv, urls.items()) + + for shot, df in results: + if df is not None: + data_dict[shot] = df + + logger.info(f"成功加载 {len(data_dict)} 个炮号的数据") + return data_dict + + +def validate_temperature_data(df: pd.DataFrame, required_columns: list = None) -> bool: + """ + 验证温度数据的有效性 + + Args: + df: 温度数据DataFrame + required_columns: 必需的列名列表 + + Returns: + 数据是否有效 + """ + if df is None or df.empty: + logger.warning("数据为空") + return False + + if required_columns: + missing_columns = [col for col in required_columns if col not in df.columns] + if missing_columns: + logger.warning(f"缺少必需的列: {missing_columns}") + return False + + # 检查时间列 + if 'time' not in df.columns: + logger.warning("缺少时间列") + return False + + # 检查时间数据的单调性 + time_data = df['time'].values + if not np.all(np.diff(time_data) >= 0): + logger.warning("时间数据不是单调递增的") + return False + + # 检查数值列的有效性 + numeric_columns = df.select_dtypes(include=[np.number]).columns + for col in numeric_columns: + if col == 'time': + continue + if df[col].isnull().all(): + logger.warning(f"列 {col} 全为空值") + return False + + logger.info("温度数据验证通过") + return True + + +def preprocess_temperature_data(df: pd.DataFrame, + remove_outliers: bool = True, + outlier_threshold: float = 3.0, + interpolate_missing: bool = True) -> pd.DataFrame: + """ + 预处理温度数据 + + Args: + df: 原始温度数据 + remove_outliers: 是否移除异常值 + outlier_threshold: 异常值检测阈值(标准差倍数) + interpolate_missing: 是否插值缺失值 + + Returns: + 预处理后的数据 + """ + df_processed = df.copy() + + # 插值缺失值 + if interpolate_missing: + numeric_columns = df_processed.select_dtypes(include=[np.number]).columns + for col in numeric_columns: + if col == 'time': + continue + if df_processed[col].isnull().any(): + df_processed[col] = df_processed[col].interpolate(method='linear') + logger.info(f"列 {col} 的缺失值已插值") + + # 移除异常值 + if remove_outliers: + numeric_columns = df_processed.select_dtypes(include=[np.number]).columns + for col in numeric_columns: + if col == 'time': + continue + + # 计算统计量 + mean_val = df_processed[col].mean() + std_val = df_processed[col].std() + + # 标记异常值 + outlier_mask = np.abs(df_processed[col] - mean_val) > outlier_threshold * std_val + outlier_count = outlier_mask.sum() + + if outlier_count > 0: + # 将异常值替换为均值 + df_processed.loc[outlier_mask, col] = mean_val + logger.info(f"列 {col} 移除了 {outlier_count} 个异常值") + + return df_processed + + +def extract_temperature_channels(df: pd.DataFrame, + channel_patterns: list = None) -> list: + """ + 提取温度通道列名 + + Args: + df: 温度数据DataFrame + channel_patterns: 通道名称模式列表 + + Returns: + 温度通道列名列表 + """ + if channel_patterns is None: + channel_patterns = ['Te', 'temp', 'Ti', 'temperature'] + + temperature_channels = [] + for col in df.columns: + col_lower = col.lower() + if any(pattern.lower() in col_lower for pattern in channel_patterns): + temperature_channels.append(col) + + logger.info(f"找到 {len(temperature_channels)} 个温度通道: {temperature_channels}") + return temperature_channels + + +def calculate_temperature_statistics(df: pd.DataFrame, + temperature_channels: list = None) -> Dict: + """ + 计算温度统计信息 + + Args: + df: 温度数据DataFrame + temperature_channels: 温度通道列表 + + Returns: + 统计信息字典 + """ + if temperature_channels is None: + temperature_channels = extract_temperature_channels(df) + + stats = {} + for channel in temperature_channels: + if channel in df.columns: + channel_data = df[channel].dropna() + if len(channel_data) > 0: + stats[channel] = { + 'mean': float(channel_data.mean()), + 'std': float(channel_data.std()), + 'min': float(channel_data.min()), + 'max': float(channel_data.max()), + 'count': len(channel_data) + } + + logger.info(f"计算了 {len(stats)} 个通道的统计信息") + return stats + + +def normalize_temperature_data(df: pd.DataFrame, + temperature_channels: list = None, + method: str = 'zscore') -> pd.DataFrame: + """ + 标准化温度数据 + + Args: + df: 温度数据DataFrame + temperature_channels: 温度通道列表 + method: 标准化方法 ('zscore', 'minmax', 'robust') + + Returns: + 标准化后的数据 + """ + if temperature_channels is None: + temperature_channels = extract_temperature_channels(df) + + df_normalized = df.copy() + + for channel in temperature_channels: + if channel in df_normalized.columns: + channel_data = df_normalized[channel].dropna() + if len(channel_data) > 0: + if method == 'zscore': + # Z-score标准化 + mean_val = channel_data.mean() + std_val = channel_data.std() + if std_val > 0: + df_normalized[channel] = (df_normalized[channel] - mean_val) / std_val + + elif method == 'minmax': + # Min-Max标准化 + min_val = channel_data.min() + max_val = channel_data.max() + if max_val > min_val: + df_normalized[channel] = (df_normalized[channel] - min_val) / (max_val - min_val) + + elif method == 'robust': + # 稳健标准化(基于中位数和四分位距) + median_val = channel_data.median() + q75 = channel_data.quantile(0.75) + q25 = channel_data.quantile(0.25) + iqr = q75 - q25 + if iqr > 0: + df_normalized[channel] = (df_normalized[channel] - median_val) / iqr + + logger.info(f"通道 {channel} 已使用 {method} 方法标准化") + + return df_normalized + + +def segment_temperature_data(df: pd.DataFrame, + segment_length: float = 1.0, + overlap: float = 0.0) -> list: + """ + 分段温度数据 + + Args: + df: 温度数据DataFrame + segment_length: 段长度(秒) + overlap: 重叠比例(0-1) + + Returns: + 数据段列表 + """ + segments = [] + time_data = df['time'].values + + if len(time_data) < 2: + return segments + + # 计算时间步长 + time_step = time_data[1] - time_data[0] + segment_samples = int(segment_length / time_step) + overlap_samples = int(segment_samples * overlap) + hop_samples = segment_samples - overlap_samples + + for i in range(0, len(df) - segment_samples + 1, hop_samples): + segment_df = df.iloc[i:i + segment_samples].copy() + segment_start = time_data[i] + segment_end = time_data[i + segment_samples - 1] + + segments.append({ + 'data': segment_df, + 'start_time': segment_start, + 'end_time': segment_end, + 'segment_index': i // hop_samples + }) + + logger.info(f"将数据分为 {len(segments)} 段,每段长度 {segment_length} 秒") + return segments