-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
123 lines (102 loc) · 4.33 KB
/
train.py
File metadata and controls
123 lines (102 loc) · 4.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import tensorflow as tf
try: [tf.config.experimental.set_memory_growth(gpu, True) for gpu in tf.config.experimental.list_physical_devices("GPU")]
except: pass
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, TensorBoard
from mltu.preprocessors import ImageReader
from mltu.transformers import ImageResizer, LabelIndexer, LabelPadding, ImageShowCV2
from mltu.augmentors import RandomBrightness, RandomRotate, RandomErodeDilate, RandomSharpen
from mltu.annotations.images import CVImage
from mltu.tensorflow.dataProvider import DataProvider
from mltu.tensorflow.losses import CTCloss
from mltu.tensorflow.callbacks import Model2onnx, TrainLogger
from mltu.tensorflow.metrics import CERMetric, WERMetric
from model import train_model
from configs import ModelConfigs
import os
from tqdm import tqdm
import joblib
# Must download and extract datasets manually from https://fki.tic.heia-fr.ch/databases/download-the-iam-handwriting-database to Datasets\IAM_Sentences
sentences_txt_path = os.path.join("Datasets", "IAM_Sentences", "ascii", "sentences.txt")
sentences_folder_path = os.path.join("Datasets", "IAM_Sentences", "sentences")
dataset, vocab, max_len = [], set(), 0
words = open(sentences_txt_path, "r").readlines()
for line in tqdm(words):
if line.startswith("#"):
continue
line_split = line.split(" ")
if line_split[2] == "err":
continue
folder1 = line_split[0][:3]
folder2 = "-".join(line_split[0].split("-")[:2])
file_name = line_split[0] + ".png"
label = line_split[-1].rstrip("\n")
# replace "|" with " " in label
label = label.replace("|", " ")
rel_path = os.path.join(sentences_folder_path, folder1, folder2, file_name)
if not os.path.exists(rel_path):
print(f"File not found: {rel_path}")
continue
dataset.append([rel_path, label])
vocab.update(list(label))
max_len = max(max_len, len(label))
# Create a ModelConfigs object to store model configurations
configs = ModelConfigs()
# Save vocab and maximum text length to configs
configs.vocab = "".join(vocab)
configs.max_text_length = max_len
configs.save()
# Create a data provider for the dataset
data_provider = DataProvider(
dataset=dataset,
skip_validation=True,
batch_size=configs.batch_size,
data_preprocessors=[ImageReader(CVImage)],
transformers=[
ImageResizer(configs.width, configs.height, keep_aspect_ratio=True),
LabelIndexer(configs.vocab),
LabelPadding(max_word_length=configs.max_text_length, padding_value=len(configs.vocab)),
],
)
# Split the dataset into training and validation sets
train_data_provider, val_data_provider = data_provider.split(split = 0.9)
# Augment training data with random brightness, rotation and erode/dilate
train_data_provider.augmentors = [
RandomBrightness(),
RandomErodeDilate(),
RandomSharpen(),
]
# Creating TensorFlow model architecture
model = train_model(
input_dim = (configs.height, configs.width, 3),
output_dim = len(configs.vocab),
)
# Compile the model and print summary
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=configs.learning_rate),
loss=CTCloss(),
metrics=[
CERMetric(vocabulary=configs.vocab),
WERMetric(vocabulary=configs.vocab)
],
run_eagerly=False
)
model.summary(line_length=110)
# Define callbacks
earlystopper = EarlyStopping(monitor="val_CER", patience=20, verbose=1, mode="min")
checkpoint = ModelCheckpoint(f"{configs.model_path}/model.h5", monitor="val_CER", verbose=1, save_best_only=True, mode="min")
trainLogger = TrainLogger(configs.model_path)
tb_callback = TensorBoard(f"{configs.model_path}/logs", update_freq=1)
reduceLROnPlat = ReduceLROnPlateau(monitor="val_CER", factor=0.9, min_delta=1e-10, patience=5, verbose=1, mode="auto")
model2onnx = Model2onnx(f"{configs.model_path}/model.h5")
# Train the model
model.fit(
train_data_provider,
validation_data=val_data_provider,
epochs=configs.train_epochs,
callbacks=[earlystopper, checkpoint, trainLogger, reduceLROnPlat, tb_callback, model2onnx],
workers=configs.train_workers
)
# Save training and validation datasets as csv files
train_data_provider.to_csv(os.path.join(configs.model_path, "train.csv"))
val_data_provider.to_csv(os.path.join(configs.model_path, "val.csv"))
joblib.dump(model,'Handwriting_recognition_model_created.joblib')