-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy patheval.py
48 lines (41 loc) · 2.06 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import glob
import json
import os
import numpy as np
from tqdm import tqdm
import onnxruntime as ort
import matplotlib.pyplot as plt
if __name__ == "__main__":
data_dir = "./Data/emulated_dataset/test/"
figs_dir = "./figs"
os.makedirs(figs_dir, exist_ok=True)
onnx_model = "./onnx_model/fast_and_furious_model.onnx"
data_files = glob.glob(os.path.join(data_dir, f'*.json'), recursive=True)
ort_session = ort.InferenceSession(onnx_model)
for filename in tqdm(data_files, desc="Processing"):
with open(filename, "r") as file:
call_data = json.load(file)
observations = np.asarray(call_data['observations'], dtype=np.float32)
bandwidth_predictions = np.asarray(call_data['bandwidth_predictions'], dtype=np.float32)
true_capacity = np.asarray(call_data['true_capacity'], dtype=np.float32)
baseline_model_predictions = []
hidden_state, cell_state = np.zeros((1, 128), dtype=np.float32), np.zeros((1, 128), dtype=np.float32)
for t in range(observations.shape[0]):
feed_dict = {'obs': observations[t:t+1,:].reshape(1,1,-1),
'hidden_states': hidden_state,
'cell_states': cell_state
}
bw_prediction, hidden_state, cell_state = ort_session.run(None, feed_dict)
baseline_model_predictions.append(bw_prediction[0,0,0])
baseline_model_predictions = np.asarray(baseline_model_predictions, dtype=np.float32)
fig = plt.figure(figsize=(8, 5))
time_s = np.arange(0, observations.shape[0]*60,60)/1000
plt.plot(time_s, baseline_model_predictions/1000, label='FARC', color='g')
plt.plot(time_s, bandwidth_predictions/1000, label='Baseline BW Estimator', color='r')
plt.plot(time_s, true_capacity/1000, label='True Capacity', color='k')
plt.ylabel("Bandwidth (Kbps)")
plt.xlabel("Call Duration (s)")
plt.grid(True)
plt.legend()
plt.savefig(os.path.join(figs_dir,os.path.basename(filename).replace(".json",".png")))
plt.close()