-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
47 lines (36 loc) · 1.46 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import os
import numpy as np
from tflite_model_maker import model_spec, object_detector
from tflite_model_maker.config import ExportFormat, QuantizationConfig
EPOCHS = 2
BATCH_SIZE = 5
if __name__ == "__main__":
train_ds = object_detector.DataLoader.from_pascal_voc(
images_dir="./data/dataset/train",
annotations_dir="./data/dataset/annotations_train",
label_map=["Tennis ball"],
)
validation_ds = object_detector.DataLoader.from_pascal_voc(
images_dir="./data/dataset/validation",
# annotations_dir="./data/dataset/annotations_validation",
annotations_dir="./data/dataset/annotations_test",
label_map=["Tennis ball"],
)
test_ds = object_detector.DataLoader.from_pascal_voc(
images_dir="./data/dataset/test",
annotations_dir="./data/dataset/annotations_test",
label_map=["Tennis ball"],
)
# options: "efficientdet_liteX", where X ele [0, 1, 2, 3, 4]
spec = model_spec.get('efficientdet_lite1')
model = object_detector.create(
train_data=train_ds,
model_spec=spec,
# epochs=EPOCHS,
# batch_size=BATCH_SIZE,
train_whole_model=False,
validation_data=validation_ds)
# model.evaluate(test_ds)
model.export(export_dir='./data/models')
# TODO: check for quantization
# https://pub.towardsai.net/object-detection-at-the-edge-with-tf-lite-model-maker-e635a17b0854