本项目实现了一个基于CIFAR-10数据集的三层神经网络,包含数据加载、模型定义、训练、超参数优化、测试及可视化功能。代码采用纯NumPy实现,无深度学习框架依赖,在CIFAR-10数据集上达到约47-52%的测试准确率。
- 数据处理:自动下载CIFAR-10数据集,支持数据标准化及训练/验证集划分。
- 模型架构:实现ReLU/Sigmoid激活函数的三层神经网络,支持正则化。
- 训练流程:包含批量训练、学习率衰减及早停机制。
- 超参数优化:支持网格搜索,自动记录最佳模型。
- 可视化:生成训练曲线及超参数对比图。
- 模块化设计:6个独立模块
CIFAR-10 数据集包含:
- 60,000张32x32彩色图像
- 10个类别(飞机、汽车、鸟类、猫、鹿、狗、青蛙、马、船和卡车)
- 数据划分:
- 训练集:45,000
- 验证集:5,000
- 测试集:10,000
cifar10/
├── cifar-10-batches-py # 图像分类数据集
├── utils.py # 数据下载与预处理
├── model.py # 神经网络模型定义
├── train.py # 训练流程控制
├── test.py # 测试评估模块
├── search.py # 超参数搜索
├── visualize.py # 结果可视化
└── main.py # 主程序入口
pip install numpy scikit-learn matplotlibpython main.py- 自动下载CIFAR-10数据集(约163MB),支持断点续传。
- 数据标准化:对训练集、验证集和测试集分别进行Z-Score标准化。
- 数据集划分:从完整训练集中划分出验证集(默认10%).
- 实现三层神经网络。
输入层(3072) → 隐藏层(256/512) → 输出层(10)- 支持两种激活函数:ReLU 和 Sigmoid。
- 使用He初始化方法初始化网络权重
- Softmax输出层
- 批量训练:支持小批量梯度下降(默认批次大小128)。
- 学习率衰减:每固定轮数(默认10轮)衰减学习率(默认衰减率0.95)。
- 正则化:支持L2正则化,防止过拟合。
- 早停机制:保存验证集上表现最佳的模型参数。
- 网格搜索:支持多种超参数组合的自动搜索。
- 参数范围:
- 隐藏层维度:[128,256,512,1024]
- 学习率:[0.1,0.05,0.01,0.005,0.001]
- 正则化强度:[0.01,0.005,0.001]
- 激活函数:['ReLU','Sigmoid']
- 自动记录每组超参数的训练损失、验证损失及准确率。
- 测试集评估:计算模型在测试集上的准确率。
- 超参数对比图:生成不同超参数组合的训练曲线对比图(hyperparam_results.png)。
- 学习曲线:绘制最佳模型的训练损失、验证损失及验证准确率变化曲线(learning_curves.png)。
- 模型权重保存:将最佳模型的参数保存为best_model.npz文件。
- 结果图表保存:自动生成可视化图表,便于后续分析。
- 输出示例
Training with: {'hidden_size': 256, 'learning_rate': 0.01, 'reg_lambda': 0.01, 'activation': 'relu'}
Epoch 001/50 | Train Loss: 4.4041 | Val Loss: 4.1193 | Val Acc: 0.4302
Epoch 002/50 | Train Loss: 3.9789 | Val Loss: 3.8711 | Val Acc: 0.4574
Epoch 003/50 | Train Loss: 3.7321 | Val Loss: 3.6884 | Val Acc: 0.4660
...
Current Best Acc: 0.5470最佳模型参数自动保存为best_model.npz:
np.savez('best_model.npz',
W1=best_model.W1,
b1=best_model.b1,
W2=best_model.W2,
b2=best_model.b2){'hidden_size': 512, 'learning_rate': 0.01, 'reg_lambda': 0.01, 'activation': 'relu'}
准确率54.77%- 减少超参数组合数量
- 降低训练轮数(epochs)
- 使用更小的隐藏层维度(如128)
- 可能原因:
- 过拟合(验证集准确率高但测试集低)
- 学习率过大/过小
- 解决方法:
- 增加正则化强度(reg_lambda)
- 调整学习率或隐藏层维度
以上代码已验证可在Python 3.8+环境中运行