-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
102 lines (77 loc) · 3.91 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# Initialization utilities.
from trainingInitialization import (
cudaInitialization,
initImageTransforms,
retrieveData,
initDataLoaders,
initModelVGG11,
activateCuda,
)
from torch.nn import CrossEntropyLoss
# Training utilities.
from vggTrainer import train
from torch.optim import AdamW
# Evaluation utilities.
from trainingEvaluation import compareModels
# TO RUN use the environment ai_cv_test as follows:
# * mamba activate ai_cv_test
# And run the program:
# * python main.py
def main():
#############################################################################################################
############################################### ###############################################
############################################# INTIALIZATION #############################################
############################################### ###############################################
#############################################################################################################
# Retrieve the reference to the cuda module to afflict the current program
# and also the running device for a dynamic execution, alowing for CPU/GPU runs.
cudaModuleRef, runningDevice = cudaInitialization()
print(runningDevice)
# Retrieve the set of online image augmentations required
# during the training.
imageTransforms = initImageTransforms()
# Retrieve the data
data = retrieveData()
# Retrieve the initialized data loaders.
trainingLoader, validationLoader, testLoader = initDataLoaders(
data=data, imageTransforms=imageTransforms, batchSize=62, numberOfCPUWorkers=3
)
# trainingLoader.to(runningDevice)
# Retrieve a VGG11 model with pretrained weights
model = initModelVGG11(data=data)
# print(model)
# print(model.classifier)
# Check cuda and load the model on the GPU
if not activateCuda(cudaModuleRef=cudaModuleRef, model=model):
raise Exception("Cuda is not available for training!")
# Transfer-L VGG has the softmax output layer removed, so we need to keep
# it on account into the loss function.
lossFunction = CrossEntropyLoss()
#############################################################################################################
############################################### ###############################################
############################################# TRAINING #############################################
############################################### ###############################################
#############################################################################################################
print("\n\n---STARTING TRAINING---\n")
model = train(
model=model,
lossFunction=lossFunction,
# lr=1e-3 is used just as placeholder
optimizer=AdamW(params=model.parameters(), lr=1e-3, weight_decay=1e-4),
trainingDataLoader=trainingLoader,
validationDataLoader=validationLoader,
learningRates=[0.00008, 0.000005, 0.00000005],
)
print("---TRAINING COMPLETED---\n")
#############################################################################################################
############################################### ###############################################
############################################## EVALUATION ##############################################
############################################### ###############################################
#############################################################################################################
# Compare the model just trained with the stored one.
compareModels(
model,
testLoader,
)
if __name__ == "__main__":
main()