This software implements the Convolutional Recurrent Neural Network (CRNN) in pytorch. Origin software could be found in crnn
python 3.6 pytorch 4.0 opencv2.4 + pytorch + lmdb +wrap_ctc
ATTENTION!
getLmdb.py must run in python2.x
- [ 11%] Building NVCC (Device) object CMakeFiles/warpctc.dir/src/warpctc_generated_reduce.cu.o sh: cicc: command not found CMake Error at warpctc_generated_reduce.cu.o.cmake:279 (message): Error generating file /home/rice/warp-ctc/build/CMakeFiles/warpctc.dir/src/./warpctc_generated_reduce.cu.o make[2]: *** [CMakeFiles/warpctc.dir/build.make:256: CMakeFiles/warpctc.dir/src/warpctc_generated_reduce.cu.o] Error 1 make[1]: *** [CMakeFiles/Makefile2:104: CMakeFiles/warpctc.dir/all] Error 2 make: *** [Makefile:130: all] Error 2 you should reinstall your cuda, and make sure it install completely
- THCudaMallco error https://github.com/baidu-research/warp-ctc/pull/71/files
- Xtra-Computing/thundersvm#54 (comment)
- ln -s /opt/cuda/include/* /home/rice/anaconda3/lib/python3.6/site-packages/torch/utils/ffi/../../lib/include/THC/
Construct dataset following origin guide. For training with variable length, please sort the image according to the text length. reference:https://github.com/Aurora11111/TextRecognitionDataGenerator
- 数据预处理
运行/contrib/crnn/tool/getLmdb.py
# 生成的lmdb输出路径
outputPath = '/run/media/rice/DATA/chinese1/lmdb'
# 图片及对应的label
imgdata = open("/run/media/rice/DATA/chinese1/labels.txt")
- 训练模型
运行/contrib/crnn/crnn_main.py
python crnn_main.py [--param val]
--trainroot 训练集路径
--valroot 验证集路径
--workers CPU工作核数, default=4
--batchSize 设置batchSize大小, default=64
--imgH 图片高度, default=32
--imgW 图片宽度,default =280(所用训练图片均为280*32)
--nh LSTM隐藏层数, default=256
--niter 训练回合数, default=25
--lr 学习率, default=0.00005
--cuda 使用GPU, action='store_true'
--ngpu 使用GPU的个数, default=1
--crnn 选择预训练模型
--alphabet 设置分类
--experiment 模型保存目录
--displayInterval 设置多少次迭代显示一次, default=1000
--n_test_disp 每次验证显示的个数, default=10
--valInterval 设置多少次迭代验证一次, default=1000
--saveInterval 设置多少次迭代保存一次模型, default=1000
--adam 使用adma优化器, default='True'
--adadelta 使用adadelta优化器, action='store_true'
--keep_ratio 设置图片保持横纵比缩放, action='store_true'
--random_sample 是否使用随机采样器对数据集进行采样, action='store_true'
示例:python /contrib/crnn/crnn_main.py --tainroot [训练集路径] --valroot [验证集路径] --nh 128 --cuda --crnn [预训练模型路径]
修改/contrib/crnn/keys.py
中alphabet = '012346789'
增加或者减少类别
- 注意事项 训练和预测采用的类别数和LSTM隐藏层数需保持一致
when you nclass is diferent from old ones, you can use this to finetune: python finetune.py
A demo program can be found in src/demo.py
. Before running the demo, download a pretrained model
from Baidu Netdisk or Dropbox.
This pretrained model is converted from auther offered one by tool
.
Put the downloaded model file crnn.pth
into directory data/
. Then launch the demo by:
python demo.py
The demo reads an example image and recognizes its text content.
Expected output: loading pretrained model from ./data/crnn.pth a-----v--a-i-l-a-bb-l-ee-- => available