-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathShapes_recognition.py
63 lines (63 loc) · 1.87 KB
/
Shapes_recognition.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# importings
from keras.models import Sequential
from keras.layers import Dense, Conv2D, Flatten, Dropout, MaxPooling2D
from keras.preprocessing.image import ImageDataGenerator
import keras
from keras import backend as K
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
# The 1./255 is to convert from uint8 to float32 in range [0,1].
image_generator = ImageDataGenerator(
rescale=1./255,
# data augmentation
rotation_range=45,
vertical_flip=True,
horizontal_flip=True
)
train_data_gen = image_generator.flow_from_directory(
directory='train',
target_size=(28,28),
batch_size=204,
color_mode='rgb'
)
model = keras.Sequential([
Conv2D(16, 3, padding='same', activation='relu', input_shape=(28,28,3)),
MaxPooling2D(),
Dropout(0.2),
Conv2D(32, 3, padding='same', activation='relu'),
MaxPooling2D(),
Conv2D(64, 3, padding='same', activation='relu'),
MaxPooling2D(),
Dropout(0,2),
keras.layers.Flatten(),
keras.layers.Dense(512, activation='relu'), # MLP part of the CNN
keras.layers.Dense(6, activation='softmax')
])
# For a multi-class classification problem
model.compile(optimizer='rmsprop',
loss='categorical_crossentropy',
metrics=['accuracy'])
model.summary()
history = model.fit_generator(
train_data_gen,
steps_per_epoch= train_data_gen.n // 10,
epochs= 30
)
#saving
model.save('model.h5')
model.save_weights('model_weights.h5')
# visualize
acc = history.history['accuracy']
loss = history.history['loss']
epochs_range = range(20)
plt.figure(figsize=(8, 8))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.legend(loc='lower right')
plt.title('Training Accuracy')
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.legend(loc='upper right')
plt.title('Training Loss')
plt.show()