Skip to content

Conversation

@undertaker86001
Copy link

@undertaker86001 undertaker86001 commented Aug 25, 2025

整体架构设计

train_model.py采用了模块化设计,主要包含以下几个核心组件:

1. 数据管理模块

  • MiningAudioDataset类:继承PyTorch的Dataset,负责音频数据的加载和预处理
  • prepare_labels函数:处理多标签数据的编码和转换

2. 模型训练模块

  • train_epoch函数:执行单个训练周期
  • validate_model函数:模型验证和评估
  • train_model主函数:协调整个训练流程

3. 性能评估模块

  • evaluate_model函数:计算各种评估指标
  • plot_training_history函数:可视化训练过程

核心原理详解

1. 多标签分类原理

# 矿业场景的三个分类维度
MINING_LABELS = {
    'equipment_status': ['normal', 'abnormal', 'maintenance_needed'],      # 3个标签
    'fault_type': ['bearing_fault', 'motor_fault', 'hydraulic_leak', 'belt_fault'],  # 4个标签
    'priority': ['low_priority', 'medium_priority', 'high_priority']      # 3个标签
}

原理说明:

  • 每个音频样本可能同时属于多个类别
  • 使用MultiLabelBinarizer将文本标签转换为二进制向量
  • 最终标签维度 = 3 + 4 + 3 = 10维

2. 音频特征提取原理

def extract_features(self, audio_path):
    # 1. 加载音频文件
    y, sr = librosa.load(audio_path, sr=self.sample_rate)
    
    # 2. 时域特征
    zcr = librosa.feature.zero_crossing_rate(y, hop_length=self.hop_length)  # 过零率
    energy = librosa.feature.rms(y=y, hop_length=self.hop_length)            # 能量
    
    # 3. 频域特征
    mfccs = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=self.n_mfcc)           # MFCC系数
    spectral_centroid = librosa.feature.spectral_centroid(y=y, sr=sr)       # 频谱质心
    spectral_rolloff = librosa.feature.spectral_rolloff(y=y, sr=sr)         # 频谱滚降
    spectral_bandwidth = librosa.feature.spectral_bandwidth(y=y, sr=sr)     # 频谱带宽
    
    # 4. 统计特征
    feature_stats = np.hstack([
        np.mean(features, axis=1),    # 均值
        np.std(features, axis=1),     # 标准差
        np.max(features, axis=1),     # 最大值
        np.min(features, axis=1)      # 最小值
    ])

特征维度计算:

  • MFCC: 13维 × 4统计量 = 52维
  • 其他特征: 5维 × 4统计量 = 20维
  • 总特征维度: 72维

3. 神经网络架构原理

class MiningAudioClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dims=[256, 128, 64]):
        # 特征提取层
        self.feature_extractor = nn.Sequential(
            nn.Linear(input_dim, 256),      # 输入层 → 256
            nn.BatchNorm1d(256),            # 批归一化
            nn.ReLU(),                      # 激活函数
            nn.Dropout(0.3),                # 防止过拟合
            
            nn.Linear(256, 128),            # 256 → 128
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.3),
            
            nn.Linear(128, 64),             # 128 → 64
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Dropout(0.3)
        )
        
        # 多标签分类头
        self.equipment_status_head = nn.Linear(64, 3)    # 设备状态分类
        self.fault_type_head = nn.Linear(64, 4)          # 故障类型分类
        self.priority_head = nn.Linear(64, 3)            # 优先级分类

优化策略:

  • 共享特征提取:三个分类任务共享前面的特征提取层
  • 任务特定分类头:每个分类任务有独立的输出层
  • 正则化技术:BatchNorm + Dropout防止过拟合

4. 损失函数设计原理

def train_epoch(model, train_loader, criterion, optimizer, device):
    for batch_idx, (features, labels) in enumerate(train_loader):
        outputs = model(features)
        
        # 计算每个类别的损失
        batch_loss = 0
        start_idx = 0
        for category, output in outputs.items():
            end_idx = start_idx + len(MINING_LABELS[category])
            category_labels = labels[:, start_idx:end_idx]
            loss = criterion(output, category_labels)  # BCELoss
            batch_loss += loss
            start_idx = end_idx
        
        batch_loss.backward()  # 反向传播
        optimizer.step()       # 参数更新

损失计算原理:

  • 使用BCELoss(二元交叉熵损失)
  • 每个类别的损失独立计算,然后求和
  • 支持多标签学习(一个样本可以有多个标签)

5. 训练流程控制原理

def train_model(data_path, model_save_path, scaler_save_path, config=None):
    # 1. 数据准备
    labels, label_encoders = prepare_labels(label_data)
    X_train, X_val, y_train, y_val = train_test_split(...)
    
    # 2. 特征标准化
    scaler = StandardScaler()
    scaler.fit(train_features)
    
    # 3. 训练循环
    for epoch in range(config['epochs']):
        train_loss, train_accuracy = train_epoch(...)
        val_loss, val_predictions, val_labels = validate_model(...)
        
        # 4. 早停机制
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            torch.save(model.state_dict(), model_save_path)  # 保存最佳模型
        else:
            patience_counter += 1
            if patience_counter >= config['early_stopping_patience']:
                break  # 早停

训练策略:

  • 数据分割:80%训练,20%验证
  • 特征标准化:使用训练集计算标准化参数
  • 早停机制:防止过拟合,自动停止训练
  • 模型保存:保存验证损失最低的模型

6. 评估指标原理

def evaluate_model(predictions, labels, label_encoders):
    # 整体准确率
    overall_accuracy = accuracy_score(labels.flatten(), predictions.flatten())
    
    # 按类别评估
    for category in MINING_LABELS.keys():
        category_accuracy = accuracy_score(
            category_labels.flatten(), 
            category_predictions.flatten()
        )
        
        # 详细分类报告
        report = classification_report(
            category_labels, 
            category_predictions, 
            target_names=target_names
        )

评估维度:

  • 整体准确率:所有标签的平均准确率
  • 类别准确率:每个分类任务的独立准确率
  • 详细报告:精确率、召回率、F1分数

关键技术特点

1. 多任务学习

  • 一个模型同时学习三个分类任务
  • 共享特征表示,提高泛化能力

2. 自适应学习率

scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=5, verbose=True
)
  • 当验证损失不再下降时,自动降低学习率
  • 帮助模型找到更好的局部最优解

3. 数据增强策略

  • 支持不同采样率的音频文件
  • 自动处理音频格式差异

4. 模型版本管理

  • 自动保存最佳模型和标准化器
  • 支持模型热更新和回滚

使用流程

# 1. 准备训练数据CSV文件
# 格式:audio_path, equipment_status, fault_type, priority

# 2. 运行训练
python train_model.py \
    --data_path data/mining_audio_dataset.csv \
    --model_save_path models/mining_audio_model.pth \
    --scaler_save_path models/feature_scaler.pkl \
    --epochs 50 \
    --batch_size 32

# 3. 查看训练日志和图表
# - training.log:详细训练日志
# - training_history.png:训练过程可视化

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant