diff --git a/ml-backends/temperature_annotation/Dockerfile b/ml-backends/temperature_annotation/Dockerfile
new file mode 100644
index 0000000..87020cc
--- /dev/null
+++ b/ml-backends/temperature_annotation/Dockerfile
@@ -0,0 +1,34 @@
+# 温度标注ML Backend的Docker配置
+# 基于Python 3.9构建
+
+FROM python:3.9-slim
+
+WORKDIR /app
+
+# 设置环境变量
+ENV PYTHONUNBUFFERED=1 \
+ PYTHONDONTWRITEBYTECODE=1 \
+ PORT=9090 \
+ WORKERS=1 \
+ THREADS=8
+
+# 安装系统依赖
+RUN apt-get update && apt-get install -y \
+ git \
+ && rm -rf /var/lib/apt/lists/*
+
+# 复制依赖文件
+COPY requirements.txt .
+COPY requirements-base.txt .
+
+# 安装Python依赖
+RUN pip install --no-cache-dir -r requirements.txt
+
+# 复制应用代码
+COPY . .
+
+# 暴露端口
+EXPOSE 9090
+
+# 启动命令
+CMD ["python", "_wsgi.py", "--host", "0.0.0.0", "--port", "9090"]
diff --git a/ml-backends/temperature_annotation/README.md b/ml-backends/temperature_annotation/README.md
new file mode 100644
index 0000000..804e536
--- /dev/null
+++ b/ml-backends/temperature_annotation/README.md
@@ -0,0 +1,197 @@
+# Temperature Annotation ML Backend
+
+这是一个专门用于托卡马克等离子体温度数据标注的ML Backend,支持温度特征识别、异常检测和智能标注。
+
+## 功能特性
+
+- **温度特征识别**: 自动检测温度上升、下降、峰值等关键阶段
+- **异常事件检测**: 识别温度突变、骤降等异常情况
+- **多通道支持**: 支持多个温度测量通道的并行分析
+- **自适应阈值**: 基于人工标注自动优化检测阈值
+- **实时预测**: 基于Label Studio ML Backend架构的实时预测服务
+
+## 核心算法
+
+### 温度阶段检测
+- **上升阶段**: 检测温度超过阈值的持续上升区间
+- **峰值时刻**: 基于梯度变化检测温度峰值点
+- **下降阶段**: 识别温度低于阈值的下降区间
+- **平台期**: 检测温度相对稳定的平台阶段
+
+### 异常检测
+- **温度突变**: 基于滑动窗口统计检测异常温度变化
+- **温度骤降**: 识别温度快速下降的异常事件
+
+### 阈值优化
+- 支持基于人工标注的自适应阈值调整
+- 动态学习不同实验条件下的最佳参数
+
+## 快速开始
+
+### 使用Docker运行 (推荐)
+
+1. 启动ML Backend服务:
+
+```bash
+docker-compose up
+```
+
+2. 验证服务运行状态:
+
+```bash
+curl http://localhost:9090/health
+# 返回: {"status":"healthy","model_version":"temperature_v1.0",...}
+```
+
+3. 在Label Studio中连接后端:
+ 进入项目设置 -> Machine Learning -> Add Model,指定URL为 `http://localhost:9090`
+
+### 从源码构建
+
+```bash
+docker-compose build
+```
+
+### 不使用Docker运行
+
+```bash
+python -m venv temperature-ml-backend
+source temperature-ml-backend/bin/activate # Windows: temperature-ml-backend\Scripts\activate
+pip install -r requirements.txt
+python _wsgi.py
+```
+
+## 配置参数
+
+在 `config.json` 中可以设置以下参数:
+
+### 温度阈值设置
+- `temp_rise_threshold`: 温度上升阈值 (默认1000.0 eV)
+- `temp_fall_threshold`: 温度下降阈值 (默认500.0 eV)
+- `gradient_threshold`: 梯度阈值 (默认100.0 eV/ms)
+- `anomaly_threshold`: 异常检测阈值 (默认3.0)
+- `min_peak_height`: 最小峰值高度 (默认800.0 eV)
+
+### 数据处理设置
+- `remove_outliers`: 是否移除异常值
+- `outlier_threshold`: 异常值检测阈值
+- `interpolate_missing`: 是否插值缺失值
+- `normalize_data`: 是否标准化数据
+
+### 预测设置
+- `detect_rise_phases`: 是否检测上升阶段
+- `detect_peak_moments`: 是否检测峰值时刻
+- `detect_fall_phases`: 是否检测下降阶段
+- `detect_anomalies`: 是否检测异常事件
+- `detect_plateau_phases`: 是否检测平台期
+
+## Label Studio配置
+
+在Label Studio中使用以下XML配置来支持温度标注:
+
+```xml
+
+
+
+
+
+
+
+
+
+
+
+```
+
+## 数据格式要求
+
+### 输入数据格式
+- CSV文件,包含时间列和温度通道列
+- 时间列名必须为 `time`
+- 温度通道列名应包含 `Te`、`temp`、`Ti` 等关键词
+- 支持多通道温度数据
+
+### 示例数据
+```csv
+time,Te_1,Te_2,Te_3
+0.0,100,150,200
+0.001,120,170,220
+0.002,140,190,240
+...
+```
+
+## 部署到Kubernetes
+
+```bash
+kubectl apply -f temperature-annotation.yaml
+```
+
+## 自定义模型
+
+可以通过修改 `temperature_predictor.py` 中的算法来自定义温度处理逻辑:
+
+- 调整阈值检测算法
+- 添加新的温度特征识别方法
+- 自定义异常检测规则
+- 优化梯度计算方法
+
+## 测试
+
+运行测试确保功能正常:
+
+```bash
+pytest test_temperature_backend.py -v
+```
+
+## API接口
+
+### 健康检查
+```
+GET /health
+```
+
+### 模型信息
+```
+GET /model/info
+```
+
+### 预测接口
+```
+POST /predict
+```
+
+### 训练接口
+```
+POST /fit
+```
+
+## 故障排除
+
+### 常见问题
+
+1. **模型加载失败**: 检查依赖包安装和配置文件
+2. **预测结果为空**: 确认输入数据格式和温度通道识别
+3. **阈值设置不当**: 根据实际数据调整阈值参数
+4. **内存不足**: 调整Docker容器的内存限制
+
+### 日志查看
+
+```bash
+# Docker日志
+docker logs temperature-annotation-ml-backend
+
+# 查看模型日志
+docker exec -it temperature-annotation-ml-backend tail -f /app/logs/model.log
+```
+
+## 性能优化
+
+### 数据处理优化
+- 使用并发加载多个炮号数据
+- 支持数据预处理和缓存
+- 优化异常值检测算法
+
+### 预测性能优化
+- 批量处理多个温度通道
+- 并行计算不同特征检测
+- 支持增量学习和模型更新
diff --git a/ml-backends/temperature_annotation/_wsgi.py b/ml-backends/temperature_annotation/_wsgi.py
new file mode 100644
index 0000000..60c3782
--- /dev/null
+++ b/ml-backends/temperature_annotation/_wsgi.py
@@ -0,0 +1,132 @@
+"""
+温度标注ML Backend的WSGI应用入口
+
+基于Label Studio ML Backend的标准模式实现
+"""
+
+import os
+import argparse
+import json
+import logging
+import logging.config
+
+# 设置默认日志级别
+log_level = os.getenv("LOG_LEVEL", "INFO")
+
+logging.config.dictConfig({
+ "version": 1,
+ "disable_existing_loggers": False,
+ "formatters": {
+ "standard": {
+ "format": "[%(asctime)s] [%(levelname)s] [%(name)s::%(funcName)s::%(lineno)d] %(message)s"
+ }
+ },
+ "handlers": {
+ "console": {
+ "class": "logging.StreamHandler",
+ "level": log_level,
+ "stream": "ext://sys.stdout",
+ "formatter": "standard"
+ }
+ },
+ "root": {
+ "level": log_level,
+ "handlers": ["console"],
+ "propagate": True
+ }
+})
+
+from label_studio_ml.api import init_app
+from model import TemperatureModel
+
+_DEFAULT_CONFIG_PATH = os.path.join(os.path.dirname(__file__), 'config.json')
+
+
+def get_kwargs_from_config(config_path=_DEFAULT_CONFIG_PATH):
+ """从配置文件获取初始化参数"""
+ if not os.path.exists(config_path):
+ return dict()
+ with open(config_path) as f:
+ config = json.load(f)
+ assert isinstance(config, dict)
+ return config
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description='温度标注ML Backend')
+ parser.add_argument(
+ '-p', '--port', dest='port', type=int, default=9090,
+ help='Server port')
+ parser.add_argument(
+ '--host', dest='host', type=str, default='0.0.0.0',
+ help='Server host')
+ parser.add_argument(
+ '--kwargs', '--with', dest='kwargs', metavar='KEY=VAL', nargs='+',
+ type=lambda kv: kv.split('='),
+ help='Additional TemperatureModel initialization kwargs')
+ parser.add_argument(
+ '-d', '--debug', dest='debug', action='store_true',
+ help='Switch debug mode')
+ parser.add_argument(
+ '--log-level', dest='log_level',
+ choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'], default=log_level,
+ help='Logging level')
+ parser.add_argument(
+ '--model-dir', dest='model_dir', default=os.path.dirname(__file__),
+ help='Directory where models are stored')
+ parser.add_argument(
+ '--check', dest='check', action='store_true',
+ help='Validate model instance before launching server')
+ parser.add_argument('--basic-auth-user',
+ default=os.environ.get('ML_SERVER_BASIC_AUTH_USER', None),
+ help='Basic auth user')
+ parser.add_argument('--basic-auth-pass',
+ default=os.environ.get('ML_SERVER_BASIC_AUTH_PASS', None),
+ help='Basic auth pass')
+
+ args = parser.parse_args()
+
+ if args.log_level:
+ logging.root.setLevel(args.log_level)
+
+ def isfloat(value):
+ try:
+ float(value)
+ return True
+ except ValueError:
+ return False
+
+ def parse_kwargs():
+ param = dict()
+ if args.kwargs:
+ for k, v in args.kwargs:
+ if v.isdigit():
+ param[k] = int(v)
+ elif v == 'True' or v == 'true':
+ param[k] = True
+ elif v == 'False' or v == 'false':
+ param[k] = False
+ elif isfloat(v):
+ param[k] = float(v)
+ else:
+ param[k] = v
+ return param
+
+ kwargs = get_kwargs_from_config()
+ if args.kwargs:
+ kwargs.update(parse_kwargs())
+
+ if args.check:
+ print('Check "' + TemperatureModel.__name__ + '" instance creation..')
+ model = TemperatureModel(**kwargs)
+
+ app = init_app(
+ model_class=TemperatureModel,
+ basic_auth_user=args.basic_auth_user,
+ basic_auth_pass=args.basic_auth_pass
+ )
+ app.run(host=args.host, port=args.port, debug=args.debug)
+
+else:
+ # 用于uWSGI
+ app = init_app(model_class=TemperatureModel)
diff --git a/ml-backends/temperature_annotation/config.json b/ml-backends/temperature_annotation/config.json
new file mode 100644
index 0000000..21cf748
--- /dev/null
+++ b/ml-backends/temperature_annotation/config.json
@@ -0,0 +1,28 @@
+{
+ "model_version": "temperature_v1.0",
+ "temperature_thresholds": {
+ "temp_rise_threshold": 1000.0,
+ "temp_fall_threshold": 500.0,
+ "gradient_threshold": 100.0,
+ "anomaly_threshold": 3.0,
+ "min_peak_height": 800.0
+ },
+ "data_processing": {
+ "remove_outliers": true,
+ "outlier_threshold": 3.0,
+ "interpolate_missing": true,
+ "normalize_data": false
+ },
+ "prediction_settings": {
+ "detect_rise_phases": true,
+ "detect_peak_moments": true,
+ "detect_fall_phases": true,
+ "detect_anomalies": true,
+ "detect_plateau_phases": true
+ },
+ "model_optimization": {
+ "enable_adaptive_thresholds": true,
+ "learning_rate": 0.05,
+ "min_training_samples": 10
+ }
+}
diff --git a/ml-backends/temperature_annotation/docker-compose.yml b/ml-backends/temperature_annotation/docker-compose.yml
new file mode 100644
index 0000000..0a88d99
--- /dev/null
+++ b/ml-backends/temperature_annotation/docker-compose.yml
@@ -0,0 +1,22 @@
+version: '3.8'
+
+services:
+ ml-backend-temperature:
+ build: .
+ container_name: temperature-annotation-ml-backend
+ ports:
+ - "9090:9090"
+ environment:
+ - MODEL_VERSION=temperature_v1.0
+ - LOG_LEVEL=INFO
+ - WORKERS=1
+ - THREADS=8
+ volumes:
+ - ./models:/app/models
+ - ./data:/app/data
+ restart: unless-stopped
+ healthcheck:
+ test: ["CMD", "curl", "-f", "http://localhost:9090/health"]
+ interval: 30s
+ timeout: 10s
+ retries: 3
diff --git a/ml-backends/temperature_annotation/model.py b/ml-backends/temperature_annotation/model.py
new file mode 100644
index 0000000..e99443c
--- /dev/null
+++ b/ml-backends/temperature_annotation/model.py
@@ -0,0 +1,264 @@
+"""
+温度标注ML Backend模型
+
+基于Label Studio ML Backend架构,实现温度标注的AI预测服务
+"""
+
+from typing import List, Dict, Optional
+from label_studio_ml.model import LabelStudioMLBase
+from label_studio_ml.response import ModelResponse
+import prediction
+import utils
+import numpy as np
+import pandas as pd
+import logging
+import time
+from temperature_predictor import TemperaturePredictor
+
+logger = logging.getLogger(__name__)
+
+
+class TemperatureModel(LabelStudioMLBase):
+ """温度标注ML Backend模型"""
+
+ def setup(self):
+ """配置模型参数"""
+ self.set("model_version", "temperature_v1.0")
+
+ # 初始化温度预测器
+ self.predictor = TemperaturePredictor()
+
+ # 记录模型统计信息
+ self.set('training_count', 0)
+ self.set('last_training_time', time.time())
+
+ logger.info("温度标注模型初始化完成")
+
+ def get_data(self, tasks: List[Dict]) -> Dict:
+ """获取任务数据"""
+ urls = {}
+ for task in tasks:
+ data = task['data']
+ urls[data['shot']] = data['csv']
+
+ logger.info(f"开始加载 {len(urls)} 个炮号的数据")
+ data_dict = utils.load_data(urls)
+ return data_dict
+
+ def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) -> ModelResponse:
+ """执行温度预测"""
+ logger.info(f'运行温度预测,任务数量: {len(tasks)}')
+
+ try:
+ data_dict = self.get_data(tasks)
+ model_preds = []
+
+ for shot, data in data_dict.items():
+ try:
+ # 验证数据
+ if not utils.validate_temperature_data(data):
+ logger.warning(f"炮号 {shot} 数据验证失败,跳过")
+ continue
+
+ # 预处理数据
+ data_processed = utils.preprocess_temperature_data(data)
+
+ # 执行预测
+ preds = prediction.convert_to_labelstudio_form(
+ self.predictor.user_predict(data_processed),
+ 'temperature_model_v1'
+ )
+
+ if preds:
+ model_preds.extend(preds)
+ logger.info(f'炮号 {shot} 预测成功,生成 {len(preds)} 个预测结果')
+ else:
+ logger.info(f'炮号 {shot} 未生成预测结果')
+
+ except Exception as e:
+ logger.error(f'炮号 {shot} 预测失败: {e}')
+ continue
+
+ logger.info(f"温度预测完成,总共生成 {len(model_preds)} 个预测结果")
+
+ except Exception as e:
+ logger.error(f"温度预测过程中出错: {e}")
+ model_preds = []
+
+ return ModelResponse(predictions=model_preds)
+
+ def fit(self, event, data, **kwargs):
+ """
+ 温度标注模型的训练/更新逻辑
+
+ 每次创建或更新标注时调用此方法来优化模型参数
+
+ Args:
+ event: 事件类型 ('ANNOTATION_CREATED', 'ANNOTATION_UPDATED', 'START_TRAINING')
+ data: 来自webhook的数据载荷
+ """
+ logger.info(f'接收到温度标注训练事件: {event}')
+
+ try:
+ # 获取之前的模型参数
+ old_temp_rise_threshold = self.get('temp_rise_threshold', 1000.0)
+ old_temp_fall_threshold = self.get('temp_fall_threshold', 500.0)
+ old_gradient_threshold = self.get('gradient_threshold', 100.0)
+ old_model_version = self.get('model_version', 'temperature_v1.0')
+
+ logger.info(f'当前模型参数: 上升阈值={old_temp_rise_threshold}, 下降阈值={old_temp_fall_threshold}, 梯度阈值={old_gradient_threshold}')
+
+ # 处理标注数据以优化参数
+ if event in ['ANNOTATION_CREATED', 'ANNOTATION_UPDATED'] and data:
+ try:
+ # 解析标注数据
+ annotation_data = self._parse_annotation_data(data)
+
+ if annotation_data:
+ # 基于标注数据调整阈值参数
+ new_thresholds = self._optimize_thresholds(annotation_data)
+
+ # 更新预测器参数
+ if new_thresholds:
+ self.predictor.update_thresholds(new_thresholds)
+
+ # 保存到缓存
+ self.set('temp_rise_threshold', self.predictor.temp_rise_threshold)
+ self.set('temp_fall_threshold', self.predictor.temp_fall_threshold)
+ self.set('gradient_threshold', self.predictor.gradient_threshold)
+
+ # 更新模型版本
+ new_version = f"temperature_v1.0_{int(time.time())}"
+ self.set('model_version', new_version)
+
+ logger.info(f'参数已更新: 上升阈值={self.predictor.temp_rise_threshold}, 下降阈值={self.predictor.temp_fall_threshold}, 梯度阈值={self.predictor.gradient_threshold}')
+ logger.info(f'模型版本更新为: {new_version}')
+
+ except Exception as e:
+ logger.error(f'处理标注数据时出错: {e}')
+
+ # 记录训练统计信息
+ training_count = self.get('training_count', 0) + 1
+ self.set('training_count', training_count)
+ self.set('last_training_time', time.time())
+
+ logger.info(f'温度标注模型训练完成,总训练次数: {training_count}')
+
+ except Exception as e:
+ logger.error(f'温度标注模型训练过程中出错: {e}')
+
+ def _parse_annotation_data(self, data):
+ """解析标注数据,提取温度相关信息"""
+ try:
+ if 'annotation' in data and 'result' in data['annotation']:
+ results = data['annotation']['result']
+ temp_annotations = []
+
+ for result in results:
+ if (result.get('from_name') == 'temperature_events' and
+ result.get('type') == 'timeserieslabels'):
+
+ value = result.get('value', {})
+ temp_annotations.append({
+ 'label': value.get('timeserieslabels', [''])[0],
+ 'start': value.get('start'),
+ 'end': value.get('end'),
+ 'shot': data.get('task', {}).get('data', {}).get('shot')
+ })
+
+ logger.info(f"解析到 {len(temp_annotations)} 个温度标注")
+ return temp_annotations if temp_annotations else None
+
+ except Exception as e:
+ logger.error(f'解析标注数据失败: {e}')
+ return None
+
+ def _optimize_thresholds(self, annotation_data):
+ """基于标注数据优化阈值参数"""
+ try:
+ # 统计不同类型标注的特征
+ rise_annotations = [ann for ann in annotation_data if '上升' in ann['label']]
+ fall_annotations = [ann for ann in annotation_data if '下降' in ann['label']]
+ peak_annotations = [ann for ann in annotation_data if '峰值' in ann['label']]
+
+ new_thresholds = {}
+
+ # 基于标注数据调整阈值(简化的自适应逻辑)
+ if rise_annotations:
+ # 如果有上升阶段标注,可以适当降低上升阈值以提高敏感度
+ current_rise = self.get('temp_rise_threshold', 1000.0)
+ new_thresholds['temp_rise_threshold'] = max(800.0, current_rise * 0.95)
+
+ if fall_annotations:
+ # 如果有下降阶段标注,可以适当调整下降阈值
+ current_fall = self.get('temp_fall_threshold', 500.0)
+ new_thresholds['temp_fall_threshold'] = max(400.0, current_fall * 0.95)
+
+ if peak_annotations:
+ # 如果有峰值标注,可以调整梯度阈值
+ current_gradient = self.get('gradient_threshold', 100.0)
+ new_thresholds['gradient_threshold'] = max(50.0, current_gradient * 0.9)
+
+ if new_thresholds:
+ logger.info(f"基于标注数据优化阈值: {new_thresholds}")
+
+ return new_thresholds if new_thresholds else None
+
+ except Exception as e:
+ logger.error(f'优化阈值参数失败: {e}')
+ return None
+
+ def get_model_info(self) -> Dict:
+ """获取模型信息"""
+ return {
+ 'model_version': self.get('model_version', 'temperature_v1.0'),
+ 'training_count': self.get('training_count', 0),
+ 'last_training_time': self.get('last_training_time', 0),
+ 'current_thresholds': self.predictor.get_current_thresholds()
+ }
+
+ def reset_model(self):
+ """重置模型参数到默认值"""
+ try:
+ # 重置预测器参数
+ self.predictor = TemperaturePredictor()
+
+ # 重置模型版本
+ self.set('model_version', 'temperature_v1.0')
+ self.set('training_count', 0)
+ self.set('last_training_time', time.time())
+
+ logger.info("模型参数已重置到默认值")
+
+ except Exception as e:
+ logger.error(f"重置模型参数失败: {e}")
+
+ def update_model_config(self, config: Dict):
+ """更新模型配置"""
+ try:
+ if 'thresholds' in config:
+ self.predictor.update_thresholds(config['thresholds'])
+ logger.info(f"模型阈值已更新: {config['thresholds']}")
+
+ if 'model_version' in config:
+ self.set('model_version', config['model_version'])
+ logger.info(f"模型版本已更新: {config['model_version']}")
+
+ except Exception as e:
+ logger.error(f"更新模型配置失败: {e}")
+
+ def health_check(self) -> Dict:
+ """健康检查"""
+ try:
+ return {
+ 'status': 'healthy',
+ 'model_version': self.get('model_version', 'unknown'),
+ 'predictor_initialized': hasattr(self, 'predictor'),
+ 'last_training': self.get('last_training_time', 0),
+ 'training_count': self.get('training_count', 0)
+ }
+ except Exception as e:
+ return {
+ 'status': 'unhealthy',
+ 'error': str(e)
+ }
diff --git a/ml-backends/temperature_annotation/prediction.py b/ml-backends/temperature_annotation/prediction.py
new file mode 100644
index 0000000..bdef299
--- /dev/null
+++ b/ml-backends/temperature_annotation/prediction.py
@@ -0,0 +1,239 @@
+"""
+温度预测工具模块
+
+提供温度预测结果数据结构、预测器抽象基类和格式转换功能
+"""
+
+import abc
+import pandas as pd
+import numpy as np
+from typing import List, Optional
+
+
+class Prediction:
+ """温度预测结果数据结构"""
+
+ def __init__(self, label_group: str, label: str, start: float, end: Optional[float] = None, score: Optional[float] = None):
+ """
+ 初始化温度预测结果
+
+ Args:
+ label_group: 标签组名称
+ label: 标签名称
+ start: 开始时间
+ end: 结束时间(可选,None表示时间点)
+ score: 置信度分数(可选)
+ """
+ self.label_group = label_group
+ self.label = label
+ self.start = start
+ self.end = end
+ self.score = score
+
+ def __repr__(self):
+ """字符串表示"""
+ return f""
+
+
+class BasePredictor(abc.ABC):
+ """温度预测器抽象基类"""
+
+ @abc.abstractmethod
+ def user_predict(self, task_data: pd.DataFrame) -> List[Prediction]:
+ """
+ 执行温度预测
+
+ Args:
+ task_data: 任务数据,包含时间序列温度数据
+
+ Returns:
+ 预测结果列表
+ """
+ pass
+
+
+def start_end_time_1D(predict_result, threshold: float, postive: bool = True):
+ """
+ 1D温度数据阈值检测
+
+ Args:
+ predict_result: 温度数据数组
+ threshold: 阈值
+ postive: True表示检测大于阈值的区间,False表示检测小于阈值的区间
+
+ Returns:
+ 区间列表,每个元素为(start_index, end_index)的元组
+ """
+ predict_result = np.array(predict_result)
+
+ if postive:
+ mask = predict_result > threshold
+ else:
+ mask = predict_result < threshold
+
+ diff = np.diff(mask.astype(int))
+ starts = np.where(diff == 1)[0] + 1
+ ends = np.where(diff == -1)[0] + 1
+
+ if mask[0]:
+ starts = np.insert(starts, 0, 0)
+ if mask[-1]:
+ ends = np.append(ends, len(predict_result) - 1)
+
+ return list(zip(starts, ends))
+
+
+def start_end_time(predict_result, threshold, postive: bool = True, all_dim: bool = True, time_axis: int = -1):
+ """
+ 多维温度数据阈值检测
+
+ Args:
+ predict_result: 温度数据数组
+ threshold: 阈值
+ postive: True表示检测大于阈值的区间,False表示检测小于阈值的区间
+ all_dim: True表示所有维度必须同时满足阈值条件,False表示对每个通道逐一判断
+ time_axis: 时间轴索引
+
+ Returns:
+ 区间列表,每个元素为(start_index, end_index)的元组
+ """
+ predict_result = np.array(predict_result)
+
+ if postive:
+ mask = predict_result > threshold
+ else:
+ mask = predict_result < threshold
+
+ if all_dim:
+ # 所有维度同时满足条件
+ mask = np.all(mask, axis=tuple(i for i in range(mask.ndim) if i != time_axis))
+
+ diff = np.diff(mask.astype(int))
+ starts = np.where(diff == 1)[0] + 1
+ ends = np.where(diff == -1)[0] + 1
+
+ if mask[0]:
+ starts = np.insert(starts, 0, 0)
+ if mask[-1]:
+ ends = np.append(ends, len(mask) - 1)
+
+ return list(zip(starts, ends))
+
+
+def convert_to_labelstudio_form(pred_results: List[Prediction], model_version: str = "temperature_v1") -> List[dict]:
+ """
+ 转换为Label Studio格式
+
+ Args:
+ pred_results: 预测结果列表
+ model_version: 模型版本
+
+ Returns:
+ Label Studio格式的预测结果
+ """
+ LS_result_list = []
+
+ for p in pred_results:
+ if p.end is None:
+ p.end = p.start
+
+ LS_result_list.append({
+ "from_name": p.label_group,
+ "to_name": "ts",
+ "type": "timeserieslabels",
+ "value": {
+ "start": p.start,
+ "end": p.end,
+ "timeserieslabels": [p.label]
+ }
+ })
+
+ return [{
+ "model_version": model_version,
+ "result": LS_result_list
+ }] if LS_result_list else []
+
+
+def detect_temperature_anomalies(temp_data: np.ndarray, time_data: np.ndarray,
+ threshold: float = 3.0, window_size: int = 10):
+ """
+ 检测温度异常事件
+
+ Args:
+ temp_data: 温度数据
+ time_data: 时间数据
+ threshold: 异常检测阈值(标准差倍数)
+ window_size: 滑动窗口大小
+
+ Returns:
+ 异常事件列表,每个元素为(start_time, end_time, anomaly_type)的元组
+ """
+ anomalies = []
+
+ # 计算滑动窗口的统计量
+ for i in range(window_size, len(temp_data)):
+ window_data = temp_data[i-window_size:i]
+ mean_temp = np.mean(window_data)
+ std_temp = np.std(window_data)
+
+ current_temp = temp_data[i]
+
+ # 检测异常(超出阈值范围)
+ if abs(current_temp - mean_temp) > threshold * std_temp:
+ anomaly_type = "温度突变" if current_temp > mean_temp else "温度骤降"
+ anomalies.append((time_data[i], time_data[i], anomaly_type))
+
+ return anomalies
+
+
+def calculate_temperature_gradient(temp_data: np.ndarray, time_data: np.ndarray,
+ window_size: int = 5):
+ """
+ 计算温度梯度
+
+ Args:
+ temp_data: 温度数据
+ time_data: 时间数据
+ window_size: 梯度计算窗口大小
+
+ Returns:
+ 梯度数据数组
+ """
+ gradient = np.zeros_like(temp_data)
+
+ for i in range(window_size, len(temp_data) - window_size):
+ # 使用中心差分计算梯度
+ temp_diff = temp_data[i + window_size] - temp_data[i - window_size]
+ time_diff = time_data[i + window_size] - time_data[i - window_size]
+
+ if time_diff != 0:
+ gradient[i] = temp_diff / time_diff
+
+ return gradient
+
+
+def find_temperature_peaks(temp_data: np.ndarray, time_data: np.ndarray,
+ gradient_threshold: float = 100.0, min_peak_height: float = 0.0):
+ """
+ 查找温度峰值
+
+ Args:
+ temp_data: 温度数据
+ time_data: 时间数据
+ gradient_threshold: 梯度阈值
+ min_peak_height: 最小峰值高度
+
+ Returns:
+ 峰值时间列表
+ """
+ peaks = []
+ gradient = calculate_temperature_gradient(temp_data, time_data)
+
+ for i in range(1, len(gradient) - 1):
+ # 梯度从正变负,且温度值足够高
+ if (gradient[i-1] > gradient_threshold and
+ gradient[i+1] < -gradient_threshold and
+ temp_data[i] > min_peak_height):
+ peaks.append(time_data[i])
+
+ return peaks
diff --git a/ml-backends/temperature_annotation/requirements-base.txt b/ml-backends/temperature_annotation/requirements-base.txt
new file mode 100644
index 0000000..97f92d5
--- /dev/null
+++ b/ml-backends/temperature_annotation/requirements-base.txt
@@ -0,0 +1,2 @@
+label-studio-ml>=1.0.0
+gunicorn>=20.1.0
diff --git a/ml-backends/temperature_annotation/requirements-test.txt b/ml-backends/temperature_annotation/requirements-test.txt
new file mode 100644
index 0000000..ba4bd9d
--- /dev/null
+++ b/ml-backends/temperature_annotation/requirements-test.txt
@@ -0,0 +1,3 @@
+pytest>=7.0.0
+pytest-cov>=4.0.0
+pytest-mock>=3.10.0
diff --git a/ml-backends/temperature_annotation/requirements.txt b/ml-backends/temperature_annotation/requirements.txt
new file mode 100644
index 0000000..4000c6d
--- /dev/null
+++ b/ml-backends/temperature_annotation/requirements.txt
@@ -0,0 +1,6 @@
+label-studio-ml>=1.0.0
+pandas>=1.3.0
+numpy>=1.21.0
+requests>=2.25.0
+scikit-learn>=1.0.0
+scipy>=1.9.0
diff --git a/ml-backends/temperature_annotation/temperature_predictor.py b/ml-backends/temperature_annotation/temperature_predictor.py
new file mode 100644
index 0000000..2d32974
--- /dev/null
+++ b/ml-backends/temperature_annotation/temperature_predictor.py
@@ -0,0 +1,374 @@
+"""
+温度预测器模块
+
+实现温度特征识别算法,包括温度上升、下降、峰值检测等
+"""
+
+import pandas as pd
+import numpy as np
+from typing import List
+from prediction import BasePredictor, Prediction
+from prediction import start_end_time_1D, detect_temperature_anomalies, find_temperature_peaks
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class TemperaturePredictor(BasePredictor):
+ """温度预测器类"""
+
+ def __init__(self,
+ label_group: str = 'temperature_events',
+ temp_rise_threshold: float = 1000.0, # eV
+ temp_fall_threshold: float = 500.0, # eV
+ gradient_threshold: float = 100.0, # eV/ms
+ anomaly_threshold: float = 3.0, # 异常检测阈值
+ min_peak_height: float = 800.0): # 最小峰值高度
+ """
+ 初始化温度预测器
+
+ Args:
+ label_group: 标签组名称
+ temp_rise_threshold: 温度上升阈值
+ temp_fall_threshold: 温度下降阈值
+ gradient_threshold: 梯度阈值
+ anomaly_threshold: 异常检测阈值
+ min_peak_height: 最小峰值高度
+ """
+ super().__init__()
+ self.label_group = label_group
+ self.temp_rise_threshold = temp_rise_threshold
+ self.temp_fall_threshold = temp_fall_threshold
+ self.gradient_threshold = gradient_threshold
+ self.anomaly_threshold = anomaly_threshold
+ self.min_peak_height = min_peak_height
+
+ logger.info(f"温度预测器初始化完成,标签组: {label_group}")
+ logger.info(f"阈值设置: 上升={temp_rise_threshold}, 下降={temp_fall_threshold}, 梯度={gradient_threshold}")
+
+ def user_predict(self, task_data: pd.DataFrame) -> List[Prediction]:
+ """
+ 执行温度预测
+
+ Args:
+ task_data: 任务数据,包含时间序列温度数据
+
+ Returns:
+ 预测结果列表
+ """
+ predictions = []
+
+ try:
+ # 验证数据
+ if task_data is None or task_data.empty:
+ logger.warning("任务数据为空")
+ return predictions
+
+ if 'time' not in task_data.columns:
+ logger.warning("缺少时间列")
+ return predictions
+
+ time = task_data['time'].values
+
+ # 检测多个温度通道
+ temp_channels = self._extract_temperature_channels(task_data)
+
+ if not temp_channels:
+ logger.warning("未找到温度通道")
+ return predictions
+
+ logger.info(f"开始分析 {len(temp_channels)} 个温度通道: {temp_channels}")
+
+ for channel in temp_channels:
+ if channel not in task_data.columns:
+ continue
+
+ temp_data = np.array(task_data[channel])
+
+ # 跳过全为NaN的通道
+ if np.all(np.isnan(temp_data)):
+ logger.warning(f"通道 {channel} 全为NaN,跳过")
+ continue
+
+ # 处理NaN值
+ temp_data = self._handle_missing_values(temp_data)
+
+ # 1. 检测温度上升阶段
+ rise_predictions = self._detect_rise_phases(time, temp_data, channel)
+ predictions.extend(rise_predictions)
+
+ # 2. 检测温度峰值时刻
+ peak_predictions = self._detect_peak_moments(time, temp_data, channel)
+ predictions.extend(peak_predictions)
+
+ # 3. 检测温度下降阶段
+ fall_predictions = self._detect_fall_phases(time, temp_data, channel)
+ predictions.extend(fall_predictions)
+
+ # 4. 检测异常温度事件
+ anomaly_predictions = self._detect_anomalies(time, temp_data, channel)
+ predictions.extend(anomaly_predictions)
+
+ # 5. 检测温度平台期
+ plateau_predictions = self._detect_plateau_phases(time, temp_data, channel)
+ predictions.extend(plateau_predictions)
+
+ logger.info(f"温度预测完成,生成了 {len(predictions)} 个预测结果")
+
+ except Exception as e:
+ logger.error(f"温度预测过程中出错: {str(e)}")
+
+ return predictions
+
+ def _extract_temperature_channels(self, task_data: pd.DataFrame) -> List[str]:
+ """提取温度通道列名"""
+ temp_channels = []
+ for col in task_data.columns:
+ col_lower = col.lower()
+ if any(pattern in col_lower for pattern in ['te', 'temp', 'ti', 'temperature']):
+ temp_channels.append(col)
+ return temp_channels
+
+ def _handle_missing_values(self, temp_data: np.ndarray) -> np.ndarray:
+ """处理缺失值"""
+ # 使用前向填充处理NaN值
+ temp_data_clean = temp_data.copy()
+ for i in range(1, len(temp_data_clean)):
+ if np.isnan(temp_data_clean[i]):
+ temp_data_clean[i] = temp_data_clean[i-1]
+
+ # 如果第一个值是NaN,使用下一个非NaN值
+ if np.isnan(temp_data_clean[0]):
+ for i in range(1, len(temp_data_clean)):
+ if not np.isnan(temp_data_clean[i]):
+ temp_data_clean[0] = temp_data_clean[i]
+ break
+
+ return temp_data_clean
+
+ def _detect_rise_phases(self, time: np.ndarray, temp_data: np.ndarray, channel: str) -> List[Prediction]:
+ """检测温度上升阶段"""
+ predictions = []
+
+ try:
+ # 检测温度上升区间
+ rise_intervals = start_end_time_1D(
+ temp_data, self.temp_rise_threshold, postive=True
+ )
+
+ for start_idx, end_idx in rise_intervals:
+ if start_idx < end_idx: # 确保区间有效
+ start_time = time[start_idx]
+ end_time = time[end_idx]
+
+ # 计算上升阶段的特征
+ rise_temp_data = temp_data[start_idx:end_idx+1]
+ max_temp = np.max(rise_temp_data)
+ avg_temp = np.mean(rise_temp_data)
+
+ # 根据温度特征调整标签
+ if max_temp > self.temp_rise_threshold * 1.5:
+ label = f"{channel}_快速上升"
+ else:
+ label = f"{channel}_上升阶段"
+
+ predictions.append(Prediction(
+ self.label_group,
+ label,
+ start=start_time,
+ end=end_time,
+ score=min(0.9, (max_temp - self.temp_rise_threshold) / self.temp_rise_threshold)
+ ))
+
+ logger.info(f"通道 {channel} 检测到 {len(predictions)} 个上升阶段")
+
+ except Exception as e:
+ logger.error(f"检测通道 {channel} 上升阶段时出错: {str(e)}")
+
+ return predictions
+
+ def _detect_peak_moments(self, time: np.ndarray, temp_data: np.ndarray, channel: str) -> List[Prediction]:
+ """检测温度峰值时刻"""
+ predictions = []
+
+ try:
+ # 查找温度峰值
+ peak_times = find_temperature_peaks(
+ temp_data, time,
+ gradient_threshold=self.gradient_threshold,
+ min_peak_height=self.min_peak_height
+ )
+
+ for peak_time in peak_times:
+ # 找到峰值对应的时间索引
+ peak_idx = np.argmin(np.abs(time - peak_time))
+ peak_temp = temp_data[peak_idx]
+
+ # 根据峰值高度调整标签
+ if peak_temp > self.temp_rise_threshold * 2:
+ label = f"{channel}_极高峰值"
+ elif peak_temp > self.temp_rise_threshold * 1.5:
+ label = f"{channel}_高峰值"
+ else:
+ label = f"{channel}_峰值时刻"
+
+ predictions.append(Prediction(
+ self.label_group,
+ label,
+ start=peak_time,
+ end=None, # 时间点标注
+ score=min(0.95, peak_temp / (self.temp_rise_threshold * 2))
+ ))
+
+ logger.info(f"通道 {channel} 检测到 {len(predictions)} 个峰值时刻")
+
+ except Exception as e:
+ logger.error(f"检测通道 {channel} 峰值时刻时出错: {str(e)}")
+
+ return predictions
+
+ def _detect_fall_phases(self, time: np.ndarray, temp_data: np.ndarray, channel: str) -> List[Prediction]:
+ """检测温度下降阶段"""
+ predictions = []
+
+ try:
+ # 检测温度下降区间
+ fall_intervals = start_end_time_1D(
+ temp_data, self.temp_fall_threshold, postive=False
+ )
+
+ for start_idx, end_idx in fall_intervals:
+ if start_idx < end_idx: # 确保区间有效
+ start_time = time[start_idx]
+ end_time = time[end_idx]
+
+ # 计算下降阶段的特征
+ fall_temp_data = temp_data[start_idx:end_idx+1]
+ min_temp = np.min(fall_temp_data)
+ avg_temp = np.mean(fall_temp_data)
+
+ # 根据温度特征调整标签
+ if min_temp < self.temp_fall_threshold * 0.5:
+ label = f"{channel}_快速下降"
+ else:
+ label = f"{channel}_下降阶段"
+
+ predictions.append(Prediction(
+ self.label_group,
+ label,
+ start=start_time,
+ end=end_time,
+ score=min(0.9, (self.temp_fall_threshold - min_temp) / self.temp_fall_threshold)
+ ))
+
+ logger.info(f"通道 {channel} 检测到 {len(predictions)} 个下降阶段")
+
+ except Exception as e:
+ logger.error(f"检测通道 {channel} 下降阶段时出错: {str(e)}")
+
+ return predictions
+
+ def _detect_anomalies(self, time: np.ndarray, temp_data: np.ndarray, channel: str) -> List[Prediction]:
+ """检测异常温度事件"""
+ predictions = []
+
+ try:
+ # 检测温度异常
+ anomalies = detect_temperature_anomalies(
+ temp_data, time,
+ threshold=self.anomaly_threshold
+ )
+
+ for start_time, end_time, anomaly_type in anomalies:
+ predictions.append(Prediction(
+ self.label_group,
+ f"{channel}_{anomaly_type}",
+ start=start_time,
+ end=end_time,
+ score=0.8 # 异常检测的置信度
+ ))
+
+ logger.info(f"通道 {channel} 检测到 {len(predictions)} 个异常事件")
+
+ except Exception as e:
+ logger.error(f"检测通道 {channel} 异常事件时出错: {str(e)}")
+
+ return predictions
+
+ def _detect_plateau_phases(self, time: np.ndarray, temp_data: np.ndarray, channel: str) -> List[Prediction]:
+ """检测温度平台期"""
+ predictions = []
+
+ try:
+ # 计算温度梯度
+ gradient = np.gradient(temp_data)
+
+ # 检测梯度较小的区间(平台期)
+ plateau_threshold = self.gradient_threshold * 0.1 # 平台期的梯度阈值
+ plateau_mask = np.abs(gradient) < plateau_threshold
+
+ # 找到连续的平台期区间
+ plateau_intervals = start_end_time_1D(
+ plateau_mask.astype(float), 0.5, postive=True
+ )
+
+ for start_idx, end_idx in plateau_intervals:
+ if end_idx - start_idx > 10: # 平台期至少持续10个采样点
+ start_time = time[start_idx]
+ end_time = time[end_idx]
+
+ # 计算平台期的特征
+ plateau_temp_data = temp_data[start_idx:end_idx+1]
+ avg_temp = np.mean(plateau_temp_data)
+ temp_std = np.std(plateau_temp_data)
+
+ # 根据温度水平调整标签
+ if avg_temp > self.temp_rise_threshold:
+ label = f"{channel}_高温平台期"
+ elif avg_temp > self.temp_fall_threshold:
+ label = f"{channel}_中温平台期"
+ else:
+ label = f"{channel}_低温平台期"
+
+ # 根据稳定性计算置信度
+ stability_score = max(0.5, 1.0 - temp_std / avg_temp) if avg_temp > 0 else 0.5
+
+ predictions.append(Prediction(
+ self.label_group,
+ label,
+ start=start_time,
+ end=end_time,
+ score=stability_score
+ ))
+
+ logger.info(f"通道 {channel} 检测到 {len(predictions)} 个平台期")
+
+ except Exception as e:
+ logger.error(f"检测通道 {channel} 平台期时出错: {str(e)}")
+
+ return predictions
+
+ def update_thresholds(self, new_thresholds: dict):
+ """更新阈值参数"""
+ if 'temp_rise_threshold' in new_thresholds:
+ self.temp_rise_threshold = new_thresholds['temp_rise_threshold']
+ if 'temp_fall_threshold' in new_thresholds:
+ self.temp_fall_threshold = new_thresholds['temp_fall_threshold']
+ if 'gradient_threshold' in new_thresholds:
+ self.gradient_threshold = new_thresholds['gradient_threshold']
+ if 'anomaly_threshold' in new_thresholds:
+ self.anomaly_threshold = new_thresholds['anomaly_threshold']
+ if 'min_peak_height' in new_thresholds:
+ self.min_peak_height = new_thresholds['min_peak_height']
+
+ logger.info(f"阈值参数已更新: {new_thresholds}")
+
+ def get_current_thresholds(self) -> dict:
+ """获取当前阈值参数"""
+ return {
+ 'temp_rise_threshold': self.temp_rise_threshold,
+ 'temp_fall_threshold': self.temp_fall_threshold,
+ 'gradient_threshold': self.gradient_threshold,
+ 'anomaly_threshold': self.anomaly_threshold,
+ 'min_peak_height': self.min_peak_height
+ }
diff --git a/ml-backends/temperature_annotation/test_e2e.py b/ml-backends/temperature_annotation/test_e2e.py
new file mode 100644
index 0000000..fb1eb39
--- /dev/null
+++ b/ml-backends/temperature_annotation/test_e2e.py
@@ -0,0 +1,635 @@
+"""
+温度标注ML Backend的端到端测试
+
+测试完整的ML Backend服务流程,包括:
+- 服务启动和健康检查
+- 数据加载和预处理
+- 温度预测和结果转换
+- API接口响应
+- 模型训练和参数更新
+"""
+
+import pytest
+import requests
+import time
+import json
+import pandas as pd
+import numpy as np
+import tempfile
+import os
+import subprocess
+import signal
+from unittest.mock import patch, MagicMock
+import logging
+
+# 设置日志
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+# 测试配置
+TEST_CONFIG = {
+ 'host': 'localhost',
+ 'port': 9090,
+ 'base_url': 'http://localhost:9090',
+ 'timeout': 30,
+ 'retry_attempts': 3
+}
+
+
+class TestTemperatureBackendE2E:
+ """温度标注ML Backend端到端测试类"""
+
+ @classmethod
+ def setup_class(cls):
+ """测试类初始化,启动ML Backend服务"""
+ cls.service_process = None
+ cls.service_url = f"{TEST_CONFIG['base_url']}"
+
+ # 检查服务是否已经在运行
+ if cls._is_service_running():
+ logger.info("ML Backend服务已在运行")
+ return
+
+ # 启动服务
+ cls._start_service()
+
+ # 等待服务启动
+ cls._wait_for_service()
+
+ @classmethod
+ def teardown_class(cls):
+ """测试类清理,停止ML Backend服务"""
+ if cls.service_process:
+ cls._stop_service()
+
+ @classmethod
+ def _is_service_running(cls):
+ """检查服务是否在运行"""
+ try:
+ response = requests.get(f"{cls.service_url}/health", timeout=5)
+ return response.status_code == 200
+ except:
+ return False
+
+ @classmethod
+ def _start_service(cls):
+ """启动ML Backend服务"""
+ try:
+ logger.info("启动温度标注ML Backend服务...")
+
+ # 使用Python直接启动服务
+ cmd = [
+ 'python', '_wsgi.py',
+ '--host', TEST_CONFIG['host'],
+ '--port', str(TEST_CONFIG['port']),
+ '--debug'
+ ]
+
+ cls.service_process = subprocess.Popen(
+ cmd,
+ cwd=os.path.dirname(__file__),
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE
+ )
+
+ logger.info(f"服务启动命令: {' '.join(cmd)}")
+ logger.info(f"服务进程ID: {cls.service_process.pid}")
+
+ except Exception as e:
+ logger.error(f"启动服务失败: {e}")
+ raise
+
+ @classmethod
+ def _stop_service(cls):
+ """停止ML Backend服务"""
+ if cls.service_process:
+ logger.info("停止ML Backend服务...")
+ try:
+ cls.service_process.terminate()
+ cls.service_process.wait(timeout=10)
+ except subprocess.TimeoutExpired:
+ cls.service_process.kill()
+ cls.service_process.wait()
+ finally:
+ cls.service_process = None
+
+ @classmethod
+ def _wait_for_service(cls):
+ """等待服务启动完成"""
+ logger.info("等待服务启动...")
+ max_wait = 60 # 最大等待60秒
+ wait_interval = 2
+
+ for i in range(0, max_wait, wait_interval):
+ if cls._is_service_running():
+ logger.info(f"服务启动成功,耗时 {i} 秒")
+ return
+ time.sleep(wait_interval)
+
+ raise TimeoutError("服务启动超时")
+
+ def _make_request(self, method, endpoint, data=None, headers=None):
+ """发送HTTP请求"""
+ url = f"{self.service_url}{endpoint}"
+
+ for attempt in range(TEST_CONFIG['retry_attempts']):
+ try:
+ if method.upper() == 'GET':
+ response = requests.get(url, headers=headers, timeout=TEST_CONFIG['timeout'])
+ elif method.upper() == 'POST':
+ response = requests.post(url, json=data, headers=headers, timeout=TEST_CONFIG['timeout'])
+ else:
+ raise ValueError(f"不支持的HTTP方法: {method}")
+
+ return response
+
+ except requests.exceptions.RequestException as e:
+ if attempt == TEST_CONFIG['retry_attempts'] - 1:
+ raise
+ logger.warning(f"请求失败,重试 {attempt + 1}/{TEST_CONFIG['retry_attempts']}: {e}")
+ time.sleep(1)
+
+ def test_01_service_health(self):
+ """测试服务健康状态"""
+ logger.info("测试服务健康状态...")
+
+ response = self._make_request('GET', '/health')
+
+ assert response.status_code == 200
+ health_data = response.json()
+
+ assert 'status' in health_data
+ assert health_data['status'] == 'healthy'
+ assert 'model_version' in health_data
+ assert 'predictor_initialized' in health_data
+
+ logger.info(f"服务健康状态: {health_data}")
+
+ def test_02_model_info(self):
+ """测试模型信息接口"""
+ logger.info("测试模型信息接口...")
+
+ response = self._make_request('GET', '/model/info')
+
+ assert response.status_code == 200
+ model_info = response.json()
+
+ assert 'model_version' in model_info
+ assert 'training_count' in model_info
+ assert 'last_training_time' in model_info
+ assert 'current_thresholds' in model_info
+
+ # 验证阈值参数
+ thresholds = model_info['current_thresholds']
+ assert 'temp_rise_threshold' in thresholds
+ assert 'temp_fall_threshold' in thresholds
+ assert 'gradient_threshold' in thresholds
+
+ logger.info(f"模型信息: {model_info}")
+
+ def test_03_prediction_with_sample_data(self):
+ """测试使用样本数据的预测功能"""
+ logger.info("测试使用样本数据的预测功能...")
+
+ # 创建样本温度数据
+ sample_data = self._create_sample_temperature_data()
+
+ # 准备预测请求
+ prediction_request = {
+ "tasks": [
+ {
+ "data": {
+ "shot": 12345,
+ "csv": "sample_temperature_data.csv"
+ }
+ }
+ ]
+ }
+
+ # 模拟数据加载
+ with patch('utils.load_data') as mock_load:
+ mock_load.return_value = {12345: sample_data}
+
+ response = self._make_request('POST', '/predict', data=prediction_request)
+
+ assert response.status_code == 200
+ prediction_result = response.json()
+
+ # 验证预测结果结构
+ assert 'predictions' in prediction_result
+ predictions = prediction_result['predictions']
+
+ if len(predictions) > 0:
+ # 验证预测结果格式
+ first_prediction = predictions[0]
+ assert 'model_version' in first_prediction
+ assert 'result' in first_prediction
+
+ # 验证结果内容
+ results = first_prediction['result']
+ for result in results:
+ assert 'from_name' in result
+ assert 'to_name' in result
+ assert 'type' in result
+ assert 'value' in result
+ assert result['type'] == 'timeserieslabels'
+
+ logger.info(f"预测成功,生成了 {len(predictions)} 个预测结果")
+ else:
+ logger.info("预测完成,但未生成预测结果(可能是阈值设置过高)")
+
+ def test_04_prediction_with_realistic_data(self):
+ """测试使用真实场景数据的预测功能"""
+ logger.info("测试使用真实场景数据的预测功能...")
+
+ # 创建更真实的温度数据(模拟托卡马克实验)
+ realistic_data = self._create_realistic_temperature_data()
+
+ prediction_request = {
+ "tasks": [
+ {
+ "data": {
+ "shot": 240830001,
+ "csv": "realistic_temperature_data.csv"
+ }
+ }
+ ]
+ }
+
+ with patch('utils.load_data') as mock_load:
+ mock_load.return_value = {240830001: realistic_data}
+
+ response = self._make_request('POST', '/predict', data=prediction_request)
+
+ assert response.status_code == 200
+ prediction_result = response.json()
+
+ # 验证预测结果
+ assert 'predictions' in prediction_result
+ predictions = prediction_result['predictions']
+
+ if len(predictions) > 0:
+ # 分析预测结果类型
+ result_types = set()
+ for pred in predictions:
+ for result in pred['result']:
+ label = result['value']['timeserieslabels'][0]
+ result_types.add(label)
+
+ logger.info(f"检测到的温度事件类型: {result_types}")
+
+ # 验证是否检测到预期的温度事件
+ expected_events = ['上升阶段', '峰值时刻', '下降阶段', '平台期']
+ detected_events = [event for event in expected_events if any(event in label for label in result_types)]
+
+ logger.info(f"检测到的事件: {detected_events}")
+ assert len(detected_events) > 0, "应该检测到至少一种温度事件"
+
+ def test_05_model_training(self):
+ """测试模型训练功能"""
+ logger.info("测试模型训练功能...")
+
+ # 准备训练数据(模拟人工标注)
+ training_data = {
+ "event": "ANNOTATION_CREATED",
+ "annotation": {
+ "id": 1,
+ "result": [
+ {
+ "from_name": "temperature_events",
+ "type": "timeserieslabels",
+ "value": {
+ "start": 2.0,
+ "end": 4.0,
+ "timeserieslabels": ["Te_1_上升阶段"]
+ }
+ }
+ ]
+ },
+ "task": {
+ "id": 1,
+ "data": {
+ "shot": 240830001
+ }
+ }
+ }
+
+ # 发送训练请求
+ response = self._make_request('POST', '/fit', data=training_data)
+
+ # 训练接口通常返回200状态码
+ assert response.status_code in [200, 204]
+
+ # 验证模型参数是否更新
+ time.sleep(2) # 等待训练完成
+
+ model_info_response = self._make_request('GET', '/model/info')
+ assert model_info_response.status_code == 200
+
+ updated_model_info = model_info_response.json()
+ assert updated_model_info['training_count'] > 0
+
+ logger.info(f"模型训练完成,训练次数: {updated_model_info['training_count']}")
+
+ def test_06_threshold_optimization(self):
+ """测试阈值优化功能"""
+ logger.info("测试阈值优化功能...")
+
+ # 获取当前阈值
+ initial_response = self._make_request('GET', '/model/info')
+ initial_thresholds = initial_response.json()['current_thresholds']
+
+ logger.info(f"初始阈值: {initial_thresholds}")
+
+ # 发送多个标注数据来触发阈值优化
+ training_events = [
+ {
+ "event": "ANNOTATION_CREATED",
+ "annotation": {
+ "id": 2,
+ "result": [
+ {
+ "from_name": "temperature_events",
+ "type": "timeserieslabels",
+ "value": {
+ "start": 1.0,
+ "end": 3.0,
+ "timeserieslabels": ["Te_1_峰值时刻"]
+ }
+ }
+ ]
+ },
+ "task": {"id": 2, "data": {"shot": 240830002}}
+ },
+ {
+ "event": "ANNOTATION_CREATED",
+ "annotation": {
+ "id": 3,
+ "result": [
+ {
+ "from_name": "temperature_events",
+ "type": "timeserieslabels",
+ "value": {
+ "start": 5.0,
+ "end": 7.0,
+ "timeserieslabels": ["Te_1_下降阶段"]
+ }
+ }
+ ]
+ },
+ "task": {"id": 3, "data": {"shot": 240830003}}
+ }
+ ]
+
+ # 发送训练事件
+ for event_data in training_events:
+ response = self._make_request('POST', '/fit', data=event_data)
+ assert response.status_code in [200, 204]
+
+ # 等待训练完成
+ time.sleep(3)
+
+ # 检查阈值是否被优化
+ final_response = self._make_request('GET', '/model/info')
+ final_thresholds = final_response.json()['current_thresholds']
+
+ logger.info(f"优化后阈值: {final_thresholds}")
+
+ # 验证阈值是否发生变化(由于是模拟数据,可能变化很小)
+ threshold_changed = False
+ for key in initial_thresholds:
+ if abs(initial_thresholds[key] - final_thresholds[key]) > 0.01:
+ threshold_changed = True
+ break
+
+ if threshold_changed:
+ logger.info("阈值优化成功")
+ else:
+ logger.info("阈值未发生变化(可能是优化逻辑或数据特征导致的)")
+
+ def test_07_error_handling(self):
+ """测试错误处理"""
+ logger.info("测试错误处理...")
+
+ # 测试无效的预测请求
+ invalid_request = {
+ "tasks": [
+ {
+ "data": {
+ "shot": "invalid_shot",
+ "csv": "nonexistent.csv"
+ }
+ }
+ ]
+ }
+
+ response = self._make_request('POST', '/predict', data=invalid_request)
+
+ # 应该返回200状态码,但预测结果可能为空
+ assert response.status_code == 200
+
+ # 测试无效的训练数据
+ invalid_training_data = {
+ "event": "INVALID_EVENT",
+ "data": "invalid_data"
+ }
+
+ response = self._make_request('POST', '/fit', data=invalid_training_data)
+ assert response.status_code in [200, 204, 400]
+
+ logger.info("错误处理测试完成")
+
+ def test_08_performance_test(self):
+ """测试性能"""
+ logger.info("测试性能...")
+
+ # 创建大量数据
+ large_data = self._create_large_temperature_data()
+
+ prediction_request = {
+ "tasks": [
+ {
+ "data": {
+ "shot": 999999,
+ "csv": "large_temperature_data.csv"
+ }
+ }
+ ]
+ }
+
+ with patch('utils.load_data') as mock_load:
+ mock_load.return_value = {999999: large_data}
+
+ start_time = time.time()
+ response = self._make_request('POST', '/predict', data=prediction_request)
+ end_time = time.time()
+
+ processing_time = end_time - start_time
+
+ assert response.status_code == 200
+ assert processing_time < 10.0, f"预测处理时间过长: {processing_time:.2f}秒"
+
+ logger.info(f"性能测试完成,处理时间: {processing_time:.2f}秒")
+
+ def test_09_model_reset(self):
+ """测试模型重置功能"""
+ logger.info("测试模型重置功能...")
+
+ # 获取当前模型状态
+ initial_response = self._make_request('GET', '/model/info')
+ initial_info = initial_response.json()
+
+ # 重置模型
+ reset_response = self._make_request('POST', '/model/reset')
+ assert reset_response.status_code in [200, 204]
+
+ # 等待重置完成
+ time.sleep(2)
+
+ # 检查重置后的状态
+ final_response = self._make_request('GET', '/model/info')
+ final_info = final_response.json()
+
+ # 验证重置
+ assert final_info['training_count'] == 0
+ assert final_info['model_version'] == 'temperature_v1.0'
+
+ logger.info("模型重置测试完成")
+
+ def _create_sample_temperature_data(self):
+ """创建样本温度数据"""
+ time_data = np.linspace(0, 10, 1000)
+ temp_data = 500 + 1000 * np.sin(time_data) + 100 * np.random.randn(1000)
+
+ df = pd.DataFrame({
+ 'time': time_data,
+ 'Te_1': temp_data,
+ 'Te_2': temp_data * 0.8 + 50 * np.random.randn(1000)
+ })
+
+ return df
+
+ def _create_realistic_temperature_data(self):
+ """创建真实的温度数据(模拟托卡马克实验)"""
+ time_data = np.linspace(0, 15, 1500)
+
+ # 模拟托卡马克温度曲线:加热期 -> 平顶期 -> 衰减期
+ temp_curve = np.zeros_like(time_data)
+
+ # 加热期 (0-3秒)
+ heating_mask = (time_data >= 0) & (time_data < 3)
+ temp_curve[heating_mask] = 200 + 800 * (time_data[heating_mask] / 3)
+
+ # 平顶期 (3-8秒)
+ plateau_mask = (time_data >= 3) & (time_data < 8)
+ temp_curve[plateau_mask] = 1000 + 50 * np.sin(2 * np.pi * time_data[plateau_mask])
+
+ # 衰减期 (8-15秒)
+ decay_mask = (time_data >= 8) & (time_data <= 15)
+ decay_time = time_data[decay_mask] - 8
+ temp_curve[decay_mask] = 1000 * np.exp(-decay_time / 2)
+
+ # 添加噪声
+ temp_curve += 20 * np.random.randn(len(temp_curve))
+
+ # 创建多通道数据
+ df = pd.DataFrame({
+ 'time': time_data,
+ 'Te_1': temp_curve,
+ 'Te_2': temp_curve * 0.9 + 30 * np.random.randn(len(temp_curve)),
+ 'Te_3': temp_curve * 1.1 - 50 * np.random.randn(len(temp_curve))
+ })
+
+ return df
+
+ def _create_large_temperature_data(self):
+ """创建大量温度数据用于性能测试"""
+ time_data = np.linspace(0, 30, 3000)
+ temp_data = 500 + 1000 * np.sin(time_data / 5) + 100 * np.random.randn(3000)
+
+ df = pd.DataFrame({
+ 'time': time_data,
+ 'Te_1': temp_data,
+ 'Te_2': temp_data * 0.8 + 50 * np.random.randn(3000),
+ 'Te_3': temp_data * 1.2 - 80 * np.random.randn(3000),
+ 'Te_4': temp_data * 0.6 + 120 * np.random.randn(3000)
+ })
+
+ return df
+
+
+class TestTemperatureBackendIntegration:
+ """温度标注ML Backend集成测试类"""
+
+ def test_model_initialization(self):
+ """测试模型初始化"""
+ from model import TemperatureModel
+
+ model = TemperatureModel()
+ model.setup()
+
+ assert model.get("model_version") == "temperature_v1.0"
+ assert hasattr(model, 'predictor')
+ assert model.get('training_count') == 0
+
+ def test_predictor_functionality(self):
+ """测试预测器功能"""
+ from temperature_predictor import TemperaturePredictor
+ from prediction import Prediction
+
+ predictor = TemperaturePredictor()
+
+ # 创建测试数据
+ test_data = pd.DataFrame({
+ 'time': [0, 1, 2, 3, 4, 5],
+ 'Te_1': [100, 200, 300, 400, 500, 600],
+ 'Te_2': [150, 250, 350, 450, 550, 650]
+ })
+
+ # 执行预测
+ predictions = predictor.user_predict(test_data)
+
+ # 验证预测结果
+ assert isinstance(predictions, list)
+ if len(predictions) > 0:
+ assert all(isinstance(p, Prediction) for p in predictions)
+
+ def test_data_processing_pipeline(self):
+ """测试数据处理流水线"""
+ from utils import validate_temperature_data, preprocess_temperature_data
+
+ # 创建测试数据
+ test_data = pd.DataFrame({
+ 'time': [0, 1, 2, 3, 4],
+ 'Te_1': [100, 200, 300, 400, 500],
+ 'Te_2': [150, 250, 350, 450, 550]
+ })
+
+ # 验证数据
+ assert validate_temperature_data(test_data) == True
+
+ # 预处理数据
+ processed_data = preprocess_temperature_data(test_data)
+ assert len(processed_data) == len(test_data)
+
+ def test_prediction_conversion(self):
+ """测试预测结果转换"""
+ from prediction import Prediction, convert_to_labelstudio_form
+
+ # 创建预测结果
+ predictions = [
+ Prediction("temperature_events", "Te_1_上升阶段", 1.0, 3.0),
+ Prediction("temperature_events", "Te_1_峰值时刻", 2.0, None)
+ ]
+
+ # 转换为Label Studio格式
+ ls_format = convert_to_labelstudio_form(predictions, "test_model")
+
+ assert len(ls_format) == 1
+ assert ls_format[0]["model_version"] == "test_model"
+ assert len(ls_format[0]["result"]) == 2
+
+
+if __name__ == "__main__":
+ # 运行端到端测试
+ pytest.main([__file__, "-v", "--tb=short"])
diff --git a/ml-backends/temperature_annotation/test_temperature_backend.py b/ml-backends/temperature_annotation/test_temperature_backend.py
new file mode 100644
index 0000000..1ff7635
--- /dev/null
+++ b/ml-backends/temperature_annotation/test_temperature_backend.py
@@ -0,0 +1,315 @@
+"""
+温度标注ML Backend的测试文件
+
+测试各个组件的功能和集成
+"""
+
+import pytest
+import numpy as np
+import pandas as pd
+from unittest.mock import Mock, patch, MagicMock
+import tempfile
+import os
+
+# 导入被测试的模块
+from model import TemperatureModel
+from temperature_predictor import TemperaturePredictor
+from prediction import Prediction, start_end_time_1D, convert_to_labelstudio_form
+from utils import validate_temperature_data, preprocess_temperature_data, extract_temperature_channels
+
+
+class TestPrediction:
+ """测试预测结果数据结构"""
+
+ def test_prediction_creation(self):
+ """测试预测结果创建"""
+ pred = Prediction("test_group", "test_label", 1.0, 2.0, 0.8)
+
+ assert pred.label_group == "test_group"
+ assert pred.label == "test_label"
+ assert pred.start == 1.0
+ assert pred.end == 2.0
+ assert pred.score == 0.8
+
+ def test_prediction_repr(self):
+ """测试预测结果的字符串表示"""
+ pred = Prediction("test_group", "test_label", 1.0, 2.0)
+ repr_str = repr(pred)
+
+ assert "TemperaturePrediction" in repr_str
+ assert "test_label" in repr_str
+ assert "1.0" in repr_str
+ assert "2.0" in repr_str
+
+
+class TestPredictionFunctions:
+ """测试预测函数"""
+
+ def test_start_end_time_1D_positive(self):
+ """测试正向阈值检测"""
+ data = np.array([0, 1, 2, 3, 4, 5])
+ threshold = 2.5
+
+ intervals = start_end_time_1D(data, threshold, postive=True)
+
+ assert len(intervals) == 1
+ assert intervals[0] == (3, 5) # 索引3-5的值大于2.5
+
+ def test_start_end_time_1D_negative(self):
+ """测试负向阈值检测"""
+ data = np.array([5, 4, 3, 2, 1, 0])
+ threshold = 2.5
+
+ intervals = start_end_time_1D(data, threshold, postive=False)
+
+ assert len(intervals) == 1
+ assert intervals[0] == (3, 5) # 索引3-5的值小于2.5
+
+ def test_convert_to_labelstudio_form(self):
+ """测试转换为Label Studio格式"""
+ predictions = [
+ Prediction("group1", "label1", 1.0, 2.0),
+ Prediction("group2", "label2", 3.0, None)
+ ]
+
+ result = convert_to_labelstudio_form(predictions, "test_model")
+
+ assert len(result) == 1
+ assert result[0]["model_version"] == "test_model"
+ assert len(result[0]["result"]) == 2
+
+ # 检查第一个预测结果
+ first_result = result[0]["result"][0]
+ assert first_result["from_name"] == "group1"
+ assert first_result["to_name"] == "ts"
+ assert first_result["type"] == "timeserieslabels"
+ assert first_result["value"]["start"] == 1.0
+ assert first_result["value"]["end"] == 2.0
+ assert first_result["value"]["timeserieslabels"] == ["label1"]
+
+ # 检查第二个预测结果(时间点)
+ second_result = result[0]["result"][1]
+ assert second_result["value"]["start"] == 3.0
+ assert second_result["value"]["end"] == 3.0 # 应该自动设置为start
+
+
+class TestUtils:
+ """测试工具函数"""
+
+ def test_validate_temperature_data_valid(self):
+ """测试有效温度数据验证"""
+ df = pd.DataFrame({
+ 'time': [0, 1, 2, 3],
+ 'Te_1': [100, 200, 300, 400],
+ 'Te_2': [150, 250, 350, 450]
+ })
+
+ assert validate_temperature_data(df) == True
+
+ def test_validate_temperature_data_missing_time(self):
+ """测试缺少时间列的数据验证"""
+ df = pd.DataFrame({
+ 'Te_1': [100, 200, 300, 400],
+ 'Te_2': [150, 250, 350, 450]
+ })
+
+ assert validate_temperature_data(df) == False
+
+ def test_validate_temperature_data_empty(self):
+ """测试空数据验证"""
+ df = pd.DataFrame()
+
+ assert validate_temperature_data(df) == False
+
+ def test_extract_temperature_channels(self):
+ """测试温度通道提取"""
+ df = pd.DataFrame({
+ 'time': [0, 1, 2],
+ 'Te_1': [100, 200, 300],
+ 'Te_2': [150, 250, 350],
+ 'other': [10, 20, 30]
+ })
+
+ channels = extract_temperature_channels(df)
+
+ assert 'Te_1' in channels
+ assert 'Te_2' in channels
+ assert 'other' not in channels
+ assert len(channels) == 2
+
+
+class TestTemperaturePredictor:
+ """测试温度预测器"""
+
+ def setup_method(self):
+ """每个测试方法前的设置"""
+ self.predictor = TemperaturePredictor()
+
+ def test_initialization(self):
+ """测试初始化"""
+ assert self.predictor.label_group == 'temperature_events'
+ assert self.predictor.temp_rise_threshold == 1000.0
+ assert self.predictor.temp_fall_threshold == 500.0
+ assert self.predictor.gradient_threshold == 100.0
+
+ def test_extract_temperature_channels(self):
+ """测试温度通道提取"""
+ df = pd.DataFrame({
+ 'time': [0, 1, 2],
+ 'Te_1': [100, 200, 300],
+ 'Te_2': [150, 250, 350],
+ 'other': [10, 20, 30]
+ })
+
+ channels = self.predictor._extract_temperature_channels(df)
+
+ assert 'Te_1' in channels
+ assert 'Te_2' in channels
+ assert 'other' not in channels
+
+ def test_handle_missing_values(self):
+ """测试缺失值处理"""
+ # 创建包含NaN的数据
+ temp_data = np.array([100, np.nan, 300, np.nan, 500])
+
+ cleaned_data = self.predictor._handle_missing_values(temp_data)
+
+ # 检查NaN值是否被处理
+ assert not np.any(np.isnan(cleaned_data))
+ assert len(cleaned_data) == 5
+
+ def test_update_thresholds(self):
+ """测试阈值更新"""
+ new_thresholds = {
+ 'temp_rise_threshold': 1200.0,
+ 'temp_fall_threshold': 600.0
+ }
+
+ self.predictor.update_thresholds(new_thresholds)
+
+ assert self.predictor.temp_rise_threshold == 1200.0
+ assert self.predictor.temp_fall_threshold == 600.0
+
+ def test_get_current_thresholds(self):
+ """测试获取当前阈值"""
+ thresholds = self.predictor.get_current_thresholds()
+
+ assert 'temp_rise_threshold' in thresholds
+ assert 'temp_fall_threshold' in thresholds
+ assert 'gradient_threshold' in thresholds
+ assert thresholds['temp_rise_threshold'] == 1000.0
+
+
+class TestTemperatureModel:
+ """测试温度标注模型"""
+
+ def setup_method(self):
+ """每个测试方法前的设置"""
+ with patch('utils.load_data'):
+ with patch('utils.validate_temperature_data', return_value=True):
+ with patch('utils.preprocess_temperature_data'):
+ self.model = TemperatureModel()
+ self.model.setup()
+
+ def test_model_initialization(self):
+ """测试模型初始化"""
+ assert self.model.get("model_version") == "temperature_v1.0"
+ assert hasattr(self.model, 'predictor')
+ assert isinstance(self.model.predictor, TemperaturePredictor)
+
+ def test_get_model_info(self):
+ """测试获取模型信息"""
+ info = self.model.get_model_info()
+
+ assert 'model_version' in info
+ assert 'training_count' in info
+ assert 'last_training_time' in info
+ assert 'current_thresholds' in info
+ assert info['model_version'] == 'temperature_v1.0'
+
+ def test_health_check(self):
+ """测试健康检查"""
+ health = self.model.health_check()
+
+ assert 'status' in health
+ assert health['status'] == 'healthy'
+ assert 'model_version' in health
+ assert 'predictor_initialized' in health
+
+ def test_reset_model(self):
+ """测试模型重置"""
+ # 修改一些参数
+ self.model.set('training_count', 100)
+
+ # 重置模型
+ self.model.reset_model()
+
+ # 检查是否重置
+ assert self.model.get('training_count') == 0
+ assert self.model.get('model_version') == 'temperature_v1.0'
+
+
+class TestIntegration:
+ """集成测试"""
+
+ def test_end_to_end_prediction(self):
+ """测试端到端预测流程"""
+ with patch('utils.load_data') as mock_load:
+ with patch('utils.validate_temperature_data', return_value=True):
+ with patch('utils.preprocess_temperature_data') as mock_preprocess:
+ # 创建模型
+ model = TemperatureModel()
+ model.setup()
+
+ # 模拟数据
+ test_data = pd.DataFrame({
+ 'time': [0, 1, 2, 3, 4],
+ 'Te_1': [100, 200, 300, 400, 500],
+ 'Te_2': [150, 250, 350, 450, 550]
+ })
+
+ mock_load.return_value = {123: test_data}
+ mock_preprocess.return_value = test_data
+
+ # 执行预测
+ tasks = [{'data': {'shot': 123, 'csv': 'test.csv'}}]
+ response = model.predict(tasks)
+
+ # 验证响应
+ assert hasattr(response, 'predictions')
+ assert isinstance(response.predictions, list)
+
+ def test_prediction_with_mock_data(self):
+ """测试使用模拟数据的预测"""
+ # 创建模拟的预测器
+ mock_predictor = Mock()
+ mock_predictor.user_predict.return_value = [
+ Prediction("test_group", "test_label", 1.0, 2.0)
+ ]
+
+ # 创建模型并替换预测器
+ model = TemperatureModel()
+ model.predictor = mock_predictor
+
+ # 模拟数据加载
+ with patch.object(model, 'get_data') as mock_get_data:
+ test_data = pd.DataFrame({
+ 'time': [0, 1, 2],
+ 'Te_1': [100, 200, 300]
+ })
+ mock_get_data.return_value = {123: test_data}
+
+ # 执行预测
+ tasks = [{'data': {'shot': 123, 'csv': 'test.csv'}}]
+ response = model.predict(tasks)
+
+ # 验证预测器被调用
+ mock_predictor.user_predict.assert_called_once()
+
+ # 验证响应
+ assert len(response.predictions) > 0
+
+
+if __name__ == "__main__":
+ # 运行测试
+ pytest.main([__file__, "-v"])
diff --git a/ml-backends/temperature_annotation/utils.py b/ml-backends/temperature_annotation/utils.py
new file mode 100644
index 0000000..a9c2a8e
--- /dev/null
+++ b/ml-backends/temperature_annotation/utils.py
@@ -0,0 +1,296 @@
+"""
+数据工具模块
+
+提供温度数据的加载、处理和工具函数
+"""
+
+import pandas as pd
+import requests
+import concurrent.futures
+from typing import Dict, Optional
+import numpy as np
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+def load_data(urls: Dict[int, str]) -> Dict[int, pd.DataFrame]:
+ """
+ 并发加载温度数据
+
+ Args:
+ urls: 炮号到URL的映射字典
+
+ Returns:
+ 炮号到数据DataFrame的映射字典
+ """
+ def fetch_csv(shot_url_pair):
+ shot, url = shot_url_pair
+ try:
+ response = requests.get(url, timeout=30)
+ response.raise_for_status()
+ df = pd.read_csv(pd.StringIO(response.text))
+ logger.info(f"成功加载炮号 {shot} 的数据,形状: {df.shape}")
+ return shot, df
+ except Exception as e:
+ logger.error(f"加载炮号 {shot} 数据失败: {e}")
+ return shot, None
+
+ data_dict = {}
+ logger.info(f"开始加载 {len(urls)} 个炮号的数据")
+
+ with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
+ results = executor.map(fetch_csv, urls.items())
+
+ for shot, df in results:
+ if df is not None:
+ data_dict[shot] = df
+
+ logger.info(f"成功加载 {len(data_dict)} 个炮号的数据")
+ return data_dict
+
+
+def validate_temperature_data(df: pd.DataFrame, required_columns: list = None) -> bool:
+ """
+ 验证温度数据的有效性
+
+ Args:
+ df: 温度数据DataFrame
+ required_columns: 必需的列名列表
+
+ Returns:
+ 数据是否有效
+ """
+ if df is None or df.empty:
+ logger.warning("数据为空")
+ return False
+
+ if required_columns:
+ missing_columns = [col for col in required_columns if col not in df.columns]
+ if missing_columns:
+ logger.warning(f"缺少必需的列: {missing_columns}")
+ return False
+
+ # 检查时间列
+ if 'time' not in df.columns:
+ logger.warning("缺少时间列")
+ return False
+
+ # 检查时间数据的单调性
+ time_data = df['time'].values
+ if not np.all(np.diff(time_data) >= 0):
+ logger.warning("时间数据不是单调递增的")
+ return False
+
+ # 检查数值列的有效性
+ numeric_columns = df.select_dtypes(include=[np.number]).columns
+ for col in numeric_columns:
+ if col == 'time':
+ continue
+ if df[col].isnull().all():
+ logger.warning(f"列 {col} 全为空值")
+ return False
+
+ logger.info("温度数据验证通过")
+ return True
+
+
+def preprocess_temperature_data(df: pd.DataFrame,
+ remove_outliers: bool = True,
+ outlier_threshold: float = 3.0,
+ interpolate_missing: bool = True) -> pd.DataFrame:
+ """
+ 预处理温度数据
+
+ Args:
+ df: 原始温度数据
+ remove_outliers: 是否移除异常值
+ outlier_threshold: 异常值检测阈值(标准差倍数)
+ interpolate_missing: 是否插值缺失值
+
+ Returns:
+ 预处理后的数据
+ """
+ df_processed = df.copy()
+
+ # 插值缺失值
+ if interpolate_missing:
+ numeric_columns = df_processed.select_dtypes(include=[np.number]).columns
+ for col in numeric_columns:
+ if col == 'time':
+ continue
+ if df_processed[col].isnull().any():
+ df_processed[col] = df_processed[col].interpolate(method='linear')
+ logger.info(f"列 {col} 的缺失值已插值")
+
+ # 移除异常值
+ if remove_outliers:
+ numeric_columns = df_processed.select_dtypes(include=[np.number]).columns
+ for col in numeric_columns:
+ if col == 'time':
+ continue
+
+ # 计算统计量
+ mean_val = df_processed[col].mean()
+ std_val = df_processed[col].std()
+
+ # 标记异常值
+ outlier_mask = np.abs(df_processed[col] - mean_val) > outlier_threshold * std_val
+ outlier_count = outlier_mask.sum()
+
+ if outlier_count > 0:
+ # 将异常值替换为均值
+ df_processed.loc[outlier_mask, col] = mean_val
+ logger.info(f"列 {col} 移除了 {outlier_count} 个异常值")
+
+ return df_processed
+
+
+def extract_temperature_channels(df: pd.DataFrame,
+ channel_patterns: list = None) -> list:
+ """
+ 提取温度通道列名
+
+ Args:
+ df: 温度数据DataFrame
+ channel_patterns: 通道名称模式列表
+
+ Returns:
+ 温度通道列名列表
+ """
+ if channel_patterns is None:
+ channel_patterns = ['Te', 'temp', 'Ti', 'temperature']
+
+ temperature_channels = []
+ for col in df.columns:
+ col_lower = col.lower()
+ if any(pattern.lower() in col_lower for pattern in channel_patterns):
+ temperature_channels.append(col)
+
+ logger.info(f"找到 {len(temperature_channels)} 个温度通道: {temperature_channels}")
+ return temperature_channels
+
+
+def calculate_temperature_statistics(df: pd.DataFrame,
+ temperature_channels: list = None) -> Dict:
+ """
+ 计算温度统计信息
+
+ Args:
+ df: 温度数据DataFrame
+ temperature_channels: 温度通道列表
+
+ Returns:
+ 统计信息字典
+ """
+ if temperature_channels is None:
+ temperature_channels = extract_temperature_channels(df)
+
+ stats = {}
+ for channel in temperature_channels:
+ if channel in df.columns:
+ channel_data = df[channel].dropna()
+ if len(channel_data) > 0:
+ stats[channel] = {
+ 'mean': float(channel_data.mean()),
+ 'std': float(channel_data.std()),
+ 'min': float(channel_data.min()),
+ 'max': float(channel_data.max()),
+ 'count': len(channel_data)
+ }
+
+ logger.info(f"计算了 {len(stats)} 个通道的统计信息")
+ return stats
+
+
+def normalize_temperature_data(df: pd.DataFrame,
+ temperature_channels: list = None,
+ method: str = 'zscore') -> pd.DataFrame:
+ """
+ 标准化温度数据
+
+ Args:
+ df: 温度数据DataFrame
+ temperature_channels: 温度通道列表
+ method: 标准化方法 ('zscore', 'minmax', 'robust')
+
+ Returns:
+ 标准化后的数据
+ """
+ if temperature_channels is None:
+ temperature_channels = extract_temperature_channels(df)
+
+ df_normalized = df.copy()
+
+ for channel in temperature_channels:
+ if channel in df_normalized.columns:
+ channel_data = df_normalized[channel].dropna()
+ if len(channel_data) > 0:
+ if method == 'zscore':
+ # Z-score标准化
+ mean_val = channel_data.mean()
+ std_val = channel_data.std()
+ if std_val > 0:
+ df_normalized[channel] = (df_normalized[channel] - mean_val) / std_val
+
+ elif method == 'minmax':
+ # Min-Max标准化
+ min_val = channel_data.min()
+ max_val = channel_data.max()
+ if max_val > min_val:
+ df_normalized[channel] = (df_normalized[channel] - min_val) / (max_val - min_val)
+
+ elif method == 'robust':
+ # 稳健标准化(基于中位数和四分位距)
+ median_val = channel_data.median()
+ q75 = channel_data.quantile(0.75)
+ q25 = channel_data.quantile(0.25)
+ iqr = q75 - q25
+ if iqr > 0:
+ df_normalized[channel] = (df_normalized[channel] - median_val) / iqr
+
+ logger.info(f"通道 {channel} 已使用 {method} 方法标准化")
+
+ return df_normalized
+
+
+def segment_temperature_data(df: pd.DataFrame,
+ segment_length: float = 1.0,
+ overlap: float = 0.0) -> list:
+ """
+ 分段温度数据
+
+ Args:
+ df: 温度数据DataFrame
+ segment_length: 段长度(秒)
+ overlap: 重叠比例(0-1)
+
+ Returns:
+ 数据段列表
+ """
+ segments = []
+ time_data = df['time'].values
+
+ if len(time_data) < 2:
+ return segments
+
+ # 计算时间步长
+ time_step = time_data[1] - time_data[0]
+ segment_samples = int(segment_length / time_step)
+ overlap_samples = int(segment_samples * overlap)
+ hop_samples = segment_samples - overlap_samples
+
+ for i in range(0, len(df) - segment_samples + 1, hop_samples):
+ segment_df = df.iloc[i:i + segment_samples].copy()
+ segment_start = time_data[i]
+ segment_end = time_data[i + segment_samples - 1]
+
+ segments.append({
+ 'data': segment_df,
+ 'start_time': segment_start,
+ 'end_time': segment_end,
+ 'segment_index': i // hop_samples
+ })
+
+ logger.info(f"将数据分为 {len(segments)} 段,每段长度 {segment_length} 秒")
+ return segments