-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun.py
More file actions
90 lines (79 loc) · 2.34 KB
/
run.py
File metadata and controls
90 lines (79 loc) · 2.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import dataset
import importlib
importlib.reload(dataset)
from dataset import *
import argparse
from models.DeblurModel import DeblurModel, CombinedLoss
import keras
import cv2
import os
import argparse
def parse_arguments():
parser = argparse.ArgumentParser(description="Deblur your photos.")
parser.add_argument(
"--model_path",
type=str,
required=False,
help="Path to the saved model.",
)
parser.add_argument(
"--dataset_dir",
type=str,
required=True,
help="Path to the dataset main directory.",
)
parser.add_argument(
"--dataset_name",
type=str,
required=True,
help="Name of the dataset in the dataset_dir folder.",
)
parser.add_argument(
"--inference",
type=str,
required=False,
default="True",
help="True - run evaluation, False - run training.",
)
parser.add_argument(
"--num_of_epochs",
type=int,
required=False,
default=5,
help="Number of epoches used for training.",
)
return parser.parse_args()
args = parse_arguments()
input_shape = (128, 128)
strategy = tf.distribute.MirroredStrategy()
BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
print(f"number of GPUs used: {strategy.num_replicas_in_sync = }")
print(f"batch size: {BATCH_SIZE = }")
input_shape = (128, 128)
#
with strategy.scope():
dataloaders = get_dataloaders(args.dataset_dir, args.dataset_name, batch_size=BATCH_SIZE)
model = DeblurModel(input_shape)
optimizer='adam'
loss='mean_squared_error'
model.compile(
optimizer=tf.keras.optimizers.get(optimizer),
loss=loss,
metrics=['mae']
)
model.build(input_shape=input_shape)
if args.inference == "True":
model.load_weights(args.model_path)
model.evaluate(dataloaders['test'])
predicted = model.predict(dataloaders['test'])
for idx, img in enumerate(predicted):
cv2.imwrite(f'{args.dataset_dir}/{args.dataset_name}/results/{idx}.jpg', img)
print(f"{predicted.shape[0]} Images generated")
else:
history = model.fit(
dataloaders['train'],
validation_data=dataloaders['val'],
epochs=args.num_of_epochs
)
model.save_weights('./checkpoints/best_model.weights.h5')