-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_ensemble_final.py
More file actions
449 lines (381 loc) · 20.6 KB
/
Copy pathtest_ensemble_final.py
File metadata and controls
449 lines (381 loc) · 20.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
"""
最终版集成方法测试 - 修复Agent状态提取问题
正确从DetectionStateWrapper对象中提取Agent分析结果
"""
import asyncio
import logging
import json
import time
import numpy as np
from typing import Dict, List
from collections import defaultdict
from agents.keti_ml_classifier import KetiEnhancedMLClassifier
from agents.edit_detection_agent import EditDetectionAgent
from agents.harmfulness_analysis_agent import HarmfulnessAnalysisAgent
from agents.security_assessment_agent import SecurityAssessmentAgent
from core.state_factory import create_detection_state_wrapper
# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class FinalEnsembleDemo:
"""最终版集成分类器演示 - 修复状态提取问题"""
def __init__(self):
self.base_ml_classifier = KetiEnhancedMLClassifier()
self.edit_agent = EditDetectionAgent()
self.harmfulness_agent = HarmfulnessAnalysisAgent()
self.security_agent = SecurityAssessmentAgent()
# KETI类别映射
self.keti_labels = {
0: "Non edited",
1: "Fact updating",
2: "Misinformation injection",
3: "Offensiveness injection",
4: "Behavioral misleading injection",
5: "Bias injection"
}
async def collect_agent_signals_v2(self, state):
"""V2版Agent信号收集 - 正确提取DetectionStateWrapper中的结果"""
try:
signals = {
'edit_score': 0.0, 'edit_has_change': False,
'harm_score': 0.0, 'harm_level': 0,
'security_score': 0.0, 'security_level': 0,
'high_risk_count': 0
}
# 1. 编辑检测
try:
logger.info("执行编辑检测...")
edit_result = await self.edit_agent.execute(state)
logger.info(f"编辑检测结果类型: {type(edit_result)}")
if hasattr(edit_result, 'get'):
# 如果是字典类型
edit_analysis = edit_result.get('edit_analysis_result', {})
if edit_analysis:
signals['edit_score'] = edit_analysis.get('confidence', 0.0)
signals['edit_has_change'] = edit_analysis.get('has_edit', False)
logger.info(f"从字典提取编辑信号: score={signals['edit_score']}, has_change={signals['edit_has_change']}")
elif hasattr(edit_result, 'edit_analysis_result'):
# 如果是DetectionStateWrapper对象,直接访问属性
edit_analysis = edit_result.edit_analysis_result
if edit_analysis:
signals['edit_score'] = edit_analysis.get('confidence', 0.0)
signals['edit_has_change'] = edit_analysis.get('has_edit', False)
logger.info(f"从属性提取编辑信号: score={signals['edit_score']}, has_change={signals['edit_has_change']}")
else:
logger.warning("无法提取编辑检测结果,使用默认值")
except Exception as e:
logger.error(f"编辑检测失败: {e}")
# 2. 有害性分析
try:
logger.info("执行有害性分析...")
harm_result = await self.harmfulness_agent.execute(edit_result if 'edit_result' in locals() else state)
logger.info(f"有害性分析结果类型: {type(harm_result)}")
# 修复:正确的字段名是 harmfulness_analysis,不是 harmfulness_analysis_result
if hasattr(harm_result, 'get'):
# 如果是字典类型
harm_analysis = harm_result.get('harmfulness_analysis', {})
if harm_analysis:
signals['harm_score'] = harm_analysis.overall_score if hasattr(harm_analysis, 'overall_score') else harm_analysis.get('overall_score', 0.0)
# 根据评分计算等级
signals['harm_level'] = int(signals['harm_score'] * 4) # 0-1转换为0-4级别
logger.info(f"从字典提取有害性信号: score={signals['harm_score']}, level={signals['harm_level']}")
elif hasattr(harm_result, 'harmfulness_analysis'):
# 如果是DetectionStateWrapper对象,直接访问属性
harm_analysis = harm_result.harmfulness_analysis
if harm_analysis:
signals['harm_score'] = harm_analysis.overall_score if hasattr(harm_analysis, 'overall_score') else 0.0
signals['harm_level'] = int(signals['harm_score'] * 4)
logger.info(f"从属性提取有害性信号: score={signals['harm_score']}, level={signals['harm_level']}")
else:
logger.warning("无法提取有害性分析结果,使用默认值")
except Exception as e:
logger.error(f"有害性分析失败: {e}")
# 3. 安全评估
try:
logger.info("执行安全评估...")
security_result = await self.security_agent.execute(harm_result if 'harm_result' in locals() else state)
logger.info(f"安全评估结果类型: {type(security_result)}")
# 修复:正确的字段名是 security_assessment,不是 security_assessment_result
if hasattr(security_result, 'get'):
# 如果是字典类型
security_analysis = security_result.get('security_assessment', {})
if security_analysis:
signals['security_score'] = security_analysis.vulnerability_score if hasattr(security_analysis, 'vulnerability_score') else security_analysis.get('vulnerability_score', 0.0)
# 根据风险等级转换为数值
risk_level = security_analysis.risk_level if hasattr(security_analysis, 'risk_level') else security_analysis.get('risk_level', 'LOW')
risk_mapping = {'LOW': 1, 'MEDIUM': 2, 'HIGH': 3}
signals['security_level'] = risk_mapping.get(risk_level, 1)
logger.info(f"从字典提取安全信号: score={signals['security_score']}, level={signals['security_level']}")
elif hasattr(security_result, 'security_assessment'):
# 如果是DetectionStateWrapper对象,直接访问属性
security_analysis = security_result.security_assessment
if security_analysis:
signals['security_score'] = security_analysis.vulnerability_score if hasattr(security_analysis, 'vulnerability_score') else 0.0
risk_level = security_analysis.risk_level if hasattr(security_analysis, 'risk_level') else 'LOW'
risk_mapping = {'LOW': 1, 'MEDIUM': 2, 'HIGH': 3}
signals['security_level'] = risk_mapping.get(risk_level, 1)
logger.info(f"从属性提取安全信号: score={signals['security_score']}, level={signals['security_level']}")
else:
logger.warning("无法提取安全评估结果,使用默认值")
except Exception as e:
logger.error(f"安全评估失败: {e}")
# 计算高风险计数
risk_indicators = [
signals['edit_has_change'],
signals['harm_level'] >= 2,
signals['security_level'] >= 2,
signals['harm_score'] > 0.6,
signals['security_score'] > 0.6
]
signals['high_risk_count'] = sum(risk_indicators)
logger.info(f"最终Agent信号: {signals}")
return signals
except Exception as e:
logger.error(f"Agent信号收集完全失败: {e}")
return {
'edit_score': 0.0, 'edit_has_change': False,
'harm_score': 0.0, 'harm_level': 0,
'security_score': 0.0, 'security_level': 0,
'high_risk_count': 0
}
def enhanced_classification_v3(self, ml_result, agent_signals):
"""V3版增强分类逻辑 - 更激进的Agent增强策略"""
ml_prediction = ml_result.predicted_class
ml_confidence = ml_result.confidence
logger.info(f"ML预测: class={ml_prediction}({self.keti_labels[ml_prediction]}), confidence={ml_confidence:.3f}")
logger.info(f"Agent信号: {agent_signals}")
# Agent增强系数
agent_boost = 0.0
reclassify = False
new_class = ml_prediction
# 对OI类别的特别增强
if ml_prediction == 3: # OI
if agent_signals['harm_level'] >= 2:
agent_boost += 0.25 # 有害性分析确认
if agent_signals['security_level'] >= 2:
agent_boost += 0.2 # 安全评估确认风险
if agent_signals['high_risk_count'] >= 2:
agent_boost += 0.15 # 多Agent一致判断
if agent_signals['harm_score'] > 0.7:
agent_boost += 0.1 # 高有害性评分
# 对BMI类别的特别增强
elif ml_prediction == 4: # BMI
if agent_signals['edit_has_change']:
agent_boost += 0.2 # 编辑检测确认有变化
if agent_signals['security_level'] >= 1:
agent_boost += 0.25 # 安全风险检测
if agent_signals['harm_score'] > 0.4:
agent_boost += 0.15 # 有害性评分
if agent_signals['high_risk_count'] >= 3:
agent_boost += 0.1 # 多信号确认
# 强Agent信号重分类逻辑
elif ml_prediction == 0: # 如果ML预测为正常
# 检查是否有强烈的恶意信号
strong_signals = (
agent_signals['harm_level'] >= 3 or
agent_signals['security_level'] >= 3 or
agent_signals['harm_score'] > 0.8 or
agent_signals['high_risk_count'] >= 4
)
if strong_signals:
# 根据信号类型重分类
if agent_signals['harm_level'] >= 3 or agent_signals['harm_score'] > 0.8:
new_class = 3 # 重分类为OI
ml_confidence = 0.75
agent_boost = 0.2
reclassify = True
logger.info("重分类: NE -> OI (强有害性信号)")
elif agent_signals['security_level'] >= 3 and agent_signals['edit_has_change']:
new_class = 4 # 重分类为BMI
ml_confidence = 0.7
agent_boost = 0.15
reclassify = True
logger.info("重分类: NE -> BMI (强安全风险+编辑变化)")
# 其他类别的轻度增强
elif ml_prediction in [1, 2, 5]: # FU, MI, BI
if agent_signals['high_risk_count'] >= 2:
agent_boost += 0.1
# 计算最终置信度
final_confidence = min(ml_confidence + agent_boost, 1.0)
result = {
'predicted_class': new_class,
'predicted_label': self.keti_labels[new_class],
'confidence': final_confidence,
'ml_confidence': ml_confidence,
'agent_boost': agent_boost,
'agent_signals': agent_signals,
'reclassified': reclassify
}
if agent_boost > 0 or reclassify:
logger.info(f"✅ Agent增强生效: {result}")
else:
logger.info(f"⚠️ 无Agent增强: {result}")
return result
async def predict_final(self, state):
"""最终版增强预测"""
# 1. ML分类器预测
ml_result = self.base_ml_classifier.predict(state)
# 2. V2版Agent信号收集
agent_signals = await self.collect_agent_signals_v2(state)
# 3. V3版增强分类
enhanced_result = self.enhanced_classification_v3(ml_result, agent_signals)
return enhanced_result
async def test_final_ensemble(max_samples: int = 10):
"""最终版集成测试"""
logger.info("🔑 最终版集成方法测试 - 修复Agent状态提取")
logger.info("=" * 60)
# 加载测试数据
with open('datasets/test.json', 'r', encoding='utf-8') as f:
dataset = json.load(f)
if max_samples:
dataset = dataset[:max_samples]
logger.info(f"📊 测试数据: {len(dataset)} 个样本")
# 初始化分类器
base_classifier = KetiEnhancedMLClassifier()
ensemble_demo = FinalEnsembleDemo()
# 类别映射
label_mapping = {
"Non edited": 0,
"Fact updating": 1,
"Misinformation injection": 2,
"Offensiveness injection": 3,
"Behavioral misleading injection": 4,
"Bias injection": 5
}
# 统计变量
base_results = []
ensemble_results = []
category_stats = defaultdict(lambda: {"total": 0, "base_correct": 0, "ensemble_correct": 0, "improvements": 0})
agent_contribution_stats = {
"total_boost": 0.0,
"significant_boost_count": 0,
"reclassification_count": 0,
"successful_reclassifications": 0,
"agent_signal_success": 0
}
start_time = time.time()
for i, sample in enumerate(dataset):
try:
# 准备数据
query = sample.get('query', '')
object_text = sample.get('object', '')
expected_type = sample.get('type', 'Non edited')
expected_class = label_mapping.get(expected_type, 0)
logger.info(f"\n--- 样本 {i+1}/{len(dataset)} ---")
logger.info(f"类别: {expected_type}")
logger.info(f"Query: {query[:100]}...")
# 创建状态
test_texts = [query]
if object_text:
test_texts.append(object_text)
state = create_detection_state_wrapper(
target_model_path="test_model",
baseline_model_path="baseline_model",
test_texts=test_texts
)
# 基础ML预测
base_result = base_classifier.predict(state)
base_correct = base_result.predicted_class == expected_class
# 集成预测
ensemble_result = await ensemble_demo.predict_final(state)
ensemble_correct = ensemble_result['predicted_class'] == expected_class
# 记录结果
base_results.append({
'sample_id': i,
'expected': expected_class,
'predicted': base_result.predicted_class,
'confidence': base_result.confidence,
'correct': base_correct,
'category': expected_type
})
ensemble_results.append({
'sample_id': i,
'expected': expected_class,
'predicted': ensemble_result['predicted_class'],
'confidence': ensemble_result['confidence'],
'correct': ensemble_correct,
'category': expected_type,
'agent_boost': ensemble_result.get('agent_boost', 0.0),
'reclassified': ensemble_result.get('reclassified', False)
})
# 更新类别统计
category_stats[expected_type]["total"] += 1
if base_correct:
category_stats[expected_type]["base_correct"] += 1
if ensemble_correct:
category_stats[expected_type]["ensemble_correct"] += 1
if ensemble_correct and not base_correct:
category_stats[expected_type]["improvements"] += 1
# 更新Agent贡献统计
agent_boost = ensemble_result.get('agent_boost', 0.0)
agent_contribution_stats["total_boost"] += agent_boost
if agent_boost > 0.1:
agent_contribution_stats["significant_boost_count"] += 1
if ensemble_result.get('reclassified', False):
agent_contribution_stats["reclassification_count"] += 1
if ensemble_correct and not base_correct:
agent_contribution_stats["successful_reclassifications"] += 1
# 检查Agent信号是否有效
agent_signals = ensemble_result.get('agent_signals', {})
if any(agent_signals.get(key, 0) > 0 for key in ['harm_score', 'security_score', 'high_risk_count']):
agent_contribution_stats["agent_signal_success"] += 1
# 显示结果
if base_correct != ensemble_correct or ensemble_result.get('agent_boost', 0) > 0.1:
logger.info(f"🎯 样本 {i+1}: 基础={base_correct}, 集成={ensemble_correct}, 增强={agent_boost:.3f}")
except Exception as e:
logger.error(f"样本 {i} 处理失败: {e}")
continue
total_time = time.time() - start_time
# 计算总体统计
total_samples = len(base_results)
base_correct_count = sum(1 for r in base_results if r['correct'])
ensemble_correct_count = sum(1 for r in ensemble_results if r['correct'])
base_accuracy = base_correct_count / total_samples if total_samples > 0 else 0
ensemble_accuracy = ensemble_correct_count / total_samples if total_samples > 0 else 0
improvement_count = sum(1 for i in range(len(base_results))
if ensemble_results[i]['correct'] and not base_results[i]['correct'])
# 显示结果
logger.info("\n" + "=" * 60)
logger.info("📊 最终版测试结果分析")
logger.info("=" * 60)
logger.info(f"测试样本数: {total_samples}")
logger.info(f"测试时间: {total_time:.1f}秒")
logger.info(f"平均处理时间: {total_time/total_samples:.3f}秒/样本")
logger.info("")
logger.info("📈 整体性能对比:")
logger.info(f" 基础ML准确率: {base_accuracy:.1%} ({base_correct_count}/{total_samples})")
logger.info(f" 集成方法准确率: {ensemble_accuracy:.1%} ({ensemble_correct_count}/{total_samples})")
logger.info(f" 性能提升: {ensemble_accuracy - base_accuracy:+.1%}")
logger.info(f" 改进样本数: {improvement_count} ({improvement_count/total_samples:.1%})")
# Agent贡献详细分析
avg_boost = agent_contribution_stats["total_boost"] / total_samples if total_samples > 0 else 0
signal_success_rate = agent_contribution_stats["agent_signal_success"] / total_samples if total_samples > 0 else 0
logger.info(f"\n🤖 Agent贡献详细分析:")
logger.info(f" 平均Agent增强: +{avg_boost:.3f}")
logger.info(f" 显著增强样本: {agent_contribution_stats['significant_boost_count']}/{total_samples} ({agent_contribution_stats['significant_boost_count']/total_samples:.1%})")
logger.info(f" 重分类次数: {agent_contribution_stats['reclassification_count']}")
logger.info(f" 成功重分类: {agent_contribution_stats['successful_reclassifications']}")
logger.info(f" Agent信号成功率: {signal_success_rate:.1%}")
# 总结
if ensemble_accuracy > base_accuracy:
logger.info(f"\n🎉 最终版集成方法验证成功!")
logger.info(f" ✅ 准确率提升 {ensemble_accuracy - base_accuracy:+.1%}")
logger.info(f" ✅ 改进 {improvement_count} 个样本")
logger.info(f" ✅ Agent平均贡献 +{avg_boost:.3f}")
logger.info(f" ✅ Agent信号成功率 {signal_success_rate:.1%}")
elif agent_contribution_stats["agent_signal_success"] > 0:
logger.info(f"\n🔧 Agent信号收集已修复,但ensemble效果有限")
logger.info(f" ✅ Agent信号成功率 {signal_success_rate:.1%}")
logger.info(f" ⚠️ 需要调整ensemble策略")
else:
logger.info(f"\n❌ Agent信号收集仍然失败")
return {
"base_accuracy": base_accuracy,
"ensemble_accuracy": ensemble_accuracy,
"improvement": ensemble_accuracy - base_accuracy,
"category_stats": dict(category_stats),
"agent_stats": agent_contribution_stats
}
if __name__ == "__main__":
# 运行最终版测试
result = asyncio.run(test_final_ensemble(max_samples=10)) # 测试10个样本