Skip to content

Commit d90d377

Browse files
committed
add metric calculation demo
1 parent c14fd1c commit d90d377

File tree

2 files changed

+190
-0
lines changed

2 files changed

+190
-0
lines changed

Diff for: metric.py

+94
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import numpy as np
2+
import scipy.interpolate as interp
3+
from scipy.optimize import linear_sum_assignment
4+
5+
def band_matrix(window: int, N):
6+
'''
7+
return NxN band matrix so that matrices multiplied by band matrix retain a +/- window on each side of diagonal element and discard points outside
8+
9+
'''
10+
11+
a = np.zeros((N,N))
12+
i,j = np.indices(a.shape)
13+
for n in range(window+1):
14+
a[i==j] = 1.
15+
a[i==j+n] = 1.
16+
a[i==j-n] = 1.
17+
a[a==0]='nan'
18+
return a
19+
20+
def reorder(A, B, window):
21+
'''
22+
apply reordering algorithm to series B to optimally match series A,
23+
within allowed window of days
24+
25+
threshold_type 'lower' - match high temperature extremes
26+
threshold_type 'upper' - match low temperature extremes
27+
'''
28+
29+
N = len(A) # total number of days
30+
if isinstance(B, list):
31+
cost_matrix = np.median([np.absolute(A[:, None] - this_B[None, :]) for this_B in B], axis=0)
32+
else:
33+
assert len(A) == len(B)
34+
cost_matrix = (A[:, None] - B[None, :])**2
35+
exclude_cost = cost_matrix.max()*10 # set arbitrarily high cost to prevent reordering of these points
36+
37+
b = band_matrix(window, N)
38+
banded_cost_matrix = cost_matrix * b
39+
banded_cost_matrix[np.isnan(banded_cost_matrix)] = exclude_cost
40+
41+
row_index, column_index = linear_sum_assignment(np.abs(banded_cost_matrix)) # find assignent of rows to columns with minimum cost
42+
43+
if isinstance(B, list):
44+
Bs_matched = [
45+
[this_B[i] for i in column_index]
46+
for this_B in B
47+
]
48+
return Bs_matched
49+
50+
B_matched = [B[i] for i in column_index]
51+
return B_matched
52+
53+
def rmse(A, B):
54+
'''
55+
root mean sq error
56+
'''
57+
return ((A-B)**2).mean()**0.5
58+
59+
def threshold_cost(A, B, threshold, threshold_type):
60+
'''
61+
calculate rmse cost of B wrt A above/below specified threshold
62+
threshold_type = "lower" - evaluate for points where A > threshold
63+
threshold_type = "upper" - evaluate for points where A < threshold
64+
'''
65+
66+
if threshold_type == 'none':
67+
include_indices = [i for i in range(len(A))]
68+
69+
elif threshold_type == 'lower':
70+
exclude_indices = [i for i in range(len(A)) if A[i] < threshold]
71+
include_indices = [i for i in range(len(A)) if A[i]>= threshold]
72+
73+
elif threshold_type == 'upper':
74+
include_indices = [i for i in range(len(A)) if A[i] < threshold]
75+
exclude_indices = [i for i in range(len(A)) if A[i]>= threshold]
76+
77+
else:
78+
print("select threshold 'none', 'lower', 'upper'")
79+
80+
A_selected = np.array(A)[include_indices]
81+
B_selected = np.array(B)[include_indices]
82+
83+
cost = rmse(A_selected, B_selected)
84+
85+
return cost
86+
87+
def reordering_cost(A, B, window, threshold, threshold_type):
88+
'''
89+
reorder + calculate cost
90+
threshold_type = 'none', 'lower', 'upper'
91+
'''
92+
B_matched = reorder(A, B, window)
93+
cost = threshold_cost(A, B_matched, threshold, threshold_type)
94+
return cost

Diff for: metric_demo.ipynb

+96
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)