-
Notifications
You must be signed in to change notification settings - Fork 1
/
estimate_measurements.py
61 lines (46 loc) · 1.91 KB
/
estimate_measurements.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import argparse
import numpy as np
import pickle
from landmark_utils import *
from measurement_utils import *
def load_model(sex):
model_path = f"models/{sex}.pkl"
model = pickle.load(open(model_path, "rb"))
return model
def estimate_measurements(model,landmarks):
results = []
for meas in MEASUREMENTS_ORDER:
pred_meas = model[meas].predict(landmarks)
results.append(pred_meas)
return results
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-L","--landmarks_path",
required=True,
type=str,
default="demo/demo_landmarks.json",
help="Path to landmarks to use.")
parser.add_argument("-S","--sex",
required=True,
type=str,
default="male",
choices=["male","female"],
help="Sex of the subject.")
parser.add_argument("--scale",
required=False,
type=int,
default=1,
help="Scale the landmarks into mm if necessary. \
Multiply scale with coordinates")
parser.add_argument("--normalize_viewpoint",
action="store_true",
help="Rotate the landmarks to have the same viewpoint \
as the training data.")
args = parser.parse_args()
landmarks = load_landmarks(args.landmarks_path)
landmarks = process_landmarks(landmarks, args.scale, args.normalize_viewpoint)
model = load_model(args.sex)
predicted_measurements = estimate_measurements(model, landmarks)
print("Measurement Estimation:")
for i, m in enumerate(MEASUREMENTS_ORDER):
print(f"{m:45} {predicted_measurements[i].round(2).item():>7.2f}")